import json inset_th=1 #_config=json.load(open("config.json","r")) _config={ "sug_based_list":["dispute","plaintiff"], "sug_pool_list":["corpus3835","2022~2023"], "embedder_list":["ftlf","ftrob"], "based_index":1, "pool_index":1, "emb_index":1, "sug_th":20, "cluster_epsilon":0.67, "similiar_trace_back_th":0.98, "back_ground_RGB":[77, 6, 39] } emb_dim_lst=[768,1024] bilstm_len_lst=[19,13] cnn_len_lst=[32,18] emb_dim=emb_dim_lst[_config["emb_index"]] bilstm_len=bilstm_len_lst[_config["based_index"]] cnn_len=cnn_len_lst[_config["based_index"]] sug_type=_config["sug_based_list"][_config["based_index"]] pool_type=_config["sug_pool_list"][_config["pool_index"]] emb_type=_config["embedder_list"][_config["emb_index"]] sug_th=_config["sug_th"] clust_th=_config["cluster_epsilon"] _th=_config["similiar_trace_back_th"] bg_rgb=(_config["back_ground_RGB"][0],_config["back_ground_RGB"][1],_config["back_ground_RGB"][2]) import os,sys #_gpu=(1==1) #if not _gpu: # os.environ['CUDA_VISIBLE_DEVICES'] = '-1' import cv2#opencv-python 4.6.0.66 import colorama from colorama import Fore,Style,Back import json import numpy as np from numpy.linalg import norm from sentence_transformers import SentenceTransformer from tqdm import tqdm import tensorflow as tf from tensorflow.keras.models import load_model #--------------------------------------- def logistic(x_r,y_r,x_e,_proba=True): from sklearn import linear_model from sklearn.inspection import permutation_importance model=linear_model.LogisticRegression(max_iter=100000) model.fit(x_r,y_r) p_e=model.predict(x_e) prob_e=model.predict_proba(x_e) prob_sum=[i[1] for i in prob_e] return (prob_sum if _proba else p_e) def cos_sim(a,b): return np.dot(a,b)/(norm(a)*norm(b)) def replace_all(t,rp_lst,k,_type=0): temp=t for _e in rp_lst: if _type==-1: temp=temp.replace(_e,k+_e) elif _type==1: temp=temp.replace(_e,_e+k) else: temp=temp.replace(_e,k) return temp def jl(file_path): with open(file_path, "r", encoding="utf8") as json_file: json_list = list(json_file) return [json.loads(json_str) for json_str in json_list] def lst_2_dict(lst): _dict={i["filename"]:[i["p_point"],i["d_point"],i["Controversy"]] for i in lst} return _dict def clust_2_dict(clust): _dict={} ct=0 for i in clust: if len(clust[i])==1: _dict[clust[i][0]]=-1 else: ct+=1 for _e in clust[i]: _dict[_e]=ct return _dict def clust_label(clust): _dict={} for i in clust: for _e in clust[i]: if len(clust[i])>1: _dict[_e]=i else: _dict[_e]='-1' return _dict #----------------------------- def clust_core(clust,vec_lst,id_lst,_type="mean"): _dict={} for i in clust: if _type=="head": _dict[i]=vec_lst[id_lst.index(clust[i][0])] elif _type=="central": tp_lst=np.array([vec_lst[id_lst.index(_e)] for _e in clust[i]]) temp=np.average(tp_lst, axis=0) cs_lst=[[cos_sim(_e,temp),list(_e)] for _e in tp_lst] _dict[i]=max(cs_lst)[-1] else:#_type=="mean" tp_lst=np.array([vec_lst[id_lst.index(_e)] for _e in clust[i]]) _dict[i]=np.average(tp_lst, axis=0) return _dict def clust_search(core_dict,target,clust_th=0.65): temp=max([[cos_sim(target,core_dict[i]),i] for i in core_dict]) ot_,label_=temp return label_ if ot_>=clust_th else '-1' def vec2img(vec_lst1,clust_lst1,vec_lst2,clust_lst2,r): tp_lst1=[[vec_lst1[i],clust_lst1[i]] for i in range(len(clust_lst1))] tp_lst2=[[vec_lst2[i],clust_lst2[i]] for i in range(len(clust_lst2))] lst1=sorted(tp_lst1,key=lambda x:x[1]) lst2=sorted(tp_lst2,key=lambda x:x[1]) m_lst=lst1+lst2 _img=[[255 for _ee in range(len(m_lst))] for _e in range(len(m_lst))] for i in range(len(m_lst)): for j in range(len(m_lst)): if ir else temp/r*128 _tp=int(_tp-1) _img[i][j]=_tp _img[j][i]=_tp return _img def img_resize(_img,_max_size): return cv2.resize(np.array(_img).astype('float32'), (_max_size, _max_size), interpolation=cv2.INTER_AREA).tolist() def cnn_load(_device="/gpu:0"): global cnn_model with tf.device(_device): cnn_model=load_model("./models/"+sug_type+"_"+emb_type+"_cnn.dat") cnn_model.load_weights("./models/"+sug_type+"_"+emb_type+"_cnn_best.hdf5") def bilstm_load(_device="/gpu:0"): global bilstm_model with tf.device(_device): bilstm_model=load_model("./models/"+sug_type+"_"+emb_type+"_sa.dat") bilstm_model.load_weights("./models/"+sug_type+"_"+emb_type+"_sa_best.hdf5") #--------------------------------------- _tranpose=(1==1) from colorama import Fore,Style,Back from pretty_html_table import build_table import pandas as pd def html_hl(lst): #font_path = "./font/TaipeiSansTCBeta-Regular.ttf" #font = ImageFont.truetype(font_path, font_size) tp_lst=[] for i in lst: temp=""+i["content"]+"" tp_lst.append(temp) return "".join(tp_lst) def ansi_to_html(_f,file_path,_tranpose=True): if _tranpose: _dict={"item":["plaintiff","p_point","score"],_f["target"]+"(target)":["plaintiff_anchor2","p_point_anchor2",""],_f["case_id"]:["plaintiff_anchor1","p_point_anchor1","score_anchor"]} else: _dict={"case_name":[_f["case_id"],_f["target"]+"(target)"],"plaintiff":["plaintiff_anchor1","plaintiff_anchor2"],"p_point":["p_point_anchor1","p_point_anchor2"],"score":["","score_anchor"]} p1=html_hl(_f["plaintiff_case1"]) p2=html_hl(_f["plaintiff_case2"]) p_point1=html_hl(_f["p_point_case1"]) p_point2=html_hl(_f["p_point_case2"]) score_="\n=0.75 else "yellow" if _f["ensemble_pred"]>=0.5 else "red")+"\">"+str(_f["ensemble_pred"])+"" #score_=""++"\">"+str(_f["ensemble_pred"])+"" df=pd.DataFrame(_dict) html_table_blue_light = build_table(df, 'blue_light') #print(type(html_table_blue_light)) injection="" #"" html_table_blue_light=html_table_blue_light[:html_table_blue_light.find("")+7]+injection+html_table_blue_light[html_table_blue_light.find("")+7:] html_table_blue_light=html_table_blue_light.replace("plaintiff_anchor1",p1).replace("plaintiff_anchor2",p2)\ .replace("p_point_anchor1",p_point1).replace("p_point_anchor2",p_point2)\ .replace("score_anchor",score_) with open(file_path, 'w',) as f: f.write(html_table_blue_light) return html_table_blue_light #--------------------------------------- from PIL import Image, ImageDraw, ImageFont # Dictionary mapping colorama codes to RGB colors ANSI_BG_COLORS = { Fore.BLACK: (0, 0, 0), Fore.RED: (255, 0, 0), Fore.GREEN: (0, 255, 0), Fore.YELLOW: (255, 255, 0), Fore.BLUE: (0, 0, 255), Fore.MAGENTA: (255, 0, 255), Fore.CYAN: (0, 255, 255), Fore.WHITE: (255, 255, 255), Fore.RESET: (0, 0, 0), # Reset to black Back.BLACK: (0, 0, 0), Back.RED: (255, 0, 0), Back.GREEN: (0, 255, 0), Back.YELLOW: (255, 255, 0), Back.BLUE: (0, 0, 255), Back.MAGENTA: (255, 0, 255), Back.CYAN: (0, 255, 255), Back.WHITE: (255, 255, 255), '\033[0m': bg_rgb # Reset to White background } ANSI_COLORS={_e:"#"+str(hex(1*256*256*256+ANSI_BG_COLORS[_e][0]*256*256+ANSI_BG_COLORS[_e][1]*256+ANSI_BG_COLORS[_e][2]))[3:] for _e in ANSI_BG_COLORS} def ansi_to_image(ansi_text, font_size=20, image_path="./test.png"): global bg_rgb font_path = "./font/TaipeiSansTCBeta-Regular.ttf" font = ImageFont.truetype(font_path, font_size) # Split the text into lines lines = ansi_text.split('\n') # Calculate image size max_width = 0 total_height = 0 line_heights = [] for line in lines: text_width, text_height = font.getsize(line) max_width = max(max_width, text_width) total_height += text_height line_heights.append(text_height) # Create a blank image image = Image.new('RGB', (max_width, total_height), color=bg_rgb) draw = ImageDraw.Draw(image) y = 0 for line, line_height in zip(lines, line_heights): x = 0 segments = line.split('\033') anchor_bg_color=(255,255,255) for segment in segments: #print(segment) if segment and segment[-1]=='m': code= segment[:-1] anchor_bg_color = ANSI_BG_COLORS.get(f'\033{code}m', anchor_bg_color) #text_width, text_height = draw.textsize(text, font=font) #draw.rectangle([x, y, x + text_width, y + line_height], fill=(255, 255, 255)) #draw.text((x, y), text, font=font, fill=anchor_bg_color) x += 0 if 'm' in segment: code, text = segment.split('m', 1) font_color = ANSI_BG_COLORS.get(f'\033{code}m', anchor_bg_color) text_width, text_height = draw.textsize(text, font=font) draw.rectangle([x, y, x + text_width, y + line_height], anchor_bg_color) draw.text((x, y), text, font=font, fill=font_color) x += text_width else: text = segment text_width, text_height = draw.textsize(text, font=font) draw.text((x, y), text, font=font, fill=(255,255,255)) x += text_width y += line_height # Save the image image.save(image_path) return image_path # 示例ANSI文本 #ansi_content = '\033[44m555\033[0m\n111\033[41m555\033[0m' # 將ANSI轉換為圖像 #image_path = ansi_to_image(ansi_content) # #--------------------------------------- def suggesting(the_pool,target_name,case_dict): global ANSI_COLORS,_th,c_th,sug_th,corpus_dict,corpus_pd_f,vec_lst,id_lst,sen_lst,corpus_clust_label,_cluster_core_dict,_embedder global bilstm_len,cnn_len,emb_dim,inset_th,clust_th lst_2=[_e for _e in case_dict["p_point"]][:bilstm_len] #for _e in lst2: # temp=_embedder.encode(_e) # vec_lst_2.append() vec_lst_2=[_embedder.encode(_e) for _e in lst_2] clst_2=[clust_search(_cluster_core_dict,_e,clust_th) for _e in vec_lst_2] plst_2=replace_all("".join(case_dict["plaintiff"]),key_lst,sp_key,1).split(sp_key) v_plst_2=[_embedder.encode(_e) for _e in plst_2] print(clst_2) rt_lst=[] for i in tqdm(the_pool): if target_name==i: continue lst_1=[_e for _e in corpus_dict[i]] id_lst_1=[id_lst[sen_lst.index(_e)] for _e in lst_1] vec_lst_1=[vec_lst[sen_lst.index(_e)] for _e in lst_1]#[_embedder.encode(_e) for _e in lst_1] clst_1=[corpus_clust_label[_e] for _e in id_lst_1]#[clust_search(_cluster_core_dict,_e,0.68) for _e in vec_lst_1] #print(clst_1) inset=sorted([_e for _e in set(clst_1)&set(clst_2) if _e!=-1]) temp_ot={} if len(inset)>=max(1,inset_th): temp_ot["target"]=target_name temp_ot["inset"]=inset #print(len(inset)) _img=img_resize(vec2img(vec_lst_1,clst_1,vec_lst_2,clst_2,clust_th),cnn_len) cnn_pred=cnn_model.predict(np.array([_img])/255) _con1,_con2=[],[] for tp_i in range(bilstm_len): if len(lst_1)>tp_i: _con1.append(vec_lst_1[tp_i]) else: _con1.append([0]*emb_dim) for tp_i in range(bilstm_len): if len(lst_2)>tp_i: _con2.append(vec_lst_2[tp_i]) else: _con2.append([0]*emb_dim) _con1=np.array([_con1]) _con2=np.array([_con2]) print(len(_con1),len(_con2),len(_con2[0])) #_con1=list(np.array(vec_lst_1).reshape(len(lst_1)*emb_dim))+[0]*(emb_dim*(bilstm_len-len(lst_1))) if len(lst_1)<=bilstm_len else list(np.array(vec_lst_1).reshape(len(lst_1)*emb_dim))[:bilstm_len*emb_dim] #_con2=list(np.array(vec_lst_2).reshape(len(lst_2)*emb_dim))+[0]*(emb_dim*(bilstm_len-len(lst_2))) if len(lst_2)<=bilstm_len else list(np.array(vec_lst_2).reshape(len(lst_2)*emb_dim))[:bilstm_len*emb_dim] bilstm_pred=bilstm_model.predict([_con1,_con2]) temp_ot["cnn_pred"]=float(cnn_pred[0][0]) temp_ot["bilstm_pred"]=float(bilstm_pred[0][0]) #print(cnn_pred) #print(bilstm_pred) x_e=[[bilstm_pred[0][0],cnn_pred[0][0]]] ensemble_pred=logistic(x_r,y_r,x_e) temp_ot["ensemble_pred"]=float(ensemble_pred[0]) #print(ensemble_pred) pre_lst_1=[[color_lst[inset.index(clst_1[_e]) % len(color_lst)],Fore.WHITE,lst_1[_e],Style.RESET_ALL] if clst_1[_e] in inset else [Style.RESET_ALL,lst_1[_e]] for _e in range(len(lst_1))] pre_lst_2=[[color_lst[inset.index(clst_2[_e]) % len(color_lst)],Fore.WHITE,lst_2[_e],Style.RESET_ALL] if clst_2[_e] in inset else [Style.RESET_ALL,lst_2[_e]] for _e in range(len(lst_2))] vlst_1=[[vec_lst_1[_e],pre_lst_1[_e][0]] for _e in range(len(pre_lst_1)) if len(pre_lst_1[_e])==4] vlst_2=[[vec_lst_2[_e],pre_lst_2[_e][0]] for _e in range(len(pre_lst_2)) if len(pre_lst_2[_e])==4] #print(lst_1) plst_1=replace_all("".join(corpus_pd_f[i.replace("_",",")][0]),key_lst,sp_key,1).split(sp_key) v_plst_1=[_embedder.encode(_e) for _e in plst_1] cs_p1=[max([[cos_sim(_e,_v[0]),_v[-1]] for _v in vlst_1]) for _e in v_plst_1] cs_p2=[max([[cos_sim(_e,_v[0]),_v[-1]] for _v in vlst_2]) for _e in v_plst_2] pre_lst_p1=[[cs_p1[_e][-1],Fore.WHITE,plst_1[_e],Style.RESET_ALL] if cs_p1[_e][0]>_th else [Style.RESET_ALL,plst_1[_e]] for _e in range(len(cs_p1))] pre_lst_p2=[[cs_p2[_e][-1],Fore.WHITE,plst_2[_e],Style.RESET_ALL] if cs_p2[_e][0]>_th else [Style.RESET_ALL,plst_2[_e]] for _e in range(len(cs_p2))] #if max_dp=0.75 else Fore.YELLOW if temp_ot["ensemble_pred"]>=0.5 else Fore.RED)+str(temp_ot["ensemble_pred"])+Style.RESET_ALL+"\n" tp_str+=Fore.MAGENTA+"---plaintiff_case1---"+Style.RESET_ALL+"\n" tp_str+="".join(draw_lst_p1)+Style.RESET_ALL+"\n" tp_str+=Fore.MAGENTA+"---p_point_case1---"+Style.RESET_ALL+"\n" tp_str+="".join(draw_lst_1)+Style.RESET_ALL+"\n" ### tp_str+=Fore.BLUE+"target"+Style.RESET_ALL+"\n" tp_str+=Fore.MAGENTA+"---plaintiff_case2---"+Style.RESET_ALL+"\n" tp_str+="".join(draw_lst_p2)+Style.RESET_ALL+"\n" tp_str+=Fore.MAGENTA+"---p_point_case2---"+Style.RESET_ALL+"\n" tp_str+="".join(draw_lst_2)+Style.RESET_ALL+"\n" #tp_str+="---------------------"+"\n" temp_ot["output"]=tp_str rt_lst.append(temp_ot) print(tp_str) ot=sorted(rt_lst,key=lambda x:x["ensemble_pred"],reverse=True) ot_lst=[i["output"] for i in ot[:sug_th]] for i in ot[:sug_th]: file=open("./json_file/"+str(target_name).replace(",","_")+"&"+str(i["case_id"])+".json","w",encoding='utf8') json.dump({_e:i[_e] for _e in i if _e!="output"},file,indent=4,ensure_ascii=False) file.close() return ot_lst,ot[:sug_th] #--------------------------------------- _dir_lst=["../gpt4_0409_p_3/","../taide_llama3_8b_3/"] _dir=_dir_lst[0] sp_key="@" emb_model="ftrob" emb_model_path={\ "lf":"thunlp/Lawformer",\ "rob":'hfl/chinese-roberta-wwm-ext-large',\ "ftlf":"./sbert_pretrained_model/training-lawformer-clause_th10_100k_task-bs100-e2-2023-10-28/", "ftrob":"./sbert_pretrained_model/training-roberta-clause_th10_100k_task-bs100-e2-2023-10-27",\ } color_lst=[Back.BLUE,Back.GREEN,Back.MAGENTA,Back.YELLOW,Back.RED,Back.CYAN]#[Fore.RED,Fore.GREEN,Fore.YELLOW,Fore.BLUE,Fore.MAGENTA,Fore.CYAN] log_f=json.load(open("./src/plaintiff_logistic_features.json","r"))["BiLSTM_CNN"] x_r=np.array(log_f)[:,:-1] y_r=np.array(log_f)[:,-1] #pd_path,dis_path,s_path,v_path,c_path,t_path,cr_path,br_path=["TAIDE-LX-8B.jsonl","llama3_taide_8b_re_3_o_c.json","sentence.json","vector.json","hdb_cluster.json","hdb_ternary_array.json","hdb_cnn_result.json","hdb_sa_result.json"] pd_f=corpus_pd_f=json.load(open("./src/corpus3835_raw.json","r"))["claim"] s_f=json.load(open("./src/plaintiff_corpus3835_sen.json","r")) v_f=json.load(open("./src/plaintiff_corpus3835_vec.json","r"))#json.load(open(_dir+v_path,"r")) o_c_f=json.load(open("./src/plaintiff_corpus3835_cluster.json","r"))["clusters"] c_f=clust_2_dict(o_c_f) t_f=json.load(open("./src/plaintiff_ter.json","r")) if pool_type=="corpus3835": corpus_clust_label=clust_label(o_c_f) vec_lst=v_f["vector"] id_lst=v_f["id"] sen_lst=s_f["sentence"] corpus_dict={} for i in range(len(id_lst)): fid=id_lst[i].split("@")[0] if fid not in corpus_dict: corpus_dict[fid]=[sen_lst[i]] else: corpus_dict[fid].append(sen_lst[i]) corpus_pd_f=json.load(open("./src/corpus3835_raw.json","r"))["claim"] else: vec_f=json.load(open("./src/plaintiff_2022~2023_vec.json","r")) vec_lst=[_e for i in vec_f for _e in vec_f[i]] corpus_dict=json.load(open("./src/plaintiff_2022~2023_raw.json","r")) corpus_pd_f=json.load(open("./src/2022~2023_raw.json","r"))["claim"] corpus_clust_f=json.load(open("./src/plaintiff_2022~2023_clust.json","r")) sen_lst=[_e for i in corpus_dict for _e in corpus_dict[i]] id_lst=[i+"@"+str(_e) for i in corpus_dict for _e in range(len(corpus_dict[i]))] corpus_clust_label={_e:corpus_clust_f[_e[:_e.find("@")]][int(_e[_e.find("@")+1:])] for _e in id_lst} ### new_point_f=lst_2_dict(jl("../law/2022~2023/gpt-4-turbo-0409-0.3-new22_23.jsonl")) new_pd_f=json.load(open("../law/2022~2023/new22_23_3k3_corpus_raw.json","r"))["claim"] key_lst=[",","。","?","?","!","!",";",":",";",":"]#["。","?","?","!","!",";",":",";",":"] _embedder = SentenceTransformer(emb_model_path[emb_model]) cnn_model =... bilstm_model =... """#fifo cnn_load() bilstm_load() """ cnn_load("/cpu:0") bilstm_load("/cpu:0") #""" _cluster_core_dict=clust_core(o_c_f,v_f["vector"],v_f["id"],"central") #--------------------------------------- from colorama import Fore,Style,Back import gradio as gr def case_sug(file_name,plaintiff,p_point): global new_pd_f,new_point_f,corpus_dict print(file_name) #print(point_f) #print(list(pd_f.keys()).index(file_name)) if file_name not in new_pd_f: print("file not found") file_name="user_input" else: plaintiff=new_pd_f[file_name][0] p_point=new_point_f[file_name][0] global sug_th p_point=p_point.split("。") if type(p_point)==type("111") else p_point _pool=[i for i in corpus_dict] _case_dict={"plaintiff":plaintiff,"p_point":p_point} print(_case_dict,[type(_case_dict[_e]) for _e in _case_dict]) ot,ot_dict=suggesting(_pool,file_name,_case_dict) #ot=[Back.BLUE+dispute+Style.RESET_ALL]*10 output_list=[] print("-----") print(len(ot_dict)) out_path="./out_of_range.html" for i in range(sug_th): if i