"""Gen cmat for de/en text.""" # pylint: disable=invalid-name, too-many-branches from typing import List, Optional import more_itertools as mit import numpy as np # from logzero import logger from loguru import logger from tqdm import tqdm # from model_pool import load_model_s # from hf_model_s_cpu import model_s # load_model_s directly from st_mlbee.load_model_s import load_model_s # from st_mlbee.cos_matrix2 import cos_matrix2 from .cos_matrix2 import cos_matrix2 _ = """ try: model_s = load_model_s() except Exception as exc: logger.erorr(exc) raise """ try: # model = model_s() # model = model_s(alive_bar_on=True) # default model-s mikeee/model_s_512 model_s = load_model_s() # model_s_v2 = load_model_s("model_s_512v2") # model-s mikeee/model-s-512v2 except Exception as _: logger.error(_) raise def gen_cmat( text1: List[str], text2: List[str], bsize: int = 32, # default batch_size of model.encode model=None, ) -> np.ndarray: """Gen corr matrix for texts. Args: ---- text1: typically '''...''' splitlines() text2: typically '''...''' splitlines() bsize: batch size, default 50 model: for encoding list of strings, default model-s of mikeee/model_s_512 text1 = 'this is a test' text2 = 'another test' Returns: ------- numpy array of cmat """ if model is None: model = model_s bsize = int(bsize) if bsize <= 0: bsize = 32 if isinstance(text1, str): text1 = [text1] if isinstance(text2, str): text1 = [text2] vec1 = [] vec2 = [] len1 = len(text1) len2 = len(text2) tot = len1 // bsize + bool(len1 % bsize) tot += len2 // bsize + bool(len2 % bsize) with tqdm(total=tot) as pbar: for chunk in mit.chunked(text1, bsize): try: vec = model.encode(chunk) except Exception as exc: logger.error(exc) raise vec1.extend(vec) pbar.update() for chunk in mit.chunked(text2, bsize): try: vec = model.encode(chunk) except Exception as exc: logger.error(exc) raise vec2.extend(vec) pbar.update() try: # note the order vec2, vec1 _ = cos_matrix2(np.array(vec2), np.array(vec1)) except Exception as exc: logger.exception(exc) raise return _