sync to remote
Browse files
app.py
CHANGED
@@ -1,137 +1,153 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
import
|
4 |
-
import pickle
|
5 |
-
from torch.utils.data import Dataset, DataLoader
|
6 |
-
from torchvision import transforms
|
7 |
from PIL import Image
|
8 |
-
import
|
9 |
-
import
|
10 |
-
|
11 |
-
|
12 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
-
st.write(f"Enabled GPU = {torch.cuda.is_available()}")
|
14 |
|
|
|
15 |
MODEL_REPO_ID = "louiecerv/amer_sign_lang_neuralnet"
|
16 |
-
|
17 |
-
|
18 |
-
# Load dataset from Hugging Face
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
self.relu1 = nn.ReLU()
|
36 |
-
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
37 |
-
self.relu2 = nn.ReLU()
|
38 |
-
self.fc3 = nn.Linear(hidden_size, num_classes)
|
39 |
-
|
40 |
-
def forward(self, x):
|
41 |
-
x = self.flatten(x)
|
42 |
-
x = self.fc1(x)
|
43 |
-
x = self.relu1(x)
|
44 |
-
x = self.fc2(x)
|
45 |
-
x = self.relu2(x)
|
46 |
-
x = self.fc3(x)
|
47 |
-
return x
|
48 |
-
|
49 |
-
# Load pre-trained model from Hugging Face (pickle file)
|
50 |
-
@st.cache_resource
|
51 |
-
def load_model():
|
52 |
-
url = f"https://huggingface.co/{MODEL_REPO_ID}/resolve/main/trained_model.pkl"
|
53 |
-
response = requests.get(url)
|
54 |
-
with open("trained_model.pkl", "wb") as f:
|
55 |
-
f.write(response.content)
|
56 |
-
with open("trained_model.pkl", "rb") as f:
|
57 |
-
model = pickle.load(f)
|
58 |
-
model.to(device)
|
59 |
-
model.eval() # Set model to evaluation mode
|
60 |
-
return model
|
61 |
-
|
62 |
-
model = load_model()
|
63 |
-
|
64 |
-
# Custom dataset class
|
65 |
-
class ASLDataset(Dataset):
|
66 |
-
def __init__(self, data):
|
67 |
-
self.data = data
|
68 |
-
self.transform = transforms.Compose([
|
69 |
transforms.Grayscale(num_output_channels=1),
|
70 |
transforms.Resize((28, 28)),
|
71 |
transforms.ToTensor(),
|
72 |
-
transforms.Normalize(mean=
|
73 |
])
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
# Streamlit App
|
101 |
-
st.title("American Sign Language
|
102 |
-
|
103 |
-
#
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
st.header("Dataset
|
109 |
-
st.write("Displaying
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
if uploaded_file is not None:
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
import numpy as np
|
|
|
|
|
|
|
4 |
from PIL import Image
|
5 |
+
import pickle
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from datasets import load_dataset
|
|
|
|
|
9 |
|
10 |
+
# Model repository ID
|
11 |
MODEL_REPO_ID = "louiecerv/amer_sign_lang_neuralnet"
|
12 |
+
MODEL_FILENAME = "trained_model.pkl" # The filename of your model on Hugging Face
|
13 |
+
|
14 |
+
# Load dataset from Hugging Face
|
15 |
+
DATASET_NAME = "louiecerv/american_sign_language" # Replace with your dataset name
|
16 |
+
dataset = load_dataset(DATASET_NAME, split="train")
|
17 |
+
|
18 |
+
def preprocess_image(image: Image) -> tuple[torch.Tensor, Image]:
|
19 |
+
"""
|
20 |
+
Preprocess the image by converting it to grayscale, resizing it to 28x28,
|
21 |
+
normalizing the pixel values, and converting it to a tensor.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
image (Image): The input image.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
tuple[torch.Tensor, Image]: A tuple containing the preprocessed image tensor and the processed PIL image.
|
28 |
+
"""
|
29 |
+
try:
|
30 |
+
transform = transforms.Compose([
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
transforms.Grayscale(num_output_channels=1),
|
32 |
transforms.Resize((28, 28)),
|
33 |
transforms.ToTensor(),
|
34 |
+
transforms.Normalize(mean=0.5, std=0.5)
|
35 |
])
|
36 |
+
tensor_image = transform(image)
|
37 |
+
|
38 |
+
# Convert the tensor back to a PIL Image for display
|
39 |
+
tensor_image_pil = tensor_image.squeeze().cpu().numpy() # Remove batch dimension and convert to numpy
|
40 |
+
tensor_image_pil = (tensor_image_pil * 0.5 + 0.5) * 255 # Unnormalize
|
41 |
+
tensor_image_pil = tensor_image_pil.astype(np.uint8) # Convert to uint8 for PIL
|
42 |
+
processed_image_pil = Image.fromarray(tensor_image_pil)
|
43 |
+
|
44 |
+
return tensor_image, processed_image_pil
|
45 |
+
except Exception as e:
|
46 |
+
st.error(f"Error preprocessing image: {e}")
|
47 |
+
return None, None
|
48 |
+
|
49 |
+
def load_model(repo_id: str, filename: str) -> torch.nn.Module:
|
50 |
+
"""
|
51 |
+
Load the model from Hugging Face Hub.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
repo_id (str): The repository ID of the model.
|
55 |
+
filename (str): The filename of the model.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
torch.nn.Module: The loaded model.
|
59 |
+
"""
|
60 |
+
try:
|
61 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
62 |
+
with open(model_path, "rb") as f:
|
63 |
+
model = pickle.load(f)
|
64 |
+
return model
|
65 |
+
except Exception as e:
|
66 |
+
st.error(f"Error loading model: {e}")
|
67 |
+
return None
|
68 |
+
|
69 |
+
def make_prediction(model: torch.nn.Module, image_tensor: torch.Tensor) -> str:
|
70 |
+
"""
|
71 |
+
Make a prediction using the loaded model and the preprocessed image tensor.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
model (torch.nn.Module): The loaded model.
|
75 |
+
image_tensor (torch.Tensor): The preprocessed image tensor.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
str: The predicted letter.
|
79 |
+
"""
|
80 |
+
try:
|
81 |
+
model.eval()
|
82 |
+
with torch.no_grad():
|
83 |
+
# Add batch dimension if not already present
|
84 |
+
if len(image_tensor.shape) == 3:
|
85 |
+
image_tensor = image_tensor.unsqueeze(0)
|
86 |
+
prediction = model(image_tensor)
|
87 |
+
predicted_class = torch.argmax(prediction).item()
|
88 |
+
predicted_letter = chr(predicted_class + ord('A'))
|
89 |
+
return predicted_letter
|
90 |
+
except Exception as e:
|
91 |
+
st.error(f"Error making prediction: {e}")
|
92 |
+
return None
|
93 |
+
|
94 |
+
def tensor_to_image(pixel_list):
|
95 |
+
"""Converts a tensor to a displayable image."""
|
96 |
+
array = np.array(pixel_list).reshape(28, 28)
|
97 |
+
array = (array * 0.5 + 0.5) * 255 # Assuming mean=0.5, std=0.5
|
98 |
+
array = np.clip(array, 0, 255).astype(np.uint8)
|
99 |
+
return Image.fromarray(array)
|
100 |
|
101 |
# Streamlit App
|
102 |
+
st.title("American Sign Language App")
|
103 |
+
|
104 |
+
# Create tabs
|
105 |
+
tabs = ["Dataset", "Prediction"]
|
106 |
+
selected_tab = st.sidebar.radio("Select Tab", tabs)
|
107 |
+
|
108 |
+
if selected_tab == "Dataset":
|
109 |
+
st.header("Dataset")
|
110 |
+
st.write("Displaying the first 20 images from the dataset.")
|
111 |
+
|
112 |
+
# Create a grid layout
|
113 |
+
cols = 5 # Number of columns
|
114 |
+
rows = 4 # Number of rows
|
115 |
+
num_images = cols * rows
|
116 |
+
|
117 |
+
# Display images in a grid
|
118 |
+
image_list = dataset[:num_images]["pixel_values"]
|
119 |
+
labels = dataset[:num_images]["label"]
|
120 |
+
|
121 |
+
# Display images using Streamlit columns
|
122 |
+
for row in range(rows):
|
123 |
+
columns = st.columns(cols)
|
124 |
+
for col in range(cols):
|
125 |
+
index = row * cols + col
|
126 |
+
image = tensor_to_image(image_list[index])
|
127 |
+
columns[col].image(image, caption=f"Label: {chr(labels[index] + ord('A'))}", use_container_width=True)
|
128 |
+
|
129 |
+
elif selected_tab == "Prediction":
|
130 |
+
st.header("Prediction")
|
131 |
+
st.write("Upload an image of an ASL letter.")
|
132 |
+
|
133 |
+
# File uploader
|
134 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
135 |
+
|
136 |
if uploaded_file is not None:
|
137 |
+
# Load and preprocess the image
|
138 |
+
image = Image.open(uploaded_file).convert("RGB") # Ensure RGB for consistent processing
|
139 |
+
st.image(image, caption="Uploaded Image.", use_container_width=True)
|
140 |
+
image_tensor, processed_image_pil = preprocess_image(image)
|
141 |
+
|
142 |
+
if image_tensor is not None and processed_image_pil is not None:
|
143 |
+
st.image(processed_image_pil, caption="Preprocessed Image.", use_container_width=True) # Display processed image
|
144 |
+
|
145 |
+
# Load the model
|
146 |
+
model = load_model(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
|
147 |
+
|
148 |
+
if model is not None:
|
149 |
+
# Make a prediction
|
150 |
+
predicted_letter = make_prediction(model, image_tensor)
|
151 |
+
|
152 |
+
if predicted_letter is not None:
|
153 |
+
st.write(f"Predicted Letter: {predicted_letter}")
|