PushkarA07's picture
Update app.py
2994578
import matplotlib.image as mpimg
import streamlit as st
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import pickle
from sklearn.preprocessing import LabelEncoder
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load the pre-trained model
model = torch.load('model.pt', map_location=torch.device('cpu'))
model.eval()
# Define the image transformation to normalize the image
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load the label encoder
with open('label_encoder.pkl', 'rb') as f:
lbl_enc = pickle.load(f)
def decode_predictions(preds, encoder):
preds = preds.permute(1, 0, 2)
preds = torch.softmax(preds, 2)
preds = torch.argmax(preds, 2)
preds = preds.detach().cpu().numpy()
cap_preds = []
for j in range(preds.shape[0]):
temp = []
for k in preds[j,:]:
k = k - 1
if k == -1:
temp.append("-")
else:
temp.append(encoder.inverse_transform([k])[0])
tp = "".join(temp)
cap_preds.append(tp)
return cap_preds
def predict_function(model, data):
model.eval()
fin_preds = []
with torch.no_grad():
# for data in data_loader:
for k, v in data.items():
data[k] = v.to(DEVICE)
batch_preds, _ = model(**data)
fin_preds.append(batch_preds)
return fin_preds
def clean_decoded_predictions(unclean_predictions):
cleaned_predictions = []
for i in unclean_predictions:
if i != "-":
cleaned_predictions.append(i)
cleaned_predictions = "".join(cleaned_predictions)
if len(cleaned_predictions) == 10:
return cleaned_predictions
else:
prev = "-"
new_cleaned_predictions = []
for char in cleaned_predictions:
if char == prev:
continue
new_cleaned_predictions.append(char)
prev = char
res = "".join(new_cleaned_predictions)
return res
def predict_captcha(model, image_path):
plt.figure(figsize=(15, 5))
image = mpimg.imread(image_path[0])
# target = image_path[0].split("/")[-1].split(".")[0]
plt.title(image_path[0].split("/")[-1])
plt.imshow(image)
valid_preds = predict_function(model, image)
current_preds = decode_predictions(valid_preds, lbl_enc)
preds = clean_decoded_predictions(current_preds[0])
# success = True if preds == target else False
return preds
# Define the Streamlit app
def app():
st.title("Captcha Breaker Project")
st.write("by - Pushkar Ambastha")
st.write("Upload an image of a captcha to recognize the text")
# Allow the user to upload an image
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Load the image and transform it
img = Image.open(uploaded_file)
img = transform(img)
# Make a prediction with the model
with torch.no_grad():
prediction = predict_captcha(model, img.unsqueeze(0))
# Get the predicted text and display it
captcha_text = "".join([chr(int(x)) for x in prediction])
st.write(f"The captcha text is: {captcha_text}")