File size: 5,479 Bytes
a560ed2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import streamlit as st
from predict import run_prediction
from io import StringIO
import json
import spacy
from spacy import displacy
from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer,pipeline
import torch
import nltk
from nltk.tokenize import sent_tokenize
from fin_readability_sustainability import BERTClass, do_predict
import pandas as pd

nltk.download('punkt')
nlp = spacy.load("en_core_web_sm")

st.set_page_config(layout="wide")
st.cache(show_spinner=False, persist=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#SUSTAIN STARTS
tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base')
model_sustain = BERTClass(2, "sustanability")
model_sustain.to(device)
model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict'])


def get_sustainability(text):
  df = pd.DataFrame({'sentence':sent_tokenize(text)})
  actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df)
  highlight = []
  for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]):
    if prob>=4.384316:
      highlight.append((sent, 'non-sustainable'))
    elif prob<=1.423736:
      highlight.append((sent, 'sustainable'))
    else:
      highlight.append((sent, '-'))
  return highlight

#SUSTAIN ENDS

##Summarization 
def summarize_text(text):
    summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
    resp = summarizer(text)
    stext = resp[0]['summary_text']
    return stext

##Forward Looking Statement
#def fls(text):
   # fls_model = pipeline("text-classification", model="yiyanghkust/finbert-fls", tokenizer="yiyanghkust/finbert-fls")
  #  results = fls_model(split_in_sentences(text))
    #return make_spans(text,results) 
    
##Company Extraction
#ner=pipeline('ner',model='Jean-Baptiste/camembert-ner-with-dates',tokenizer='Jean-Baptiste/camembert-ner-with-dates', aggregation_strategy="simple")
#def fin_ner(text):
    #replaced_spans = ner(text)
   # return replaced_spans  
    
     
      
        
def load_questions():
    questions = []
    with open('questions.txt') as f:
        questions = f.readlines()
    return questions


def load_questions_short():
    questions_short = []
    with open('questionshort.txt') as f:
        questions_short = f.readlines()
    return questions_short


st.cache(show_spinner=False, persist=True)


questions = load_questions()
questions_short = load_questions_short()

### DEFINE SIDEBAR
st.sidebar.title("Interactive Contract Analysis")

st.sidebar.header('CONTRACT UPLOAD')

# upload contract
user_upload = st.sidebar.file_uploader('Please upload your contract', type=['txt'],
                                       accept_multiple_files=False)


# process upload
if user_upload is not None:
    print(user_upload.name, user_upload.type)
    extension = user_upload.name.split('.')[-1].lower()
    if extension == 'txt':
        print('text file uploaded')
         # To convert to a string based IO:
        stringio = StringIO(user_upload.getvalue().decode("utf-8"))

        # To read file as string:
        contract_data = stringio.read()
    else:
        st.warning('Unknown uploaded file type, please try again')

results_drop = ['1', '2', '3']
number_results = st.sidebar.selectbox('Select number of results', results_drop)

### DEFINE MAIN PAGE
st.header("Legal Contract Review Demo")
paragraph = st.text_area(label="Contract", value=contract_data, height=300)

questions_drop = questions_short
question_short = st.selectbox('Choose one of the 41 queries from the CUAD dataset:', questions_drop)
idxq = questions_drop.index(question_short)
question = questions[idxq]


raw_answer=""
if st.button('Analyze'):
    if (not len(paragraph)==0) and not (len(question)==0):
        print('getting predictions')
        with st.spinner(text='Analysis in progress...'):
            predictions = run_prediction([question], paragraph, 'marshmellow77/roberta-base-cuad',
                                         n_best_size=5)
        answer = ""
        if predictions['0'] == "":
            answer = 'No answer found in document'
        else:
            # if number_results == '1':
            #     answer = f"Answer: {predictions['0']}"
            #     # st.text_area(label="Answer", value=f"{answer}")
            # else:
            answer = ""
            with open("nbest.json") as jf:
                data = json.load(jf)
                for i in range(int(number_results)):
                    raw_answer=data['0'][i]['text']
                    answer += f"Answer {i+1}: {data['0'][i]['text']} -- \n"
                    answer += f"Probability: {round(data['0'][i]['probability']*100,1)}%\n\n"
        st.success(answer)

    else:
        st.write("Unable to call model, please select question and contract")
        
 if st.button('Check Sustainability'):
    if(raw_answer==""):
        st.write("Unable to call model, please select question and contract")
    else:
        st.write(get_sustainability(raw_answer))
if st.button('Summarize'):
    if(raw_answer==""):
        st.write("Unable to call model, please select question and contract")
    else:
        st.write(summarize_text(raw_answer))
        
if st.button('NER'):
    if(raw_answer==""):
        st.write("Unable to call model, please select question and contract")
    else:
        doc = nlp(raw_answer)
        st.write(displacy.render(doc, style="ent"))