Spaces:
Sleeping
Sleeping
File size: 2,492 Bytes
2c2081e b9d6157 2c2081e 76b2171 34815a7 2c2081e 34815a7 2c2081e e883c39 34815a7 b9d6157 2c2081e e883c39 2c2081e b9d6157 2c2081e b9d6157 2c2081e b9d6157 2c2081e b9d6157 2c2081e b9d6157 2c2081e 76b2171 2c2081e 76b2171 2c2081e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
"""Gen cmat for de/en text."""
# pylint: disable=invalid-name, too-many-branches
from typing import List, Optional
import more_itertools as mit
import numpy as np
# from logzero import logger
from loguru import logger
from tqdm import tqdm
# from model_pool import load_model_s
# from hf_model_s_cpu import model_s # load_model_s directly
from st_mlbee.load_model_s import load_model_s
# from st_mlbee.cos_matrix2 import cos_matrix2
from .cos_matrix2 import cos_matrix2
_ = """
try:
model_s = load_model_s()
except Exception as exc:
logger.erorr(exc)
raise
"""
try:
# model = model_s()
# model = model_s(alive_bar_on=True)
# default model-s mikeee/model_s_512
model_s = load_model_s()
# model_s_v2 = load_model_s("model_s_512v2") # model-s mikeee/model-s-512v2
except Exception as _:
logger.error(_)
raise
def gen_cmat(
text1: List[str],
text2: List[str],
bsize: int = 32, # default batch_size of model.encode
model=None,
) -> np.ndarray:
"""Gen corr matrix for texts.
Args:
----
text1: typically '''...''' splitlines()
text2: typically '''...''' splitlines()
bsize: batch size, default 50
model: for encoding list of strings, default model-s of mikeee/model_s_512
text1 = 'this is a test'
text2 = 'another test'
Returns:
-------
numpy array of cmat
"""
if model is None:
model = model_s
bsize = int(bsize)
if bsize <= 0:
bsize = 32
if isinstance(text1, str):
text1 = [text1]
if isinstance(text2, str):
text1 = [text2]
vec1 = []
vec2 = []
len1 = len(text1)
len2 = len(text2)
tot = len1 // bsize + bool(len1 % bsize)
tot += len2 // bsize + bool(len2 % bsize)
with tqdm(total=tot) as pbar:
for chunk in mit.chunked(text1, bsize):
try:
vec = model.encode(chunk)
except Exception as exc:
logger.error(exc)
raise
vec1.extend(vec)
pbar.update()
for chunk in mit.chunked(text2, bsize):
try:
vec = model.encode(chunk)
except Exception as exc:
logger.error(exc)
raise
vec2.extend(vec)
pbar.update()
try:
# note the order vec2, vec1
_ = cos_matrix2(np.array(vec2), np.array(vec1))
except Exception as exc:
logger.exception(exc)
raise
return _
|