Spaces:
Sleeping
Sleeping
from itertools import product | |
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
from PIL import Image, ImageOps | |
import time | |
from paddleocr import PaddleOCR | |
import os | |
# from dotenv import load_dotenv | |
import torch | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
# Load environment variables | |
# load_dotenv() | |
# huggingface_token = os.getenv("HF_TOKEN") | |
# Load TinyBERT model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("Intel/dynamic_tinybert") | |
model = AutoModelForQuestionAnswering.from_pretrained("Intel/dynamic_tinybert") | |
# Initialize PaddleOCR | |
ocr = PaddleOCR(use_angle_cls=True, lang='en') | |
# Team details | |
team_members = [ | |
{"name": "Aman Deep", "image": "aman.jpg"}, # Replace with actual paths to images | |
{"name": "Nandini", "image": "nandini.jpg"}, | |
{"name": "Abhay Sharma", "image": "abhay.jpg"}, | |
{"name": "Ratan Prakash Mishra", "image": "anandimg.jpg"} | |
] | |
# Function to preprocess images for the model | |
def preprocess_image(image): | |
""" | |
Preprocess the input image for model prediction. | |
Args: | |
image (PIL.Image): Input image in PIL format. | |
Returns: | |
np.ndarray: Preprocessed image array ready for prediction. | |
""" | |
try: | |
img = image.resize((128, 128), Image.LANCZOS) | |
img_array = np.array(img) | |
if img_array.ndim == 2: # Grayscale image | |
img_array = np.stack([img_array] * 3, axis=-1) | |
elif img_array.shape[2] == 1: # Single-channel image | |
img_array = np.concatenate([img_array, img_array, img_array], axis=-1) | |
img_array = img_array / 255.0 | |
img_array = np.expand_dims(img_array, axis=0) | |
return img_array | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
return None | |
# Function to perform Q&A with TinyBERT | |
def answer_question(context, question): | |
""" | |
Extract the answer to a question from the given context using TinyBERT. | |
Args: | |
context (str): The text to search for answers. | |
question (str): The question to answer. | |
Returns: | |
str: The extracted answer or an error message. | |
""" | |
try: | |
tokens = tokenizer.encode_plus(question, context, return_tensors="pt", truncation=True) | |
input_ids = tokens["input_ids"] | |
attention_mask = tokens["attention_mask"] | |
# Perform question answering | |
outputs = model(input_ids, attention_mask=attention_mask) | |
start_scores = outputs.start_logits | |
end_scores = outputs.end_logits | |
answer_start = torch.argmax(start_scores) | |
answer_end = torch.argmax(end_scores) + 1 | |
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[0][answer_start:answer_end])) | |
return answer | |
except Exception as e: | |
return f"Error: {e}" | |
# Function to display team members in circular format | |
def display_team_members(members, max_members_per_row=4): | |
num_members = len(members) | |
num_rows = (num_members + max_members_per_row - 1) // max_members_per_row | |
for i in range(num_rows): | |
cols = st.columns(min(max_members_per_row, num_members - i * max_members_per_row)) | |
for j, member in enumerate(members[i * max_members_per_row:(i + 1) * max_members_per_row]): | |
with cols[j]: | |
img = Image.open(member["image"]) | |
st.image(img, use_column_width=True) | |
st.write(member["name"]) | |
# Function to simulate loading process with a progress bar | |
def simulate_progress(): | |
progress_bar = st.progress(0) | |
for percent_complete in range(100): | |
time.sleep(0.02) | |
progress_bar.progress(percent_complete + 1) | |
# Title and description | |
st.title("Product Listing Assistant") | |
# Navbar with task tabs | |
st.sidebar.title("Navigation") | |
st.sidebar.write("Team Name: Sadhya") | |
app_mode = st.sidebar.selectbox("Choose the task", ["Welcome", "Project Details", "Team Details", "Extract Product Details"]) | |
if app_mode == "Welcome": | |
st.write("# Welcome to the Product Listing Assistant! 🎉") | |
elif app_mode == "Project Details": | |
st.write(""" | |
## Project Overview: | |
- Automates product listings from social media content. | |
- Extracts product details from posts using OCR and Q&A. | |
- Outputs structured, engaging, and optimized e-commerce listings. | |
""") | |
elif app_mode == "Team Details": | |
st.write("## Meet Our Team:") | |
display_team_members(team_members) | |
elif app_mode == "Extract Product Details": | |
st.write("## Extract Product Details Using OCR and Q&A") | |
post_url = st.text_input("Enter Post URL:") | |
uploaded_files = st.file_uploader("Upload Product Images", type=["jpeg", "png", "jpg"], accept_multiple_files=True) | |
user_question = st.text_input("Ask a question about the extracted details:") | |
if uploaded_files: | |
st.write("### Uploaded Images:") | |
simulate_progress() | |
for uploaded_image in uploaded_files: | |
image = Image.open(uploaded_image) | |
st.image(image, use_column_width=True) | |
simulate_progress() | |
# Perform OCR | |
st.write("Extracting text from image...") | |
result = ocr.ocr(np.array(image), cls=True) | |
extracted_text = " ".join([line[1][0] for line in result[0]]) | |
st.write("Extracted Text:") | |
st.text(extracted_text) | |
# Use Q&A model | |
if user_question: | |
st.write("### Answer to your question:") | |
answer = answer_question(extracted_text, user_question) | |
st.write(answer) | |