|
iimport gradio as gr |
|
import torch |
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, Trainer, TrainingArguments |
|
from PIL import Image |
|
import numpy as np |
|
import pandas as pd |
|
import os |
|
import logging |
|
from datasets import Dataset |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
|
|
model_name = "google/vit-base-patch16-224" |
|
logging.info("Loading image processor and model...") |
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) |
|
model = AutoModelForImageClassification.from_pretrained(model_name) |
|
|
|
|
|
feedback_data_path = "feedback_data.csv" |
|
if os.path.exists(feedback_data_path): |
|
feedback_data = pd.read_csv(feedback_data_path) |
|
else: |
|
feedback_data = pd.DataFrame(columns=["image_path", "correct_label"]) |
|
|
|
|
|
os.makedirs("images", exist_ok=True) |
|
|
|
|
|
def predict(image): |
|
try: |
|
logging.info("Received image of type: %s", type(image)) |
|
logging.debug("Image content: %s", image) |
|
|
|
|
|
image = np.array(image).astype('uint8') |
|
image = Image.fromarray(image, 'RGBA').convert('RGB') |
|
|
|
logging.info("Processing image...") |
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
top_probs, top_idxs = probs.topk(3, dim=-1) |
|
top_probs = top_probs.detach().numpy()[0] |
|
top_idxs = top_idxs.detach().numpy()[0] |
|
top_classes = [model.config.id2label[idx] for idx in top_idxs] |
|
result = {top_classes[i]: float(top_probs[i]) for i in range(3)} |
|
logging.info("Prediction successful.") |
|
return result |
|
except Exception as e: |
|
logging.error("Error during prediction: %s", e) |
|
return {"error": str(e)} |
|
|
|
|
|
def save_feedback(image, correct_label): |
|
global feedback_data |
|
try: |
|
image_np = np.array(image['composite']).astype('uint8') |
|
image_pil = Image.fromarray(image_np, 'RGBA').convert('RGB') |
|
image_path = f"images/{len(feedback_data)}.png" |
|
image_pil.save(image_path) |
|
|
|
|
|
feedback_data = feedback_data.append({"image_path": image_path, "correct_label": correct_label}, ignore_index=True) |
|
feedback_data.to_csv(feedback_data_path, index=False) |
|
|
|
|
|
if len(feedback_data) % 5 == 0: |
|
retrain_model(feedback_data) |
|
|
|
return "Feedback saved and model retrained!" if len(feedback_data) % 5 == 0 else "Feedback saved!" |
|
except Exception as e: |
|
logging.error("Error saving feedback: %s", e) |
|
return {"error": str(e)} |
|
|
|
|
|
def retrain_model(feedback_data): |
|
try: |
|
logging.info("Retraining the model with feedback data...") |
|
|
|
|
|
def load_image(file_path): |
|
return Image.open(file_path).convert("RGB") |
|
|
|
dataset_dict = { |
|
"image": [load_image(f) for f in feedback_data["image_path"]], |
|
"label": [str(lbl) for lbl in feedback_data["correct_label"].tolist()] |
|
} |
|
|
|
dataset = Dataset.from_dict(dataset_dict) |
|
dataset = dataset.train_test_split(test_size=0.1) |
|
|
|
|
|
def preprocess(examples): |
|
inputs = feature_extractor(images=examples["image"], return_tensors="pt") |
|
inputs["labels"] = [int(lbl) for lbl in examples["label"]] |
|
return inputs |
|
|
|
dataset = dataset.with_transform(preprocess) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
evaluation_strategy="epoch", |
|
per_device_train_batch_size=4, |
|
per_device_eval_batch_size=4, |
|
num_train_epochs=3, |
|
save_strategy="epoch", |
|
save_total_limit=2, |
|
remove_unused_columns=False, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset["train"], |
|
eval_dataset=dataset["test"], |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
model.save_pretrained("./fine_tuned_model") |
|
feature_extractor.save_pretrained("./fine_tuned_model") |
|
logging.info("Model retrained and saved successfully.") |
|
except Exception as e: |
|
logging.error("Error during model retraining: %s", e) |
|
|
|
|
|
predict_interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Sketchpad(), |
|
outputs=gr.JSON(), |
|
title="Drawing Classifier", |
|
description="Draw something and the model will try to identify it!", |
|
live=True |
|
) |
|
|
|
feedback_interface = gr.Interface( |
|
fn=save_feedback, |
|
inputs=[gr.Sketchpad(), gr.Textbox(label="Correct Label")], |
|
outputs="text", |
|
title="Save Feedback", |
|
description="Draw something and provide the correct label to improve the model." |
|
) |
|
|
|
|
|
gr.TabbedInterface( |
|
[predict_interface, feedback_interface], |
|
["Predict", "Provide Feedback"] |
|
).launch(share=True) |
|
|