import streamlit as st
import os
import json
import fitz
import re
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForSequenceClassification, BertTokenizer, BertModel,T5Tokenizer, T5ForConditionalGeneration,AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords
def is_new_file_upload(uploaded_file):
if 'last_uploaded_file' in st.session_state:
# Check if the newly uploaded file is different from the last one
if (uploaded_file.name != st.session_state.last_uploaded_file['name'] or
uploaded_file.size != st.session_state.last_uploaded_file['size']):
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size}
# st.write("A new src image file has been uploaded.")
return True
else:
# st.write("The same src image file has been re-uploaded.")
return False
else:
# st.write("This is the first file upload detected.")
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size}
return True
def combined_similarity(similarity, sentence, query):
# Tokenize both the sentence and the query
# sentence_words = set(sentence.split())
# query_words = set(query.split())
sentence_words = set(word for word in sentence.split() if word.lower() not in st.session_state.stop_words)
query_words = set(word for word in query.split() if word.lower() not in st.session_state.stop_words)
# Calculate the number of common words
common_words = len(sentence_words.intersection(query_words))
# Adjust the similarity score with the common words count
combined_score = similarity + (common_words / max(len(query_words), 1)) # Normalize by the length of the query to keep the score between -1 and 1
return combined_score,similarity,(common_words / max(len(query_words), 1))
def contradiction_detection(premise,hypothesis):
inputs = st.session_state.roberta_tokenizer.encode_plus(premise, hypothesis, return_tensors="pt", truncation=True)
# Get model predictions
outputs = st.session_state.roberta_model(**inputs)
# Get the logits (raw predictions before softmax)
logits = outputs.logits
# Apply softmax to get probabilities for each class
probabilities = torch.softmax(logits, dim=1)
# Class labels: 0 = entailment, 1 = neutral, 2 = contradiction
predicted_class = torch.argmax(probabilities, dim=1).item()
# Class labels
labels = ["Contradiction", "Neutral", "Entailment"]
# Output the result
print(f"Prediction: {labels[predicted_class]}")
return {labels[predicted_class]}
if 'is_initialized' not in st.session_state:
st.session_state['is_initialized'] = True
nltk.download('punkt')
nltk.download('stopwords')
# print("stop words start")
# print(stopwords.words('english'))
# print("stop words end")
stop_words_list = stopwords.words('english')
st.session_state.stop_words = set(stop_words_list)
st.session_state.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", )
st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
st.session_state.roberta_tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
st.session_state.roberta_model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli")
if 'list_count' in st.session_state:
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
if 'paragraph_sentence_encodings' not in st.session_state:
print("start embedding paragarphs")
read_progress_bar = st.progress(0)
st.session_state.paragraph_sentence_encodings = []
for index,paragraph in enumerate(st.session_state.restored_paragraphs):
#print(paragraph)
progress_percentage = (index) / (st.session_state.list_count - 1)
# print(progress_percentage)
read_progress_bar.progress(progress_percentage)
sentence_encodings = []
paragraph_without_newline= paragraph['paragraph'].replace("\n", "")
sentences = sent_tokenize(paragraph_without_newline)
for sentence in sentences:
if sentence.strip().endswith('?'):
sentence_encodings.append(None)
continue
if len(sentence.strip()) < 4:
sentence_encodings.append(None)
continue
sentence_tokens = st.session_state.bert_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy()
sentence_encodings.append([sentence, sentence_encoding])
# sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()])
st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings])
st.rerun()
big_text = """
Knowledge Extraction A
"""
# Display the styled text
st.markdown(big_text, unsafe_allow_html=True)
uploaded_pdf_file = st.file_uploader("Upload a PDF file",
type=['pdf'])
st.markdown(
f'Sample 1 download and then upload to above',
unsafe_allow_html=True)
st.markdown("sample queries for above file:
What is death? What is a lucid dream? What is the seat of consciousness?",unsafe_allow_html=True)
st.markdown(
f'Sample 2 download and then upload to above',
unsafe_allow_html=True)
st.markdown("sample queries for above file:
what does nontechnical managers worry about? what if you put all the knowledge, frameworks, and tips from this book to full use? tell me about AI agent",unsafe_allow_html=True)
if uploaded_pdf_file is not None:
if is_new_file_upload(uploaded_pdf_file):
print("is new file uploaded")
if 'prev_query' in st.session_state:
del st.session_state['prev_query']
if 'paragraph_sentence_encodings' in st.session_state:
del st.session_state['paragraph_sentence_encodings']
save_path = './uploaded_files'
if not os.path.exists(save_path):
os.makedirs(save_path)
with open(os.path.join(save_path, uploaded_pdf_file.name), "wb") as f:
f.write(uploaded_pdf_file.getbuffer()) # Write the file to the specified location
st.success(f'Saved file temp_{uploaded_pdf_file.name} in {save_path}')
st.session_state.uploaded_path=os.path.join(save_path, uploaded_pdf_file.name)
# st.session_state.page_count = utils.get_pdf_page_count(st.session_state.uploaded_pdf_path)
# print("page_count=",st.session_state.page_count)
doc = fitz.open(st.session_state.uploaded_path)
sentence_endings = ('.', '!', '?')
start_page = 1
st.session_state.restored_paragraphs = []
for page_num in range(start_page - 1, len(doc)): # start_page - 1 to adjust for 0-based index
page = doc.load_page(page_num)
blocks = page.get_text("blocks")
block_index = 1
for block in blocks:
x0, y0, x1, y1, text, block_type, flags = block
if text.strip() != "":
text = text.strip()
text = re.sub(r'\n\s+\n', '\n\n', text)
list_pattern = re.compile(r'^\s*((?:\d+\.|[a-zA-Z]\.|[*-])\s+.+)', re.MULTILINE)
match = list_pattern.search(text)
containsList = False
if match:
containsList = True
# print ("list detected")
paragraph = ""
if bool(re.search(r'\n{2,}', text)):
substrings = re.split(r'\n{2,}', text)
for substring in substrings:
if substring.strip() != "":
paragraph = substring
st.session_state.restored_paragraphs.append(
{"paragraph": paragraph, "containsList": containsList, "page_num": page_num, "text": text});
# print(f" {substring} ")
else:
paragraph = text
st.session_state.restored_paragraphs.append(
{"paragraph": paragraph, "containsList": containsList, "page_num": page_num, "text": None});
if isinstance(st.session_state.restored_paragraphs, list):
# Count the restored_paragraphs of top-level elements
st.session_state.list_count = len(st.session_state.restored_paragraphs)
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count}')
st.rerun()
if 'paragraph_sentence_encodings' in st.session_state:
query = st.text_input("Enter your query")
if query:
if 'prev_query' not in st.session_state or st.session_state.prev_query != query:
st.session_state.prev_query = query
st.session_state.premise = query
query_tokens = st.session_state.bert_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(
'cuda')
with torch.no_grad(): # Disable gradient calculation for inference
query_encoding = st.session_state.bert_model(**query_tokens).last_hidden_state[:, 0,
:].cpu().numpy() # Move the result to CPU and convert to NumPy
paragraph_scores = []
sentence_scores = []
total_count = len(st.session_state.paragraph_sentence_encodings)
processing_progress_bar = st.progress(0)
for index, paragraph_sentence_encoding in enumerate(st.session_state.paragraph_sentence_encodings):
progress_percentage = index / (total_count - 1)
processing_progress_bar.progress(progress_percentage)
sentence_similarities = []
for sentence_encoding in paragraph_sentence_encoding[1]:
if sentence_encoding:
similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0]
combined_score, similarity_score, commonality_score = combined_similarity(similarity,
sentence_encoding[0],
query)
sentence_similarities.append((combined_score, sentence_encoding[0], commonality_score))
sentence_scores.append((combined_score, sentence_encoding[0]))
sentence_similarities.sort(reverse=True, key=lambda x: x[0])
# print(sentence_similarities)
if len(sentence_similarities) >= 3:
top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities[:3]])
top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities[:3]])
top_three_sentences = sentence_similarities[:3]
elif sentence_similarities:
top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities])
top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities])
top_three_sentences = sentence_similarities
else:
top_three_avg_similarity = 0
top_three_avg_commonality = 0
top_three_sentences = []
# print(f"top_three_sentences={top_three_sentences}")
# top_three_texts = [s[1] for s in top_three_sentences]
# remaining_texts = [s[0] for s in paragraph_sentence_encoding[1] if s and s[0] not in top_three_texts]
# reordered_paragraph = top_three_texts + remaining_texts
#
# original_paragraph = ' '.join([s[0] for s in paragraph_sentence_encoding[1] if s])
# modified_paragraph = ' '.join(reordered_paragraph)
paragraph_scores.append(
(top_three_avg_similarity, top_three_avg_commonality,
{'top_three_sentences': top_three_sentences, 'original_text': paragraph_sentence_encoding[0]})
)
sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True)
st.session_state.paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True)
if 'paragraph_scores' in st.session_state:
st.write("Top scored paragraphs and their scores:")
for i, (similarity_score, commonality_score, paragraph) in enumerate(
st.session_state.paragraph_scores[:5]):
#st.write("top_three_sentences: ", paragraph['top_three_sentences'])
for top_sentence in paragraph['top_three_sentences']:
st.write("hyppthesis: ", top_sentence[1])
st.write(contradiction_detection(st.session_state.premise,top_sentence[1]))
#print(top_sentence[1])
# st.write(f"Similarity Score: {similarity_score}, Commonality Score: {commonality_score}")
# st.write("top_three_sentences: ", paragraph['top_three_sentences'])
st.write("Original Paragraph: ", paragraph['original_text'])
#A Member will be considered Actively at Work if he or she is able and available for active performance of all of his or her regular duties
# st.write("Modified Paragraph: ", paragraph['modified_text'])