pdfProdigy / app.py
asdfaman's picture
Update app.py
44e1aeb verified
from itertools import product
import streamlit as st
import numpy as np
import pandas as pd
from PIL import Image, ImageOps
import time
from paddleocr import PaddleOCR
import os
# from dotenv import load_dotenv
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
# Load environment variables
# load_dotenv()
# huggingface_token = os.getenv("HF_TOKEN")
# Load TinyBERT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Intel/dynamic_tinybert")
model = AutoModelForQuestionAnswering.from_pretrained("Intel/dynamic_tinybert")
# Initialize PaddleOCR
ocr = PaddleOCR(use_angle_cls=True, lang='en')
# Team details
team_members = [
{"name": "Aman Deep", "image": "aman.jpg"}, # Replace with actual paths to images
{"name": "Nandini", "image": "nandini.jpg"},
{"name": "Abhay Sharma", "image": "abhay.jpg"},
{"name": "Ratan Prakash Mishra", "image": "anandimg.jpg"}
]
# Function to preprocess images for the model
def preprocess_image(image):
"""
Preprocess the input image for model prediction.
Args:
image (PIL.Image): Input image in PIL format.
Returns:
np.ndarray: Preprocessed image array ready for prediction.
"""
try:
img = image.resize((128, 128), Image.LANCZOS)
img_array = np.array(img)
if img_array.ndim == 2: # Grayscale image
img_array = np.stack([img_array] * 3, axis=-1)
elif img_array.shape[2] == 1: # Single-channel image
img_array = np.concatenate([img_array, img_array, img_array], axis=-1)
img_array = img_array / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
except Exception as e:
print(f"Error processing image: {e}")
return None
# Function to perform Q&A with TinyBERT
def answer_question(context, question):
"""
Extract the answer to a question from the given context using TinyBERT.
Args:
context (str): The text to search for answers.
question (str): The question to answer.
Returns:
str: The extracted answer or an error message.
"""
try:
tokens = tokenizer.encode_plus(question, context, return_tensors="pt", truncation=True)
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
# Perform question answering
outputs = model(input_ids, attention_mask=attention_mask)
start_scores = outputs.start_logits
end_scores = outputs.end_logits
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores) + 1
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[0][answer_start:answer_end]))
return answer
except Exception as e:
return f"Error: {e}"
# Function to display team members in circular format
def display_team_members(members, max_members_per_row=4):
num_members = len(members)
num_rows = (num_members + max_members_per_row - 1) // max_members_per_row
for i in range(num_rows):
cols = st.columns(min(max_members_per_row, num_members - i * max_members_per_row))
for j, member in enumerate(members[i * max_members_per_row:(i + 1) * max_members_per_row]):
with cols[j]:
img = Image.open(member["image"])
st.image(img, use_column_width=True)
st.write(member["name"])
# Function to simulate loading process with a progress bar
def simulate_progress():
progress_bar = st.progress(0)
for percent_complete in range(100):
time.sleep(0.02)
progress_bar.progress(percent_complete + 1)
# Title and description
st.title("Product Listing Assistant")
# Navbar with task tabs
st.sidebar.title("Navigation")
st.sidebar.write("Team Name: Sadhya")
app_mode = st.sidebar.selectbox("Choose the task", ["Welcome", "Project Details", "Team Details", "Extract Product Details"])
if app_mode == "Welcome":
st.write("# Welcome to the Product Listing Assistant! 🎉")
elif app_mode == "Project Details":
st.write("""
## Project Overview:
- Automates product listings from social media content.
- Extracts product details from posts using OCR and Q&A.
- Outputs structured, engaging, and optimized e-commerce listings.
""")
elif app_mode == "Team Details":
st.write("## Meet Our Team:")
display_team_members(team_members)
elif app_mode == "Extract Product Details":
st.write("## Extract Product Details Using OCR and Q&A")
post_url = st.text_input("Enter Post URL:")
uploaded_files = st.file_uploader("Upload Product Images", type=["jpeg", "png", "jpg"], accept_multiple_files=True)
user_question = st.text_input("Ask a question about the extracted details:")
if uploaded_files:
st.write("### Uploaded Images:")
simulate_progress()
for uploaded_image in uploaded_files:
image = Image.open(uploaded_image)
st.image(image, use_column_width=True)
simulate_progress()
# Perform OCR
st.write("Extracting text from image...")
result = ocr.ocr(np.array(image), cls=True)
extracted_text = " ".join([line[1][0] for line in result[0]])
st.write("Extracted Text:")
st.text(extracted_text)
# Use Q&A model
if user_question:
st.write("### Answer to your question:")
answer = answer_question(extracted_text, user_question)
st.write(answer)