|
import json |
|
inset_th=1 |
|
|
|
_config={ |
|
"sug_based_list":["dispute","plaintiff"], |
|
"sug_pool_list":["corpus3835","2022~2023"], |
|
"embedder_list":["ftlf","ftrob"], |
|
"based_index":0, |
|
"pool_index":1, |
|
"emb_index":1, |
|
"sug_th":20, |
|
"cluster_epsilon":0.67, |
|
"similiar_trace_back_th":0.98, |
|
"back_ground_RGB":[77, 6, 39] |
|
} |
|
emb_dim_lst=[768,1024] |
|
bilstm_len_lst=[19,13] |
|
cnn_len_lst=[32,18] |
|
|
|
emb_dim=emb_dim_lst[_config["emb_index"]] |
|
bilstm_len=bilstm_len_lst[_config["based_index"]] |
|
cnn_len=cnn_len_lst[_config["based_index"]] |
|
|
|
|
|
sug_type=_config["sug_based_list"][_config["based_index"]] |
|
pool_type=_config["sug_pool_list"][_config["pool_index"]] |
|
emb_type=_config["embedder_list"][_config["emb_index"]] |
|
|
|
sug_th=_config["sug_th"] |
|
|
|
clust_th=_config["cluster_epsilon"] |
|
_th=_config["similiar_trace_back_th"] |
|
|
|
bg_rgb=(_config["back_ground_RGB"][0],_config["back_ground_RGB"][1],_config["back_ground_RGB"][2]) |
|
|
|
|
|
|
|
import os,sys |
|
|
|
|
|
|
|
|
|
|
|
import cv2 |
|
import colorama |
|
from colorama import Fore,Style,Back |
|
import json |
|
import numpy as np |
|
from numpy.linalg import norm |
|
from sentence_transformers import SentenceTransformer |
|
from tqdm import tqdm |
|
import tensorflow as tf |
|
from tensorflow.keras.models import load_model |
|
|
|
def logistic(x_r,y_r,x_e,_proba=True): |
|
from sklearn import linear_model |
|
from sklearn.inspection import permutation_importance |
|
model=linear_model.LogisticRegression(max_iter=100000) |
|
model.fit(x_r,y_r) |
|
|
|
p_e=model.predict(x_e) |
|
prob_e=model.predict_proba(x_e) |
|
prob_sum=[i[1] for i in prob_e] |
|
return (prob_sum if _proba else p_e) |
|
|
|
def cos_sim(a,b): |
|
return np.dot(a,b)/(norm(a)*norm(b)) |
|
def replace_all(t,rp_lst,k,_type=0): |
|
temp=t |
|
for _e in rp_lst: |
|
|
|
if _type==-1: |
|
temp=temp.replace(_e,k+_e) |
|
elif _type==1: |
|
temp=temp.replace(_e,_e+k) |
|
else: |
|
temp=temp.replace(_e,k) |
|
return temp |
|
def jl(file_path): |
|
with open(file_path, "r", encoding="utf8") as json_file: |
|
json_list = list(json_file) |
|
return [json.loads(json_str) for json_str in json_list] |
|
def lst_2_dict(lst): |
|
_dict={i["filename"]:[i["p_point"],i["d_point"],i["Controversy"]] for i in lst} |
|
return _dict |
|
def clust_2_dict(clust): |
|
_dict={} |
|
|
|
ct=0 |
|
for i in clust: |
|
|
|
if len(clust[i])==1: |
|
_dict[clust[i][0]]=-1 |
|
else: |
|
ct+=1 |
|
for _e in clust[i]: |
|
|
|
_dict[_e]=ct |
|
return _dict |
|
def clust_label(clust): |
|
_dict={} |
|
for i in clust: |
|
for _e in clust[i]: |
|
if len(clust[i])>1: |
|
_dict[_e]=i |
|
else: |
|
_dict[_e]='-1' |
|
return _dict |
|
|
|
def clust_core(clust,vec_lst,id_lst,_type="mean"): |
|
_dict={} |
|
for i in clust: |
|
if _type=="head": |
|
_dict[i]=vec_lst[id_lst.index(clust[i][0])] |
|
elif _type=="central": |
|
tp_lst=np.array([vec_lst[id_lst.index(_e)] for _e in clust[i]]) |
|
temp=np.average(tp_lst, axis=0) |
|
cs_lst=[[cos_sim(_e,temp),list(_e)] for _e in tp_lst] |
|
_dict[i]=max(cs_lst)[-1] |
|
else: |
|
tp_lst=np.array([vec_lst[id_lst.index(_e)] for _e in clust[i]]) |
|
_dict[i]=np.average(tp_lst, axis=0) |
|
return _dict |
|
|
|
def clust_search(core_dict,target,clust_th=0.65): |
|
temp=max([[cos_sim(target,core_dict[i]),i] for i in core_dict]) |
|
ot_,label_=temp |
|
|
|
return label_ if ot_>=clust_th else '-1' |
|
|
|
def vec2img(vec_lst1,clust_lst1,vec_lst2,clust_lst2,r): |
|
tp_lst1=[[vec_lst1[i],clust_lst1[i]] for i in range(len(clust_lst1))] |
|
tp_lst2=[[vec_lst2[i],clust_lst2[i]] for i in range(len(clust_lst2))] |
|
|
|
lst1=sorted(tp_lst1,key=lambda x:x[1]) |
|
lst2=sorted(tp_lst2,key=lambda x:x[1]) |
|
|
|
m_lst=lst1+lst2 |
|
_img=[[255 for _ee in range(len(m_lst))] for _e in range(len(m_lst))] |
|
for i in range(len(m_lst)): |
|
for j in range(len(m_lst)): |
|
if i<j: |
|
temp=cos_sim(m_lst[i][0],m_lst[j][0]) |
|
_tp=(temp-r)/(1-r)*128+127 if temp>r else temp/r*128 |
|
|
|
_tp=int(_tp-1) |
|
_img[i][j]=_tp |
|
_img[j][i]=_tp |
|
return _img |
|
def img_resize(_img,_max_size): |
|
return cv2.resize(np.array(_img).astype('float32'), (_max_size, _max_size), interpolation=cv2.INTER_AREA).tolist() |
|
def cnn_load(_device="/gpu:0"): |
|
global cnn_model |
|
with tf.device(_device): |
|
cnn_model=load_model("./models/"+sug_type+"_"+emb_type+"_cnn.dat") |
|
cnn_model.load_weights("./models/"+sug_type+"_"+emb_type+"_cnn_best.hdf5") |
|
def bilstm_load(_device="/gpu:0"): |
|
global bilstm_model |
|
with tf.device(_device): |
|
bilstm_model=load_model("./models/"+sug_type+"_"+emb_type+"_sa.dat") |
|
bilstm_model.load_weights("./models/"+sug_type+"_"+emb_type+"_sa_best.hdf5") |
|
|
|
_tranpose=(1==1) |
|
from colorama import Fore,Style,Back |
|
from pretty_html_table import build_table |
|
import pandas as pd |
|
|
|
|
|
def html_hl(lst): |
|
|
|
|
|
|
|
tp_lst=[] |
|
|
|
for i in lst: |
|
temp="<mark style=\"background:"+i["background_color"]+";color:"+i["font_color"]+"\">"+i["content"]+"</mark>" |
|
tp_lst.append(temp) |
|
|
|
return "".join(tp_lst) |
|
def ansi_to_html_dis(_f,file_path,_tranpose=True): |
|
|
|
if _tranpose: |
|
_dict={"item":["plaintiff","defendant","dispute","score"],_f["target"]+"(target)":["plaintiff_anchor2","defendant_anchor2","dispute_anchor2",""],_f["case_id"]:["plaintiff_anchor1","defendant_anchor1","dispute_anchor1","score_anchor"]} |
|
else: |
|
_dict={"case_name":[_f["case_id"],_f["target"]+"(target)"],"plaintiff":["plaintiff_anchor1","plaintiff_anchor2"],"defendant":["defendant_anchor1","defendant_anchor2"],"dispute":["dispute_anchor1","dispute_anchor2"],"score":["","score_anchor"]} |
|
|
|
|
|
p1=html_hl(_f["plaintiff_case1"]) |
|
p2=html_hl(_f["plaintiff_case2"]) |
|
d1=html_hl(_f["defendant_case1"]) |
|
d2=html_hl(_f["defendant_case2"]) |
|
dis1=html_hl(_f["dispute_case1"]) |
|
dis2=html_hl(_f["dispute_case2"]) |
|
score_="\n<mark style=\"background:#ffffff;color:"+("green" if _f["ensemble_pred"]>=0.75 else "yellow" if _f["ensemble_pred"]>=0.5 else "red")+"\">"+str(_f["ensemble_pred"])+"</mark>" |
|
|
|
|
|
df=pd.DataFrame(_dict) |
|
html_table_blue_light = build_table(df, 'blue_light') |
|
|
|
injection="<meta charset=\"UTF-8\">" |
|
|
|
html_table_blue_light=html_table_blue_light[:html_table_blue_light.find("<thead>")+7]+injection+html_table_blue_light[html_table_blue_light.find("<thead>")+7:] |
|
html_table_blue_light=html_table_blue_light.replace("plaintiff_anchor1",p1).replace("plaintiff_anchor2",p2)\ |
|
.replace("defendant_anchor1",d1).replace("defendant_anchor2",d2)\ |
|
.replace("dispute_anchor1",dis1).replace("dispute_anchor2",dis2)\ |
|
.replace("score_anchor",score_) |
|
|
|
with open(file_path, 'w',) as f: |
|
f.write(html_table_blue_light) |
|
return html_table_blue_light |
|
def ansi_to_html(_f,file_path,_tranpose=True): |
|
|
|
if _tranpose: |
|
_dict={"item":["plaintiff","p_point","score"],_f["target"]+"(target)":["plaintiff_anchor2","p_point_anchor2",""],_f["case_id"]:["plaintiff_anchor1","p_point_anchor1","score_anchor"]} |
|
else: |
|
_dict={"case_name":[_f["case_id"],_f["target"]+"(target)"],"plaintiff":["plaintiff_anchor1","plaintiff_anchor2"],"p_point":["p_point_anchor1","p_point_anchor2"],"score":["","score_anchor"]} |
|
|
|
|
|
p1=html_hl(_f["plaintiff_case1"]) |
|
p2=html_hl(_f["plaintiff_case2"]) |
|
|
|
p_point1=html_hl(_f["p_point_case1"]) |
|
p_point2=html_hl(_f["p_point_case2"]) |
|
score_="\n<mark style=\"background:#ffffff;color:"+("green" if _f["ensemble_pred"]>=0.75 else "yellow" if _f["ensemble_pred"]>=0.5 else "red")+"\">"+str(_f["ensemble_pred"])+"</mark>" |
|
|
|
|
|
df=pd.DataFrame(_dict) |
|
html_table_blue_light = build_table(df, 'blue_light') |
|
|
|
injection="<meta charset=\"UTF-8\">" |
|
|
|
html_table_blue_light=html_table_blue_light[:html_table_blue_light.find("<thead>")+7]+injection+html_table_blue_light[html_table_blue_light.find("<thead>")+7:] |
|
html_table_blue_light=html_table_blue_light.replace("plaintiff_anchor1",p1).replace("plaintiff_anchor2",p2)\ |
|
.replace("p_point_anchor1",p_point1).replace("p_point_anchor2",p_point2)\ |
|
.replace("score_anchor",score_) |
|
|
|
with open(file_path, 'w',) as f: |
|
f.write(html_table_blue_light) |
|
return html_table_blue_light |
|
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
|
|
|
|
|
|
ANSI_BG_COLORS = { |
|
Fore.BLACK: (0, 0, 0), |
|
Fore.RED: (255, 0, 0), |
|
Fore.GREEN: (0, 255, 0), |
|
Fore.YELLOW: (255, 255, 0), |
|
Fore.BLUE: (0, 0, 255), |
|
Fore.MAGENTA: (255, 0, 255), |
|
Fore.CYAN: (0, 255, 255), |
|
Fore.WHITE: (255, 255, 255), |
|
Fore.RESET: (0, 0, 0), |
|
Back.BLACK: (0, 0, 0), |
|
Back.RED: (255, 0, 0), |
|
Back.GREEN: (0, 255, 0), |
|
Back.YELLOW: (255, 255, 0), |
|
Back.BLUE: (0, 0, 255), |
|
Back.MAGENTA: (255, 0, 255), |
|
Back.CYAN: (0, 255, 255), |
|
Back.WHITE: (255, 255, 255), |
|
'\033[0m': bg_rgb |
|
} |
|
|
|
ANSI_COLORS={_e:"#"+str(hex(1*256*256*256+ANSI_BG_COLORS[_e][0]*256*256+ANSI_BG_COLORS[_e][1]*256+ANSI_BG_COLORS[_e][2]))[3:] for _e in ANSI_BG_COLORS} |
|
def ansi_to_image(ansi_text, font_size=20, image_path="./test.png"): |
|
global bg_rgb |
|
font_path = "./font/TaipeiSansTCBeta-Regular.ttf" |
|
font = ImageFont.truetype(font_path, font_size) |
|
|
|
|
|
|
|
lines = ansi_text.split('\n') |
|
|
|
|
|
max_width = 0 |
|
total_height = 0 |
|
line_heights = [] |
|
for line in lines: |
|
text_width, text_height = font.getsize(line) |
|
max_width = max(max_width, text_width) |
|
total_height += text_height |
|
line_heights.append(text_height) |
|
|
|
|
|
image = Image.new('RGB', (max_width, total_height), color=bg_rgb) |
|
draw = ImageDraw.Draw(image) |
|
|
|
y = 0 |
|
for line, line_height in zip(lines, line_heights): |
|
x = 0 |
|
segments = line.split('\033') |
|
anchor_bg_color=(255,255,255) |
|
for segment in segments: |
|
|
|
if segment and segment[-1]=='m': |
|
code= segment[:-1] |
|
anchor_bg_color = ANSI_BG_COLORS.get(f'\033{code}m', anchor_bg_color) |
|
|
|
|
|
|
|
x += 0 |
|
if 'm' in segment: |
|
code, text = segment.split('m', 1) |
|
font_color = ANSI_BG_COLORS.get(f'\033{code}m', anchor_bg_color) |
|
text_width, text_height = draw.textsize(text, font=font) |
|
draw.rectangle([x, y, x + text_width, y + line_height], anchor_bg_color) |
|
draw.text((x, y), text, font=font, fill=font_color) |
|
x += text_width |
|
else: |
|
|
|
text = segment |
|
text_width, text_height = draw.textsize(text, font=font) |
|
draw.text((x, y), text, font=font, fill=(255,255,255)) |
|
x += text_width |
|
y += line_height |
|
|
|
|
|
image.save(image_path) |
|
return image_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def suggesting_dis(the_pool,target_name,case_dict): |
|
global ANSI_COLORS,_th,c_th,sug_th,corpus_dict,corpus_pd_f,vec_lst,id_lst,sen_lst,corpus_clust_label,_cluster_core_dict,_embedder |
|
global bilstm_len,cnn_len,emb_dim,inset_th,clust_th |
|
lst_2=[_e for _e in case_dict["dispute"]][:bilstm_len] |
|
|
|
|
|
|
|
|
|
vec_lst_2=[_embedder.encode(_e) for _e in lst_2] |
|
|
|
clst_2=[clust_search(_cluster_core_dict,_e,clust_th) for _e in vec_lst_2] |
|
plst_2=replace_all("".join(case_dict["plaintiff"]),key_lst,sp_key,1).split(sp_key) |
|
dlst_2=replace_all("".join(case_dict["defendant"]),key_lst,sp_key,1).split(sp_key) |
|
v_plst_2=[_embedder.encode(_e) for _e in plst_2] |
|
v_dlst_2=[_embedder.encode(_e) for _e in dlst_2] |
|
|
|
print(clst_2) |
|
|
|
rt_lst=[] |
|
for i in tqdm(the_pool): |
|
lst_1=[_e for _e in corpus_dict[i]] |
|
id_lst_1=[id_lst[sen_lst.index(_e)] for _e in lst_1] |
|
vec_lst_1=[vec_lst[sen_lst.index(_e)] for _e in lst_1] |
|
clst_1=[corpus_clust_label[_e] for _e in id_lst_1] |
|
|
|
inset=sorted([_e for _e in set(clst_1)&set(clst_2) if _e!=-1]) |
|
temp_ot={} |
|
if len(inset)>=max(1,inset_th): |
|
temp_ot["target"]=target_name |
|
temp_ot["inset"]=inset |
|
|
|
_img=img_resize(vec2img(vec_lst_1,clst_1,vec_lst_2,clst_2,clust_th),cnn_len) |
|
cnn_pred=cnn_model.predict(np.array([_img])/255) |
|
|
|
_con1,_con2=[],[] |
|
for tp_i in range(bilstm_len): |
|
if len(lst_1)>tp_i: |
|
_con1.append(vec_lst_1[tp_i]) |
|
else: |
|
_con1.append([0]*emb_dim) |
|
for tp_i in range(bilstm_len): |
|
if len(lst_2)>tp_i: |
|
_con2.append(vec_lst_2[tp_i]) |
|
else: |
|
_con2.append([0]*emb_dim) |
|
_con1=np.array([_con1]) |
|
_con2=np.array([_con2]) |
|
print(len(_con1),len(_con2),len(_con2[0])) |
|
|
|
|
|
bilstm_pred=bilstm_model.predict([_con1,_con2]) |
|
|
|
|
|
temp_ot["cnn_pred"]=float(cnn_pred[0][0]) |
|
temp_ot["bilstm_pred"]=float(bilstm_pred[0][0]) |
|
|
|
|
|
x_e=[[bilstm_pred[0][0],cnn_pred[0][0]]] |
|
ensemble_pred=logistic(x_r,y_r,x_e) |
|
temp_ot["ensemble_pred"]=float(ensemble_pred[0]) |
|
|
|
|
|
pre_lst_1=[[color_lst[inset.index(clst_1[_e]) % len(color_lst)],Fore.WHITE,lst_1[_e],Style.RESET_ALL] if clst_1[_e] in inset else [Style.RESET_ALL,lst_1[_e]] for _e in range(len(lst_1))] |
|
pre_lst_2=[[color_lst[inset.index(clst_2[_e]) % len(color_lst)],Fore.WHITE,lst_2[_e],Style.RESET_ALL] if clst_2[_e] in inset else [Style.RESET_ALL,lst_2[_e]] for _e in range(len(lst_2))] |
|
|
|
vlst_1=[[vec_lst_1[_e],pre_lst_1[_e][0]] for _e in range(len(pre_lst_1)) if len(pre_lst_1[_e])==4] |
|
vlst_2=[[vec_lst_2[_e],pre_lst_2[_e][0]] for _e in range(len(pre_lst_2)) if len(pre_lst_2[_e])==4] |
|
|
|
|
|
|
|
plst_1=replace_all("".join(corpus_pd_f[i.replace("_",",")][0]),key_lst,sp_key,1).split(sp_key) |
|
|
|
dlst_1=replace_all("".join(corpus_pd_f[i.replace("_",",")][1]),key_lst,sp_key,1).split(sp_key) |
|
|
|
v_plst_1=[_embedder.encode(_e) for _e in plst_1] |
|
|
|
v_dlst_1=[_embedder.encode(_e) for _e in dlst_1] |
|
|
|
|
|
cs_p1=[max([[cos_sim(_e,_v[0]),_v[-1]] for _v in vlst_1]) for _e in v_plst_1] |
|
cs_d1=[max([[cos_sim(_e,_v[0]),_v[-1]] for _v in vlst_1]) for _e in v_dlst_1] |
|
|
|
cs_p2=[max([[cos_sim(_e,_v[0]),_v[-1]] for _v in vlst_2]) for _e in v_plst_2] |
|
cs_d2=[max([[cos_sim(_e,_v[0]),_v[-1]] for _v in vlst_2]) for _e in v_dlst_2] |
|
|
|
pre_lst_p1=[[cs_p1[_e][-1],Fore.WHITE,plst_1[_e],Style.RESET_ALL] if cs_p1[_e][0]>_th else [Style.RESET_ALL,plst_1[_e]] for _e in range(len(cs_p1))] |
|
pre_lst_d1=[[cs_d1[_e][-1],Fore.WHITE,dlst_1[_e],Style.RESET_ALL] if cs_d1[_e][0]>_th else [Style.RESET_ALL,dlst_1[_e]] for _e in range(len(cs_d1))] |
|
|
|
pre_lst_p2=[[cs_p2[_e][-1],Fore.WHITE,plst_2[_e],Style.RESET_ALL] if cs_p2[_e][0]>_th else [Style.RESET_ALL,plst_2[_e]] for _e in range(len(cs_p2))] |
|
pre_lst_d2=[[cs_d2[_e][-1],Fore.WHITE,dlst_2[_e],Style.RESET_ALL] if cs_d2[_e][0]>_th else [Style.RESET_ALL,dlst_2[_e]] for _e in range(len(cs_d2))] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
draw_lst_1=["".join(_e) for _e in pre_lst_1] |
|
draw_lst_2=["".join(_e) for _e in pre_lst_2] |
|
|
|
draw_lst_p1=["".join(_e) for _e in pre_lst_p1] |
|
draw_lst_p2=["".join(_e) for _e in pre_lst_p2] |
|
draw_lst_d1=["".join(_e) for _e in pre_lst_d1] |
|
draw_lst_d2=["".join(_e) for _e in pre_lst_d2] |
|
|
|
|
|
|
|
tp_str="" |
|
|
|
|
|
|
|
temp_ot["case_id"]=i |
|
temp_ot["plaintiff_case1"]=[{"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[_e[1]],"content":_e[-2]} if len(_e)==4 else {"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[Fore.WHITE],"content":_e[-1]} for _e in pre_lst_p1] |
|
temp_ot["defendant_case1"]=[{"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[_e[1]],"content":_e[-2]} if len(_e)==4 else {"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[Fore.WHITE],"content":_e[-1]} for _e in pre_lst_d1] |
|
temp_ot["dispute_case1"]=[{"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[_e[1]],"content":_e[-2]} if len(_e)==4 else {"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[Fore.WHITE],"content":_e[-1]} for _e in pre_lst_1] |
|
temp_ot["plaintiff_case2"]=[{"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[_e[1]],"content":_e[-2]} if len(_e)==4 else {"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[Fore.WHITE],"content":_e[-1]} for _e in pre_lst_p2] |
|
temp_ot["defendant_case2"]=[{"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[_e[1]],"content":_e[-2]} if len(_e)==4 else {"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[Fore.WHITE],"content":_e[-1]} for _e in pre_lst_d2] |
|
temp_ot["dispute_case2"]=[{"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[_e[1]],"content":_e[-2]} if len(_e)==4 else {"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[Fore.WHITE],"content":_e[-1]} for _e in pre_lst_2] |
|
|
|
tp_str+=Fore.BLUE+str(i)+Style.RESET_ALL+"\n" |
|
tp_str+=(Fore.GREEN if temp_ot["ensemble_pred"]>=0.75 else Fore.YELLOW if temp_ot["ensemble_pred"]>=0.5 else Fore.RED)+str(temp_ot["ensemble_pred"])+Style.RESET_ALL+"\n" |
|
tp_str+=Fore.MAGENTA+"---plaintiff_case1---"+Style.RESET_ALL+"\n" |
|
tp_str+="".join(draw_lst_p1)+Style.RESET_ALL+"\n" |
|
|
|
tp_str+=Fore.MAGENTA+"---defendant_case1---"+Style.RESET_ALL+"\n" |
|
tp_str+="".join(draw_lst_d1)+Style.RESET_ALL+"\n" |
|
|
|
tp_str+=Fore.MAGENTA+"---dispute_case1---"+Style.RESET_ALL+"\n" |
|
tp_str+="".join(draw_lst_1)+Style.RESET_ALL+"\n" |
|
|
|
tp_str+=Fore.BLUE+"target"+Style.RESET_ALL+"\n" |
|
|
|
tp_str+=Fore.MAGENTA+"---plaintiff_case2---"+Style.RESET_ALL+"\n" |
|
tp_str+="".join(draw_lst_p2)+Style.RESET_ALL+"\n" |
|
|
|
tp_str+=Fore.MAGENTA+"---defendant_case2---"+Style.RESET_ALL+"\n" |
|
tp_str+="".join(draw_lst_d2)+Style.RESET_ALL+"\n" |
|
|
|
tp_str+=Fore.MAGENTA+"---dispute_case2---"+Style.RESET_ALL+"\n" |
|
tp_str+="".join(draw_lst_2)+Style.RESET_ALL+"\n" |
|
|
|
|
|
|
|
|
|
|
|
temp_ot["output"]=tp_str |
|
rt_lst.append(temp_ot) |
|
print(tp_str) |
|
ot=sorted(rt_lst,key=lambda x:x["ensemble_pred"],reverse=True) |
|
ot_lst=[i["output"] for i in ot[:sug_th]] |
|
|
|
for i in ot[:sug_th]: |
|
file=open("./json_file/"+str(target_name).replace(",","_")+"&"+str(i["case_id"])+".json","w",encoding='utf8') |
|
json.dump({_e:i[_e] for _e in i if _e!="output"},file,indent=4,ensure_ascii=False) |
|
file.close() |
|
return ot_lst,ot[:sug_th] |
|
def suggesting(the_pool,target_name,case_dict): |
|
global ANSI_COLORS,_th,c_th,sug_th,corpus_dict,corpus_pd_f,vec_lst,id_lst,sen_lst,corpus_clust_label,_cluster_core_dict,_embedder |
|
global bilstm_len,cnn_len,emb_dim,inset_th,clust_th |
|
lst_2=[_e for _e in case_dict["p_point"]][:bilstm_len] |
|
|
|
|
|
|
|
|
|
vec_lst_2=[_embedder.encode(_e) for _e in lst_2] |
|
|
|
clst_2=[clust_search(_cluster_core_dict,_e,clust_th) for _e in vec_lst_2] |
|
plst_2=replace_all("".join(case_dict["plaintiff"]),key_lst,sp_key,1).split(sp_key) |
|
|
|
v_plst_2=[_embedder.encode(_e) for _e in plst_2] |
|
|
|
|
|
print(clst_2) |
|
|
|
rt_lst=[] |
|
for i in tqdm(the_pool): |
|
if target_name==i: |
|
continue |
|
lst_1=[_e for _e in corpus_dict[i]] |
|
id_lst_1=[id_lst[sen_lst.index(_e)] for _e in lst_1] |
|
vec_lst_1=[vec_lst[sen_lst.index(_e)] for _e in lst_1] |
|
clst_1=[corpus_clust_label[_e] for _e in id_lst_1] |
|
|
|
inset=sorted([_e for _e in set(clst_1)&set(clst_2) if _e!=-1]) |
|
temp_ot={} |
|
if len(inset)>=max(1,inset_th): |
|
temp_ot["target"]=target_name |
|
temp_ot["inset"]=inset |
|
|
|
_img=img_resize(vec2img(vec_lst_1,clst_1,vec_lst_2,clst_2,clust_th),cnn_len) |
|
cnn_pred=cnn_model.predict(np.array([_img])/255) |
|
|
|
_con1,_con2=[],[] |
|
for tp_i in range(bilstm_len): |
|
if len(lst_1)>tp_i: |
|
_con1.append(vec_lst_1[tp_i]) |
|
else: |
|
_con1.append([0]*emb_dim) |
|
for tp_i in range(bilstm_len): |
|
if len(lst_2)>tp_i: |
|
_con2.append(vec_lst_2[tp_i]) |
|
else: |
|
_con2.append([0]*emb_dim) |
|
_con1=np.array([_con1]) |
|
_con2=np.array([_con2]) |
|
print(len(_con1),len(_con2),len(_con2[0])) |
|
|
|
|
|
bilstm_pred=bilstm_model.predict([_con1,_con2]) |
|
temp_ot["cnn_pred"]=float(cnn_pred[0][0]) |
|
temp_ot["bilstm_pred"]=float(bilstm_pred[0][0]) |
|
|
|
|
|
x_e=[[bilstm_pred[0][0],cnn_pred[0][0]]] |
|
ensemble_pred=logistic(x_r,y_r,x_e) |
|
temp_ot["ensemble_pred"]=float(ensemble_pred[0]) |
|
|
|
|
|
pre_lst_1=[[color_lst[inset.index(clst_1[_e]) % len(color_lst)],Fore.WHITE,lst_1[_e],Style.RESET_ALL] if clst_1[_e] in inset else [Style.RESET_ALL,lst_1[_e]] for _e in range(len(lst_1))] |
|
pre_lst_2=[[color_lst[inset.index(clst_2[_e]) % len(color_lst)],Fore.WHITE,lst_2[_e],Style.RESET_ALL] if clst_2[_e] in inset else [Style.RESET_ALL,lst_2[_e]] for _e in range(len(lst_2))] |
|
|
|
vlst_1=[[vec_lst_1[_e],pre_lst_1[_e][0]] for _e in range(len(pre_lst_1)) if len(pre_lst_1[_e])==4] |
|
vlst_2=[[vec_lst_2[_e],pre_lst_2[_e][0]] for _e in range(len(pre_lst_2)) if len(pre_lst_2[_e])==4] |
|
|
|
|
|
|
|
plst_1=replace_all("".join(corpus_pd_f[i.replace("_",",")][0]),key_lst,sp_key,1).split(sp_key) |
|
|
|
|
|
v_plst_1=[_embedder.encode(_e) for _e in plst_1] |
|
|
|
|
|
|
|
cs_p1=[max([[cos_sim(_e,_v[0]),_v[-1]] for _v in vlst_1]) for _e in v_plst_1] |
|
|
|
cs_p2=[max([[cos_sim(_e,_v[0]),_v[-1]] for _v in vlst_2]) for _e in v_plst_2] |
|
|
|
pre_lst_p1=[[cs_p1[_e][-1],Fore.WHITE,plst_1[_e],Style.RESET_ALL] if cs_p1[_e][0]>_th else [Style.RESET_ALL,plst_1[_e]] for _e in range(len(cs_p1))] |
|
|
|
|
|
pre_lst_p2=[[cs_p2[_e][-1],Fore.WHITE,plst_2[_e],Style.RESET_ALL] if cs_p2[_e][0]>_th else [Style.RESET_ALL,plst_2[_e]] for _e in range(len(cs_p2))] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
draw_lst_1=["".join(_e) for _e in pre_lst_1] |
|
draw_lst_2=["".join(_e) for _e in pre_lst_2] |
|
|
|
draw_lst_p1=["".join(_e) for _e in pre_lst_p1] |
|
draw_lst_p2=["".join(_e) for _e in pre_lst_p2] |
|
|
|
|
|
|
|
|
|
tp_str="" |
|
|
|
|
|
|
|
temp_ot["case_id"]=i |
|
temp_ot["plaintiff_case1"]=[{"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[_e[1]],"content":_e[-2]} if len(_e)==4 else {"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[Fore.WHITE],"content":_e[-1]} for _e in pre_lst_p1] |
|
temp_ot["p_point_case1"]=[{"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[_e[1]],"content":_e[-2]} if len(_e)==4 else {"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[Fore.WHITE],"content":_e[-1]} for _e in pre_lst_1] |
|
temp_ot["plaintiff_case2"]=[{"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[_e[1]],"content":_e[-2]} if len(_e)==4 else {"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[Fore.WHITE],"content":_e[-1]} for _e in pre_lst_p2] |
|
temp_ot["p_point_case2"]=[{"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[_e[1]],"content":_e[-2]} if len(_e)==4 else {"background_color":ANSI_COLORS[_e[0]],"font_color":ANSI_COLORS[Fore.WHITE],"content":_e[-1]} for _e in pre_lst_2] |
|
|
|
tp_str+=Fore.BLUE+str(i)+Style.RESET_ALL+"\n" |
|
tp_str+=(Fore.GREEN if temp_ot["ensemble_pred"]>=0.75 else Fore.YELLOW if temp_ot["ensemble_pred"]>=0.5 else Fore.RED)+str(temp_ot["ensemble_pred"])+Style.RESET_ALL+"\n" |
|
tp_str+=Fore.MAGENTA+"---plaintiff_case1---"+Style.RESET_ALL+"\n" |
|
tp_str+="".join(draw_lst_p1)+Style.RESET_ALL+"\n" |
|
|
|
|
|
tp_str+=Fore.MAGENTA+"---p_point_case1---"+Style.RESET_ALL+"\n" |
|
tp_str+="".join(draw_lst_1)+Style.RESET_ALL+"\n" |
|
|
|
tp_str+=Fore.BLUE+"target"+Style.RESET_ALL+"\n" |
|
|
|
tp_str+=Fore.MAGENTA+"---plaintiff_case2---"+Style.RESET_ALL+"\n" |
|
tp_str+="".join(draw_lst_p2)+Style.RESET_ALL+"\n" |
|
|
|
|
|
tp_str+=Fore.MAGENTA+"---p_point_case2---"+Style.RESET_ALL+"\n" |
|
tp_str+="".join(draw_lst_2)+Style.RESET_ALL+"\n" |
|
|
|
|
|
|
|
|
|
|
|
temp_ot["output"]=tp_str |
|
rt_lst.append(temp_ot) |
|
print(tp_str) |
|
ot=sorted(rt_lst,key=lambda x:x["ensemble_pred"],reverse=True) |
|
ot_lst=[i["output"] for i in ot[:sug_th]] |
|
|
|
for i in ot[:sug_th]: |
|
file=open("./json_file/"+str(target_name).replace(",","_")+"&"+str(i["case_id"])+".json","w",encoding='utf8') |
|
json.dump({_e:i[_e] for _e in i if _e!="output"},file,indent=4,ensure_ascii=False) |
|
file.close() |
|
return ot_lst,ot[:sug_th] |
|
|
|
|
|
|
|
_dir_lst=["../gpt4_0409_p_3/","../taide_llama3_8b_3/"] |
|
_dir=_dir_lst[0] |
|
sp_key="@" |
|
emb_model="ftrob" |
|
emb_model_path={\ |
|
"lf":"thunlp/Lawformer",\ |
|
"rob":'hfl/chinese-roberta-wwm-ext-large',\ |
|
"ftlf":"./sbert_pretrained_model/training-lawformer-clause_th10_100k_task-bs100-e2-2023-10-28/", |
|
"ftrob":"./sbert_pretrained_model/training-roberta-clause_th10_100k_task-bs100-e2-2023-10-27",\ |
|
} |
|
|
|
color_lst=[Back.BLUE,Back.GREEN,Back.MAGENTA,Back.YELLOW,Back.RED,Back.CYAN] |
|
|
|
|
|
log_f=json.load(open("./src/plaintiff_logistic_features.json","r"))["BiLSTM_CNN"] |
|
x_r=np.array(log_f)[:,:-1] |
|
y_r=np.array(log_f)[:,-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if sug_type=="plaintiff": |
|
pd_f=corpus_pd_f=json.load(open("./src/corpus3835_raw.json","r"))["claim"] |
|
s_f=json.load(open("./src/plaintiff_corpus3835_sen.json","r")) |
|
v_f=json.load(open("./src/plaintiff_corpus3835_vec.json","r")) |
|
|
|
o_c_f=json.load(open("./src/plaintiff_corpus3835_cluster.json","r"))["clusters"] |
|
c_f=clust_2_dict(o_c_f) |
|
t_f=json.load(open("./src/plaintiff_ter.json","r")) |
|
if pool_type=="corpus3835": |
|
corpus_clust_label=clust_label(o_c_f) |
|
|
|
vec_lst=v_f["vector"] |
|
id_lst=v_f["id"] |
|
sen_lst=s_f["sentence"] |
|
|
|
corpus_dict={} |
|
for i in range(len(id_lst)): |
|
fid=id_lst[i].split("@")[0] |
|
if fid not in corpus_dict: |
|
corpus_dict[fid]=[sen_lst[i]] |
|
else: |
|
corpus_dict[fid].append(sen_lst[i]) |
|
corpus_pd_f=json.load(open("./src/corpus3835_raw.json","r"))["claim"] |
|
else: |
|
vec_f=json.load(open("./src/plaintiff_2022~2023_vec.json","r")) |
|
vec_lst=[_e for i in vec_f for _e in vec_f[i]] |
|
|
|
|
|
corpus_dict=json.load(open("./src/plaintiff_2022~2023_raw.json","r")) |
|
corpus_pd_f=json.load(open("./src/2022~2023_raw.json","r"))["claim"] |
|
corpus_clust_f=json.load(open("./src/plaintiff_2022~2023_clust.json","r")) |
|
|
|
sen_lst=[_e for i in corpus_dict for _e in corpus_dict[i]] |
|
id_lst=[i+"@"+str(_e) for i in corpus_dict for _e in range(len(corpus_dict[i]))] |
|
corpus_clust_label={_e:corpus_clust_f[_e[:_e.find("@")]][int(_e[_e.find("@")+1:])] for _e in id_lst} |
|
|
|
elif sug_type=="dispute": |
|
pd_f=corpus_pd_f=json.load(open("./src/corpus3835_raw_dis.json","r"))["claim"] |
|
s_f=json.load(open("./src/dispute_corpus3835_sen.json","r")) |
|
v_f=json.load(open("./src/dispute_corpus3835_vec.json","r")) |
|
|
|
o_c_f=json.load(open("./src/dispute_corpus3835_cluster.json","r"))["clusters"] |
|
c_f=clust_2_dict(o_c_f) |
|
t_f=json.load(open("./src/dispute_ter.json","r")) |
|
if pool_type=="corpus3835": |
|
corpus_clust_label=clust_label(o_c_f) |
|
|
|
vec_lst=v_f["vector"] |
|
id_lst=v_f["id"] |
|
sen_lst=s_f["sentence"] |
|
|
|
corpus_dict={} |
|
for i in range(len(id_lst)): |
|
fid=id_lst[i].split("@")[0] |
|
if fid not in corpus_dict: |
|
corpus_dict[fid]=[sen_lst[i]] |
|
else: |
|
corpus_dict[fid].append(sen_lst[i]) |
|
corpus_pd_f=json.load(open("./src/corpus3835_raw_dis.json","r"))["claim"] |
|
else: |
|
vec_f=json.load(open("./src/dispute_2022~2023_vec.json","r")) |
|
vec_lst=[_e for i in vec_f for _e in vec_f[i]] |
|
|
|
|
|
corpus_dict=json.load(open("./src/dispute_2022~2023_raw.json","r")) |
|
corpus_pd_f=json.load(open("./src/new22_23_3k3_corpus_raw.json","r"))["claim"] |
|
corpus_clust_f=json.load(open("./src/dispute_22~23_clust.json","r")) |
|
|
|
sen_lst=[_e for i in corpus_dict for _e in corpus_dict[i]] |
|
id_lst=[i+"@"+str(_e) for i in corpus_dict for _e in range(len(corpus_dict[i]))] |
|
corpus_clust_label={_e:corpus_clust_f[_e[:_e.find("@")]][int(_e[_e.find("@")+1:])] for _e in id_lst} |
|
|
|
|
|
|
|
new_point_f=lst_2_dict(jl("./src/gpt-4-turbo-0409-0.3-new22_23.jsonl")) |
|
new_pd_f=json.load(open("./src/new22_23_3k3_corpus_raw.json","r"))["claim"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key_lst=[",","。","?","?","!","!",";",":",";",":"] |
|
|
|
|
|
_embedder = SentenceTransformer(emb_model_path[emb_model]) |
|
cnn_model =... |
|
bilstm_model =... |
|
|
|
"""#fifo |
|
cnn_load() |
|
bilstm_load() |
|
""" |
|
cnn_load("/cpu:0") |
|
bilstm_load("/cpu:0") |
|
|
|
|
|
|
|
|
|
_cluster_core_dict=clust_core(o_c_f,v_f["vector"],v_f["id"],"central") |
|
|
|
|
|
from colorama import Fore,Style,Back |
|
|
|
import gradio as gr |
|
|
|
def case_sug_dis(file_name,plaintiff,defendant,p_point,d_point,dispute_list): |
|
global new_pd_f,new_point_f,corpus_dict |
|
|
|
|
|
|
|
|
|
if file_name not in new_pd_f: |
|
print("file not found") |
|
file_name="user_input" |
|
else: |
|
plaintiff=new_pd_f[file_name][0] |
|
defendant=new_pd_f[file_name][1] |
|
p_point=new_point_f[file_name][0] |
|
d_point=new_point_f[file_name][1] |
|
dispute_list=new_point_f[file_name][2] |
|
|
|
global sug_th |
|
|
|
|
|
p_point="。".split(p_point) if type(p_point)==type("111") else p_point |
|
d_point="。".split(d_point) if type(d_point)==type("111") else d_point |
|
dispute_list="。".split(dispute_list) if type(dispute_list)==type("111") else dispute_list |
|
_pool=[i for i in corpus_dict] |
|
_case_dict={"plaintiff":plaintiff,"defendant":defendant,"p_point":p_point,"d_point":d_point,"dispute":dispute_list} |
|
ot,ot_dict=suggesting_dis(_pool,file_name,_case_dict) |
|
|
|
|
|
dispute="\n".join(dispute_list) |
|
|
|
output_list=[] |
|
print("-----") |
|
print(len(ot_dict)) |
|
out_path="./out_of_range.html" |
|
for i in range(sug_th): |
|
if i<len(ot_dict): |
|
_path="./html_file/test"+str(i)+".html" |
|
output_html=ansi_to_html_dis(ot_dict[i],_path) |
|
|
|
output_list.append(_path) |
|
else: |
|
output_list.append(out_path) |
|
return output_list |
|
def case_sug(file_name,plaintiff,p_point): |
|
global new_pd_f,new_point_f,corpus_dict |
|
|
|
print(file_name) |
|
|
|
|
|
if file_name not in new_pd_f: |
|
print("file not found") |
|
file_name="user_input" |
|
else: |
|
plaintiff=new_pd_f[file_name][0] |
|
p_point=new_point_f[file_name][0] |
|
|
|
|
|
global sug_th |
|
|
|
p_point=p_point.split("。") if type(p_point)==type("111") else p_point |
|
_pool=[i for i in corpus_dict] |
|
_case_dict={"plaintiff":plaintiff,"p_point":p_point} |
|
print(_case_dict,[type(_case_dict[_e]) for _e in _case_dict]) |
|
ot,ot_dict=suggesting(_pool,file_name,_case_dict) |
|
|
|
|
|
|
|
|
|
output_list=[] |
|
print("-----") |
|
print(len(ot_dict)) |
|
out_path="./out_of_range.html" |
|
for i in range(sug_th): |
|
if i<len(ot_dict): |
|
_path="./html_file/test"+str(i)+".html" |
|
output_html=ansi_to_html(ot_dict[i],_path) |
|
|
|
output_list.append(_path) |
|
else: |
|
output_list.append(out_path) |
|
return output_list |
|
if sug_type=="plaintiff": |
|
demo = gr.Interface(fn=case_sug, inputs=["text","text","text"], outputs=[gr.outputs.File() for i in range(sug_th)]) |
|
demo.launch(share=True,server_port=4096,show_error=True) |
|
elif sug_type=="dispute": |
|
demo = gr.Interface(fn=case_sug_dis, inputs=["text","text","text","text","text","text"], outputs=[gr.outputs.File() for i in range(sug_th)]) |
|
demo.launch(share=True,server_port=2048,show_error=True) |
|
|