import streamlit as st
import os
import json
from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer, BertModel,T5Tokenizer, T5ForConditionalGeneration,AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords
def is_new_file_upload(uploaded_file):
if 'last_uploaded_file' in st.session_state:
# Check if the newly uploaded file is different from the last one
if (uploaded_file.name != st.session_state.last_uploaded_file['name'] or
uploaded_file.size != st.session_state.last_uploaded_file['size']):
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size}
# st.write("A new src image file has been uploaded.")
return True
else:
# st.write("The same src image file has been re-uploaded.")
return False
else:
# st.write("This is the first file upload detected.")
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size}
return True
def combined_similarity(similarity, sentence, query):
# Tokenize both the sentence and the query
# sentence_words = set(sentence.split())
# query_words = set(query.split())
sentence_words = set(word for word in sentence.split() if word.lower() not in st.session_state.stop_words)
query_words = set(word for word in query.split() if word.lower() not in st.session_state.stop_words)
# Calculate the number of common words
common_words = len(sentence_words.intersection(query_words))
# Adjust the similarity score with the common words count
combined_score = similarity + (common_words / max(len(query_words), 1)) # Normalize by the length of the query to keep the score between -1 and 1
return combined_score,similarity,(common_words / max(len(query_words), 1))
def paraphrase(sentence):
text = "paraphrase: " + sentence + " "
encoding = st.session_state.paraphrase_tokenizer.encode_plus(text,pad_to_max_length=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"].to("cuda"), encoding["attention_mask"].to("cuda")
outputs = st.session_state.paraphrase_model.generate(
input_ids=input_ids, attention_mask=attention_masks,
max_length=256,
do_sample=True,
top_k=120,
top_p=0.95,
#early_stopping=True,
early_stopping=False,
#num_return_sequences=5,
num_return_sequences=1,
repetition_penalty=1.5
)
# print(f"outputs = {outputs}")
results=[]
for output in outputs:
print("*")
line = st.session_state.paraphrase_tokenizer.decode(output, skip_special_tokens=True,clean_up_tokenization_spaces=True)
#results.append(line)
return line
if 'is_initialized' not in st.session_state:
st.session_state['is_initialized'] = True
nltk.download('punkt')
nltk.download('stopwords')
st.session_state.stop_words = set(stopwords.words('english'))
st.session_state.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", )
st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
st.session_state.paraphrase_tokenizer = AutoTokenizer.from_pretrained("Vamsi/T5_Paraphrase_Paws")
st.session_state.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("Vamsi/T5_Paraphrase_Paws").to('cuda')
print(str(st.session_state.paraphrase_model ))
if 'list_count' in st.session_state:
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
if 'paragraph_sentence_encodings' not in st.session_state:
print("start embedding paragarphs")
read_progress_bar = st.progress(0)
st.session_state.paragraph_sentence_encodings = []
for index,paragraph in enumerate(st.session_state.restored_paragraphs):
#print(paragraph)
progress_percentage = (index) / (st.session_state.list_count - 1)
# print(progress_percentage)
read_progress_bar.progress(progress_percentage)
sentence_encodings = []
sentences = sent_tokenize(paragraph['text'])
for sentence in sentences:
if sentence.strip().endswith('?'):
sentence_encodings.append(None)
continue
if len(sentence.strip()) < 4:
sentence_encodings.append(None)
continue
sentence_tokens = st.session_state.bert_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy()
sentence_encodings.append([sentence, sentence_encoding])
# sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()])
st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings])
st.rerun()
big_text = """
Knowledge Extraction A
"""
# Display the styled text
st.markdown(big_text, unsafe_allow_html=True)
uploaded_json_file = st.file_uploader("Upload a pre-processed file",
type=['json'])
st.markdown(
f'Sample 1 download and then upload to above',
unsafe_allow_html=True)
st.markdown("sample queries for above file:
What is death? What is a lucid dream? What is the seat of consciousness?",unsafe_allow_html=True)
st.markdown(
f'Sample 2 download and then upload to above',
unsafe_allow_html=True)
st.markdown("sample queries for above file:
what does nontechnical managers worry about? what if you put all the knowledge, frameworks, and tips from this book to full use? tell me about AI agent",unsafe_allow_html=True)
if uploaded_json_file is not None:
if is_new_file_upload(uploaded_json_file):
print("is new file uploaded")
if 'paraphrased_paragrpahs' in st.session_state:
del st.session_state['paraphrased_paragrpahs']
if 'prev_query' in st.session_state:
del st.session_state['prev_query']
if 'paragraph_sentence_encodings' in st.session_state:
del st.session_state['paragraph_sentence_encodings']
save_path = './uploaded_files'
if not os.path.exists(save_path):
os.makedirs(save_path)
with open(os.path.join(save_path, uploaded_json_file.name), "wb") as f:
f.write(uploaded_json_file.getbuffer()) # Write the file to the specified location
st.success(f'Saved file temp_{uploaded_json_file.name} in {save_path}')
st.session_state.uploaded_path=os.path.join(save_path, uploaded_json_file.name)
# st.session_state.page_count = utils.get_pdf_page_count(st.session_state.uploaded_pdf_path)
# print("page_count=",st.session_state.page_count)
content = uploaded_json_file.read()
try:
st.session_state.restored_paragraphs = json.loads(content)
#print(data)
# Check if the parsed data is a dictionary
if isinstance(st.session_state.restored_paragraphs, list):
# Count the restored_paragraphs of top-level elements
st.session_state.list_count = len(st.session_state.restored_paragraphs)
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
else:
st.write('The JSON content is not a dictionary.')
except json.JSONDecodeError:
st.write('Invalid JSON file.')
st.rerun()
if 'paragraph_sentence_encodings' in st.session_state:
query = st.text_input("Enter your query")
if query:
if 'prev_query' not in st.session_state or st.session_state.prev_query != query:
st.session_state.prev_query = query
if 'paraphrased_paragrpahs' in st.session_state:
del st.session_state['paraphrased_paragrpahs']
query_tokens = st.session_state.bert_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(
'cuda')
with torch.no_grad(): # Disable gradient calculation for inference
query_encoding = st.session_state.bert_model(**query_tokens).last_hidden_state[:, 0,
:].cpu().numpy() # Move the result to CPU and convert to NumPy
paragraph_scores = []
sentence_scores = []
total_count = len(st.session_state.paragraph_sentence_encodings)
processing_progress_bar = st.progress(0)
for index, paragraph_sentence_encoding in enumerate(st.session_state.paragraph_sentence_encodings):
progress_percentage = index / (total_count - 1)
processing_progress_bar.progress(progress_percentage)
sentence_similarities = []
for sentence_encoding in paragraph_sentence_encoding[1]:
if sentence_encoding:
similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0]
combined_score, similarity_score, commonality_score = combined_similarity(similarity,
sentence_encoding[0],
query)
sentence_similarities.append((combined_score, sentence_encoding[0], commonality_score))
sentence_scores.append((combined_score, sentence_encoding[0]))
sentence_similarities.sort(reverse=True, key=lambda x: x[0])
if len(sentence_similarities) >= 3:
top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities[:3]])
top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities[:3]])
top_three_sentences = sentence_similarities[:3]
elif sentence_similarities:
top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities])
top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities])
top_three_sentences = sentence_similarities
else:
top_three_avg_similarity = 0
top_three_avg_commonality = 0
top_three_sentences = []
top_three_texts = [s[1] for s in top_three_sentences]
remaining_texts = [s[0] for s in paragraph_sentence_encoding[1] if s and s[0] not in top_three_texts]
reordered_paragraph = top_three_texts + remaining_texts
original_paragraph = ' '.join([s[0] for s in paragraph_sentence_encoding[1] if s])
modified_paragraph = ' '.join(reordered_paragraph)
paragraph_scores.append(
(top_three_avg_similarity, top_three_avg_commonality,
{'modified_text': modified_paragraph, 'original_text': paragraph_sentence_encoding[0]})
)
sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True)
st.session_state.paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True)
if 'paragraph_scores' in st.session_state:
if "paraphrased_paragrpahs" not in st.session_state:
st.session_state.paraphrased_paragrpahs = []
processing_progress_bar=st.progress(0)
for i, (similarity_score, commonality_score, paragraph) in enumerate(st.session_state.paragraph_scores[:5]):
output_1 = paraphrase(paragraph['modified_text'])
# print(output_1)
output_2 = paraphrase(output_1)
# print(output_2)
st.session_state.paraphrased_paragrpahs.append(output_2)
processing_progress_bar.progress(i / (len(st.session_state.paragraph_scores[:5]) - 1))
st.write("Top scored paragraphs and their scores:")
for i, (similarity_score, commonality_score, paragraph) in enumerate(
st.session_state.paragraph_scores[:5]):
st.write("Paraphrased Paragraph: ", st.session_state.paraphrased_paragrpahs[i])
if st.button(f"Show Original Paragraph {i + 1}", key=f"button_{i}"):
st.write(f"Similarity Score: {similarity_score}, Commonality Score: {commonality_score}")
st.write("Original Paragraph: ", paragraph['original_text'])
# st.write("Modified Paragraph: ", paragraph['modified_text'])