louiecerv's picture
sync to remote
c0f9b0f
import streamlit as st
import torch
import numpy as np
from PIL import Image
import pickle
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
from datasets import load_dataset
# Model repository ID
MODEL_REPO_ID = "louiecerv/amer_sign_lang_neuralnet"
MODEL_FILENAME = "trained_model.pkl" # The filename of your model on Hugging Face
# Load dataset from Hugging Face
DATASET_NAME = "louiecerv/american_sign_language" # Replace with your dataset name
dataset = load_dataset(DATASET_NAME, split="train")
def preprocess_image(image: Image) -> tuple[torch.Tensor, Image]:
"""
Preprocess the image by converting it to grayscale, resizing it to 28x28,
normalizing the pixel values, and converting it to a tensor.
Args:
image (Image): The input image.
Returns:
tuple[torch.Tensor, Image]: A tuple containing the preprocessed image tensor and the processed PIL image.
"""
try:
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.5)
])
tensor_image = transform(image)
# Convert the tensor back to a PIL Image for display
tensor_image_pil = tensor_image.squeeze().cpu().numpy() # Remove batch dimension and convert to numpy
tensor_image_pil = (tensor_image_pil * 0.5 + 0.5) * 255 # Unnormalize
tensor_image_pil = tensor_image_pil.astype(np.uint8) # Convert to uint8 for PIL
processed_image_pil = Image.fromarray(tensor_image_pil)
return tensor_image, processed_image_pil
except Exception as e:
st.error(f"Error preprocessing image: {e}")
return None, None
def load_model(repo_id: str, filename: str) -> torch.nn.Module:
"""
Load the model from Hugging Face Hub.
Args:
repo_id (str): The repository ID of the model.
filename (str): The filename of the model.
Returns:
torch.nn.Module: The loaded model.
"""
try:
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(model_path, "rb") as f:
model = pickle.load(f)
return model
except Exception as e:
st.error(f"Error loading model: {e}")
return None
def make_prediction(model: torch.nn.Module, image_tensor: torch.Tensor) -> str:
"""
Make a prediction using the loaded model and the preprocessed image tensor.
Args:
model (torch.nn.Module): The loaded model.
image_tensor (torch.Tensor): The preprocessed image tensor.
Returns:
str: The predicted letter.
"""
try:
model.eval()
with torch.no_grad():
# Add batch dimension if not already present
if len(image_tensor.shape) == 3:
image_tensor = image_tensor.unsqueeze(0)
prediction = model(image_tensor)
predicted_class = torch.argmax(prediction).item()
predicted_letter = chr(predicted_class + ord('A'))
return predicted_letter
except Exception as e:
st.error(f"Error making prediction: {e}")
return None
def tensor_to_image(pixel_list):
"""Converts a tensor to a displayable image."""
array = np.array(pixel_list).reshape(28, 28)
array = (array * 0.5 + 0.5) * 255 # Assuming mean=0.5, std=0.5
array = np.clip(array, 0, 255).astype(np.uint8)
return Image.fromarray(array)
# Streamlit App
st.title("American Sign Language App")
# Create tabs
tabs = ["Dataset", "Prediction"]
selected_tab = st.sidebar.radio("Select Tab", tabs)
if selected_tab == "Dataset":
st.header("Dataset")
st.write("Displaying the first 20 images from the dataset.")
# Create a grid layout
cols = 5 # Number of columns
rows = 4 # Number of rows
num_images = cols * rows
# Display images in a grid
image_list = dataset[:num_images]["pixel_values"]
labels = dataset[:num_images]["label"]
# Display images using Streamlit columns
for row in range(rows):
columns = st.columns(cols)
for col in range(cols):
index = row * cols + col
image = tensor_to_image(image_list[index])
columns[col].image(image, caption=f"Label: {chr(labels[index] + ord('A'))}", use_container_width=True)
elif selected_tab == "Prediction":
st.header("Prediction")
st.write("Upload an image of an ASL letter.")
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Load and preprocess the image
image = Image.open(uploaded_file).convert("RGB") # Ensure RGB for consistent processing
st.image(image, caption="Uploaded Image.", use_container_width=True)
image_tensor, processed_image_pil = preprocess_image(image)
if image_tensor is not None and processed_image_pil is not None:
st.image(processed_image_pil, caption="Preprocessed Image.", use_container_width=True) # Display processed image
# Load the model
model = load_model(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
if model is not None:
# Make a prediction
predicted_letter = make_prediction(model, image_tensor)
if predicted_letter is not None:
st.write(f"Predicted Letter: {predicted_letter}")