Sketch / app.py
Jangai's picture
Update app.py
85b975e verified
raw
history blame
5.49 kB
iimport gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, Trainer, TrainingArguments
from PIL import Image
import numpy as np
import pandas as pd
import os
import logging
from datasets import Dataset
# Configure logging
logging.basicConfig(level=logging.DEBUG)
# Load the pre-trained model and feature extractor
model_name = "google/vit-base-patch16-224"
logging.info("Loading image processor and model...")
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Load or initialize the feedback data
feedback_data_path = "feedback_data.csv"
if os.path.exists(feedback_data_path):
feedback_data = pd.read_csv(feedback_data_path)
else:
feedback_data = pd.DataFrame(columns=["image_path", "correct_label"])
# Directory to save images
os.makedirs("images", exist_ok=True)
# Define the prediction function
def predict(image):
try:
logging.info("Received image of type: %s", type(image))
logging.debug("Image content: %s", image)
# Convert to NumPy array and then to PIL image
image = np.array(image).astype('uint8')
image = Image.fromarray(image, 'RGBA').convert('RGB')
logging.info("Processing image...")
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
top_probs, top_idxs = probs.topk(3, dim=-1)
top_probs = top_probs.detach().numpy()[0]
top_idxs = top_idxs.detach().numpy()[0]
top_classes = [model.config.id2label[idx] for idx in top_idxs]
result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
logging.info("Prediction successful.")
return result
except Exception as e:
logging.error("Error during prediction: %s", e)
return {"error": str(e)}
# Save feedback and retrain if necessary
def save_feedback(image, correct_label):
global feedback_data
try:
image_np = np.array(image['composite']).astype('uint8')
image_pil = Image.fromarray(image_np, 'RGBA').convert('RGB')
image_path = f"images/{len(feedback_data)}.png"
image_pil.save(image_path)
# Add the feedback to the DataFrame
feedback_data = feedback_data.append({"image_path": image_path, "correct_label": correct_label}, ignore_index=True)
feedback_data.to_csv(feedback_data_path, index=False)
# Retrain if we have collected 5 new feedbacks
if len(feedback_data) % 5 == 0:
retrain_model(feedback_data)
return "Feedback saved and model retrained!" if len(feedback_data) % 5 == 0 else "Feedback saved!"
except Exception as e:
logging.error("Error saving feedback: %s", e)
return {"error": str(e)}
# Retrain the model with the feedback data
def retrain_model(feedback_data):
try:
logging.info("Retraining the model with feedback data...")
# Load images and labels into a Hugging Face dataset
def load_image(file_path):
return Image.open(file_path).convert("RGB")
dataset_dict = {
"image": [load_image(f) for f in feedback_data["image_path"]],
"label": [str(lbl) for lbl in feedback_data["correct_label"].tolist()]
}
dataset = Dataset.from_dict(dataset_dict)
dataset = dataset.train_test_split(test_size=0.1)
# Preprocess the dataset
def preprocess(examples):
inputs = feature_extractor(images=examples["image"], return_tensors="pt")
inputs["labels"] = [int(lbl) for lbl in examples["label"]]
return inputs
dataset = dataset.with_transform(preprocess)
# Set up the training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=3,
save_strategy="epoch",
save_total_limit=2,
remove_unused_columns=False,
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
# Train the model
trainer.train()
# Save the model
model.save_pretrained("./fine_tuned_model")
feature_extractor.save_pretrained("./fine_tuned_model")
logging.info("Model retrained and saved successfully.")
except Exception as e:
logging.error("Error during model retraining: %s", e)
# Create the Gradio interfaces
predict_interface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(),
outputs=gr.JSON(),
title="Drawing Classifier",
description="Draw something and the model will try to identify it!",
live=True
)
feedback_interface = gr.Interface(
fn=save_feedback,
inputs=[gr.Sketchpad(), gr.Textbox(label="Correct Label")],
outputs="text",
title="Save Feedback",
description="Draw something and provide the correct label to improve the model."
)
# Launch the interfaces together
gr.TabbedInterface(
[predict_interface, feedback_interface],
["Predict", "Provide Feedback"]
).launch(share=True)