shopping-intent / app.py
dejanseo's picture
Update app.py
327b650 verified
raw
history blame
3.79 kB
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image
import requests
from io import BytesIO
from bs4 import BeautifulSoup
import pandas as pd
import base64
def load_model(model_path):
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
return interpreter
def preprocess_image(image, input_size):
image = image.convert('RGB')
image = image.resize((input_size, input_size))
image_np = np.array(image, dtype=np.float32)
image_np = np.expand_dims(image_np, axis=0)
image_np = image_np / 255.0 # Normalize to [0, 1]
return image_np
def run_inference(interpreter, input_data):
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data_shopping_intent = interpreter.get_tensor(output_details[0]['index'])
output_data_sensitive = interpreter.get_tensor(output_details[1]['index'])
return output_data_shopping_intent, output_data_sensitive
def fetch_images_from_url(url):
response = requests.get(url)
soup = BeautifulSoup(response.content, 'html.parser')
img_tags = soup.find_all('img')
img_urls = [img['src'] for img in img_tags if 'src' in img.attrs]
return img_urls
def image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def main():
st.set_page_config(layout="wide")
st.title("Shopping Intent Classification - SEO by DEJAN")
st.write("Enter a URL to fetch and classify all images on the page. Javascript-based website scraping currently unsupported.")
model_path = "model.tflite"
url = st.text_input("Enter URL")
if url:
img_urls = fetch_images_from_url(url)
if img_urls:
st.write(f"Found {len(img_urls)} images")
interpreter = load_model(model_path)
input_details = interpreter.get_input_details()
input_shape = input_details[0]['shape']
input_size = input_shape[1] # assuming square input
data = []
errors = []
for img_url in img_urls:
try:
response = requests.get(img_url)
image = Image.open(BytesIO(response.content))
input_data = preprocess_image(image, input_size)
output_data_shopping_intent, output_data_sensitive = run_inference(interpreter, input_data)
# Convert image to Base64
image.thumbnail((100, 100))
thumbnail_base64 = image_to_base64(image)
thumbnail_data_url = f"data:image/png;base64,{thumbnail_base64}"
data.append({
'Thumbnail': thumbnail_data_url,
'URL': img_url,
'Shopping Intent': output_data_shopping_intent.flatten().tolist(),
'Sensitivity': output_data_sensitive.flatten().tolist()
})
except Exception as e:
errors.append(f"Could not process image {img_url}: {e}")
# Convert data to DataFrame
df = pd.DataFrame(data)
# Configure DataFrame display with images, URLs, and classifications
st.dataframe(df) # Use dataframe for simple display
# Display errors in an expandable section
if errors:
with st.expander(f"Could not process {len(errors)} images"):
for error in errors:
st.write(error)
if __name__ == "__main__":
main()