import streamlit as st from predict import run_prediction from io import StringIO import json import spacy from spacy import displacy from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer,pipeline import torch import nltk from nltk.tokenize import sent_tokenize from fin_readability_sustainability import BERTClass, do_predict import pandas as pd nltk.download('punkt') nlp = spacy.load("en_core_web_sm") st.set_page_config(layout="wide") st.cache(show_spinner=False, persist=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #SUSTAIN STARTS tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base') model_sustain = BERTClass(2, "sustanability") model_sustain.to(device) model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict']) def get_sustainability(text): df = pd.DataFrame({'sentence':sent_tokenize(text)}) actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df) highlight = [] for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]): if prob>=4.384316: highlight.append((sent, 'non-sustainable')) elif prob<=1.423736: highlight.append((sent, 'sustainable')) else: highlight.append((sent, '-')) return highlight #SUSTAIN ENDS ##Summarization def summarize_text(text): summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY") resp = summarizer(text) stext = resp[0]['summary_text'] return stext ##Forward Looking Statement #def fls(text): # fls_model = pipeline("text-classification", model="yiyanghkust/finbert-fls", tokenizer="yiyanghkust/finbert-fls") # results = fls_model(split_in_sentences(text)) #return make_spans(text,results) ##Company Extraction #ner=pipeline('ner',model='Jean-Baptiste/camembert-ner-with-dates',tokenizer='Jean-Baptiste/camembert-ner-with-dates', aggregation_strategy="simple") #def fin_ner(text): #replaced_spans = ner(text) # return replaced_spans def load_questions(): questions = [] with open('questions.txt') as f: questions = f.readlines() return questions def load_questions_short(): questions_short = [] with open('questionshort.txt') as f: questions_short = f.readlines() return questions_short st.cache(show_spinner=False, persist=True) questions = load_questions() questions_short = load_questions_short() ### DEFINE SIDEBAR st.sidebar.title("Interactive Contract Analysis") st.sidebar.header('CONTRACT UPLOAD') # upload contract user_upload = st.sidebar.file_uploader('Please upload your contract', type=['txt'], accept_multiple_files=False) # process upload if user_upload is not None: print(user_upload.name, user_upload.type) extension = user_upload.name.split('.')[-1].lower() if extension == 'txt': print('text file uploaded') # To convert to a string based IO: stringio = StringIO(user_upload.getvalue().decode("utf-8")) # To read file as string: contract_data = stringio.read() else: st.warning('Unknown uploaded file type, please try again') results_drop = ['1', '2', '3'] number_results = st.sidebar.selectbox('Select number of results', results_drop) ### DEFINE MAIN PAGE st.header("Legal Contract Review Demo") paragraph = st.text_area(label="Contract", value=contract_data, height=300) questions_drop = questions_short question_short = st.selectbox('Choose one of the 41 queries from the CUAD dataset:', questions_drop) idxq = questions_drop.index(question_short) question = questions[idxq] raw_answer="" if st.button('Analyze'): if (not len(paragraph)==0) and not (len(question)==0): print('getting predictions') with st.spinner(text='Analysis in progress...'): predictions = run_prediction([question], paragraph, 'marshmellow77/roberta-base-cuad', n_best_size=5) answer = "" if predictions['0'] == "": answer = 'No answer found in document' else: # if number_results == '1': # answer = f"Answer: {predictions['0']}" # # st.text_area(label="Answer", value=f"{answer}") # else: answer = "" with open("nbest.json") as jf: data = json.load(jf) for i in range(int(number_results)): raw_answer=data['0'][i]['text'] answer += f"Answer {i+1}: {data['0'][i]['text']} -- \n" answer += f"Probability: {round(data['0'][i]['probability']*100,1)}%\n\n" st.success(answer) else: st.write("Unable to call model, please select question and contract") if st.button('Check Sustainability'): if(raw_answer==""): st.write("Unable to call model, please select question and contract") else: st.write(get_sustainability(raw_answer)) if st.button('Summarize'): if(raw_answer==""): st.write("Unable to call model, please select question and contract") else: st.write(summarize_text(raw_answer)) if st.button('NER'): if(raw_answer==""): st.write("Unable to call model, please select question and contract") else: doc = nlp(raw_answer) st.write(displacy.render(doc, style="ent"))