Spaces:
Running
Running
import streamlit as st | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
from bs4 import BeautifulSoup | |
import pandas as pd | |
import base64 | |
def load_model(model_path): | |
interpreter = tf.lite.Interpreter(model_path=model_path) | |
interpreter.allocate_tensors() | |
return interpreter | |
def preprocess_image(image, input_size): | |
image = image.convert('RGB') | |
image = image.resize((input_size, input_size)) | |
image_np = np.array(image, dtype=np.float32) | |
image_np = np.expand_dims(image_np, axis=0) | |
image_np = image_np / 255.0 # Normalize to [0, 1] | |
return image_np | |
def run_inference(interpreter, input_data): | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
interpreter.set_tensor(input_details[0]['index'], input_data) | |
interpreter.invoke() | |
output_data_shopping_intent = interpreter.get_tensor(output_details[0]['index']) | |
output_data_sensitive = interpreter.get_tensor(output_details[1]['index']) | |
return output_data_shopping_intent, output_data_sensitive | |
def fetch_images_from_url(url): | |
response = requests.get(url) | |
soup = BeautifulSoup(response.content, 'html.parser') | |
img_tags = soup.find_all('img') | |
img_urls = [img['src'] for img in img_tags if 'src' in img.attrs] | |
return img_urls | |
def image_to_base64(image): | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode() | |
def main(): | |
st.set_page_config(layout="wide") | |
st.title("Shopping Intent Classification - SEO by DEJAN") | |
st.write("Enter a URL to fetch and classify all images on the page. Javascript-based website scraping currently unsupported.") | |
model_path = "model.tflite" | |
url = st.text_input("Enter URL") | |
if url: | |
img_urls = fetch_images_from_url(url) | |
if img_urls: | |
st.write(f"Found {len(img_urls)} images") | |
interpreter = load_model(model_path) | |
input_details = interpreter.get_input_details() | |
input_shape = input_details[0]['shape'] | |
input_size = input_shape[1] # assuming square input | |
data = [] | |
errors = [] | |
for img_url in img_urls: | |
try: | |
response = requests.get(img_url) | |
image = Image.open(BytesIO(response.content)) | |
input_data = preprocess_image(image, input_size) | |
output_data_shopping_intent, output_data_sensitive = run_inference(interpreter, input_data) | |
# Convert image to Base64 | |
image.thumbnail((100, 100)) | |
thumbnail_base64 = image_to_base64(image) | |
thumbnail_data_url = f"data:image/png;base64,{thumbnail_base64}" | |
data.append({ | |
'Thumbnail': thumbnail_data_url, | |
'URL': img_url, | |
'Shopping Intent': output_data_shopping_intent.flatten().tolist(), | |
'Sensitivity': output_data_sensitive.flatten().tolist() | |
}) | |
except Exception as e: | |
errors.append(f"Could not process image {img_url}: {e}") | |
# Convert data to DataFrame | |
df = pd.DataFrame(data) | |
# Configure DataFrame display with images, URLs, and classifications | |
st.dataframe(df) # Use dataframe for simple display | |
# Display errors in an expandable section | |
if errors: | |
with st.expander(f"Could not process {len(errors)} images"): | |
for error in errors: | |
st.write(error) | |
if __name__ == "__main__": | |
main() | |