Spaces:
Sleeping
Sleeping
"""Fetch content from upload. | |
org ezbee_page.py. | |
""" | |
# pylint: disable=invalid-name | |
# pylint: disable=too-many-locals, too-many-return-statements, too-many-branches, too-many-statements, abstract-class-instantiated | |
import base64 | |
import inspect | |
import io | |
import platform | |
from pathlib import Path | |
from functools import partial | |
from itertools import zip_longest | |
# import hanzidentifier | |
import logzero | |
import numpy as np | |
import pandas as pd | |
import pendulum | |
import streamlit as st | |
from about_time import about_time | |
# from ezbee.gen_pairs import gen_pairs # aset2pairs? | |
from aset2pairs import aset2pairs | |
from icecream import ic | |
from loguru import logger as loggu | |
from logzero import logger | |
from set_loglevel import set_loglevel | |
from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode | |
# from st_aggrid.grid_options_builder import GridOptionsBuilder | |
from streamlit import session_state as state | |
# from st_mlbee.t2s import t2s | |
from st_mlbee import st_mlbee | |
from st_mlbee.color_map import color_map | |
from st_mlbee.fetch_paste import fetch_paste | |
from st_mlbee.fetch_upload import fetch_upload | |
from st_mlbee.fetch_urls import fetch_urls | |
# from seg_text import seg_text | |
from st_mlbee.split_text import seg_text | |
def home(): # noqa | |
"""Run tasks. | |
beetype | |
sourcetype | |
fetch_upload/fetch_paste, fetch_url | |
sourcecount | |
align: para-align/sent-align | |
save xlsx/tsv | |
""" | |
if state.ns.sourcetype not in ["upload", "paste", "urls"]: | |
st.write("Coming soooooooon...") | |
return None | |
# if state.ns.beetype not in ["ezbee", "dzbee", "debee"]: | |
if state.ns.beetype not in ["mlbee", ]: | |
st.write("Coming soon...") | |
return None | |
# process sourcetype and fetch list1/list2 | |
list1 = list2 = [] | |
# fetch_upload/fetch_paste | |
if state.ns.sourcetype in ["upload"]: | |
fetch_upload() | |
elif state.ns.sourcetype in ["paste"]: | |
fetch_paste() | |
elif state.ns.sourcetype in ["urls"]: | |
fetch_urls() | |
else: | |
st.warning(f"{state.ns.sourcetype}: Not implemented") | |
return None | |
# state.ns.list1 state.ns.list2 defiend in fetch_x | |
if state.ns.sentali: # split to sents | |
try: | |
state.ns.list1 = seg_text(state.ns.list1) | |
except Exception as exc: | |
logger.exception(exc) | |
raise | |
try: | |
state.ns.list2 = seg_text(state.ns.list2) | |
except Exception as exc: | |
logger.exception(exc) | |
raise | |
logger.debug("state.ns.updated: %s", state.ns.updated) | |
# if not updated, quit: this does not quite work | |
# only prevents the first run/missing upload | |
if not state.ns.updated: | |
logger.debug(" not updated, early exit.") | |
return None | |
list1 = state.ns.list1[:] | |
list2 = state.ns.list2[:] | |
logger.debug("list1[:3]: %s", list1[:3]) | |
logger.debug("list2[:3]: %s", list2[:3]) | |
df = pd.DataFrame(zip_longest(list1, list2, fillvalue="")) | |
try: | |
# df.columns = ["text1", "text2"] | |
df.columns = [f"text{i + 1}" for i in range(len(df.columns))] | |
except Exception as exc: | |
logger.debug("df: \n%s", df) | |
logger.error("%s", exc) | |
state.ns.df = df | |
logger.debug("df: %s", df) | |
# st.table(df) # looks alright | |
# equiv to st.markdown(df.to_markdown())? | |
# stlyed pd dataframe? | |
# bigger, no pagination | |
# st.markdown(df.to_html(), unsafe_allow_html=True) | |
# ag_grid smallish, editable, probably slower | |
# if "df" not in globals() or "df" not in locals(): | |
if "df" not in locals(): | |
logger.debug(" df not defined, return") | |
if df.empty: | |
logger.debug(" df.empty, return") | |
return None | |
# print estimated completion time | |
len1 = len([elm.strip() for elm in list1 if elm.strip()]) | |
len2 = len([elm.strip() for elm in list2 if elm.strip()]) | |
len12 = len1 + len2 | |
time_min = 0.4 | |
time_max = 1 | |
time_av = .66 | |
uname = platform.uname() | |
if "amz2" in uname.release or "forindo" in uname.node: | |
time_min /= 12 | |
time_max /= 12 | |
time_av /= 12 | |
# reduce for sent align | |
if state.ns.sentali: | |
time_min /= 1.4 | |
time_max /= 1.4 | |
time_av /= 1.4 | |
# time0 = len12 * 0.4 | |
# time1 = len12 * 1 | |
# eta = pendulum.now() + pendulum.duration(seconds=len12 * 0.66) | |
time0 = len12 * time_min | |
time1 = len12 * time_max | |
eta = pendulum.now() + pendulum.duration(seconds=len12 * time_av) | |
in_words0 = pendulum.duration(seconds=time0).in_words() | |
in_words1 = pendulum.duration(seconds=time1).in_words() | |
diff_for_humans = eta.diff_for_humans() | |
dt_str = eta.to_datetime_string() | |
timezone_name = eta.timezone_name | |
running_in = uname.node | |
# streamlit.io mount to /mount/src/mlbee and uname.node set to 'localhost' | |
if Path("/mount/src/mlbee").exists() and running_in.startswith("local"): | |
running_in = "share.streamlit.io" | |
_ = ( | |
f"running in {running_in} -- " | |
f" processing {len1} + {len2} = {len12} blocks; " | |
f"estimated time to complete: {in_words0} to {in_words1}; " | |
f"eta: {diff_for_humans} ({dt_str} {timezone_name}) " | |
) | |
eta_msg = _ | |
# st.info(_) | |
# only show this for upload | |
if state.ns.sourcetype in ["upload"]: | |
_ = st.expander("to be aligned", expanded=False) | |
with _: | |
st.write(df) | |
logger.info("Processing data... %s", state.ns.beetype) | |
# if state.ns.beetype in ["ezbee", "dzbee", "debee"]: | |
if state.ns.beetype in ["mlbee"]: | |
with about_time() as t: | |
# diggin... | |
with st.spinner(f"{eta_msg}"): | |
try: | |
# aset = globals()[state.ns.beetype]( | |
aset = st_mlbee( | |
list1, | |
list2, | |
# eps=eps, | |
# min_samples=min_samples, | |
) | |
except Exception as e: | |
logger.exception( | |
"aset = globals()[state.ns.beetype](...) exc: %s", e | |
) | |
aset = "" | |
st.write("Collecting inputs...") | |
logger.debug("Collecting inputs...") | |
return None | |
st.success(f"Done, took {t.duration_human}") | |
else: | |
try: | |
filename = inspect.currentframe().f_code.co_filename # type: ignore | |
except Exception as e: | |
logger.error(e) | |
filename = "" | |
try: | |
lineno = inspect.currentframe().f_lineno # type: ignore | |
except Exception as e: | |
logger.error(e) | |
lineno = "" | |
st.write(f"{state.ns.beetype} coming soon...{filename}:{lineno}") | |
return None | |
if aset: | |
logger.debug("aset: %s...%s", aset[:3], aset[-3:]) | |
# logger.debug("aset[:10]: %s", aset[:10]) | |
if set_loglevel() <= 10: | |
st.write(aset) | |
# aligned_pairs = gen_pairs(list1, list2, aset) | |
aligned_pairs = aset2pairs(list1, list2, aset) | |
if aligned_pairs: | |
# logger.debug("%s...%s", aligned_pairs[:1], aligned_pairs[-1:]) | |
logger.debug("%s...s", aligned_pairs[:1]) | |
df_a = pd.DataFrame( | |
aligned_pairs, columns=["text1", "text2", "llh"], dtype="object" | |
) | |
if set_loglevel() <= 10: | |
_ = st.expander("done aligned") | |
with _: | |
st.table(df_a.astype(str)) | |
# st.markdown(df_a.astype(str).to_markdown()) | |
# st.markdown(df_a.astype(str).to_numpy().tolist()) | |
# insert seq no | |
df_a.insert(0, "sn", range(len(df_a))) | |
gb = GridOptionsBuilder.from_dataframe(df_a) | |
gb.configure_pagination(paginationAutoPageSize=True) | |
options = { | |
"resizable": True, | |
"autoHeight": True, | |
"wrapText": True, | |
"editable": True, | |
} | |
gb.configure_default_column(**options) | |
gridOptions = gb.build() | |
# st.write("editable aligned (double-click a cell to edit, drag column header to adjust widths)") | |
_ = "editable aligned (double-click a cell to edit, drag column header to adjust widths)" | |
with st.expander(_, expanded=False): | |
ag_df = AgGrid( | |
# df, | |
df_a, | |
gridOptions=gridOptions, | |
key="outside", | |
reload_data=True, | |
editable=True, | |
# width="100%", # width parameter is deprecated | |
height=750, | |
# fit_columns_on_grid_load=True, | |
update_mode=GridUpdateMode.MODEL_CHANGED, | |
) | |
# pop("sn"): remove sn column | |
df_a.pop("sn") | |
# ### prep download ### # | |
# taken from vizbee cb_save_xlsx | |
# subset = list(df_a.columns[2:3]) # 3rd col | |
subset = list(df_a.columns[2:]) # 3rd col | |
s_df = df_a.astype(str).style.applymap(color_map, subset=subset) | |
if set_loglevel() <= 10: | |
logger.debug(" showing styled aligned") | |
with st.expander("styled aligned"): | |
# st.dataframe(s_df) # can't handle styleddf | |
st.table(s_df) | |
output = io.BytesIO() | |
with pd.ExcelWriter( | |
output, engine="xlsxwriter" | |
) as writer: # pylint: disable=abstract-class-instantiated | |
s_df.to_excel(writer, index=False, header=False, sheet_name="Sheet1") | |
writer.sheets["Sheet1"].set_column("A:A", 70) | |
writer.sheets["Sheet1"].set_column("B:B", 70) | |
output.seek(0) | |
val = output.getvalue() | |
b64 = base64.b64encode(val) | |
filename = "" | |
if state.ns.src_filename: | |
filename = f"{state.ns.src_filename}-" | |
if state.ns.sentali: | |
extra = "aligned_sents" | |
else: | |
extra = "aligned_paras" | |
dl_xlsx = f'<a href="data:application/octet-stream;base64,{b64.decode()}" download="{filename}-{extra}.xlsx">Download aligned paras xlsx</a>' | |
_ = """ | |
output = io.BytesIO() | |
# df_a.astype(str).to_csv(output, sep="\t", index=False, header=False, encoding="gbk") | |
df_a.astype(object).to_csv(output, sep="\t", index=False, header=False, encoding="gbk") | |
output.seek(0) | |
val = output.getvalue() | |
b64 = base64.b64encode(val) | |
dl_tsv = f'<a href="data:application/octet-stream;base64,{b64.decode()}" download="{filename}aligned_paras.tsv">Download aligned paras tsv</a>' | |
# """ | |
col1_dl, col2_dl = st.columns(2) | |
with col1_dl: | |
st.markdown(dl_xlsx, unsafe_allow_html=True) | |
_ = """ | |
with col2_dl: | |
st.markdown(dl_tsv, unsafe_allow_html=True) | |
# """ | |
# reset | |
state.ns.updated = False | |
return None | |