iimport gradio as gr import torch from transformers import AutoFeatureExtractor, AutoModelForImageClassification, Trainer, TrainingArguments from PIL import Image import numpy as np import pandas as pd import os import logging from datasets import Dataset # Configure logging logging.basicConfig(level=logging.DEBUG) # Load the pre-trained model and feature extractor model_name = "google/vit-base-patch16-224" logging.info("Loading image processor and model...") feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) model = AutoModelForImageClassification.from_pretrained(model_name) # Load or initialize the feedback data feedback_data_path = "feedback_data.csv" if os.path.exists(feedback_data_path): feedback_data = pd.read_csv(feedback_data_path) else: feedback_data = pd.DataFrame(columns=["image_path", "correct_label"]) # Directory to save images os.makedirs("images", exist_ok=True) # Define the prediction function def predict(image): try: logging.info("Received image of type: %s", type(image)) logging.debug("Image content: %s", image) # Convert to NumPy array and then to PIL image image = np.array(image).astype('uint8') image = Image.fromarray(image, 'RGBA').convert('RGB') logging.info("Processing image...") inputs = feature_extractor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=-1) top_probs, top_idxs = probs.topk(3, dim=-1) top_probs = top_probs.detach().numpy()[0] top_idxs = top_idxs.detach().numpy()[0] top_classes = [model.config.id2label[idx] for idx in top_idxs] result = {top_classes[i]: float(top_probs[i]) for i in range(3)} logging.info("Prediction successful.") return result except Exception as e: logging.error("Error during prediction: %s", e) return {"error": str(e)} # Save feedback and retrain if necessary def save_feedback(image, correct_label): global feedback_data try: image_np = np.array(image['composite']).astype('uint8') image_pil = Image.fromarray(image_np, 'RGBA').convert('RGB') image_path = f"images/{len(feedback_data)}.png" image_pil.save(image_path) # Add the feedback to the DataFrame feedback_data = feedback_data.append({"image_path": image_path, "correct_label": correct_label}, ignore_index=True) feedback_data.to_csv(feedback_data_path, index=False) # Retrain if we have collected 5 new feedbacks if len(feedback_data) % 5 == 0: retrain_model(feedback_data) return "Feedback saved and model retrained!" if len(feedback_data) % 5 == 0 else "Feedback saved!" except Exception as e: logging.error("Error saving feedback: %s", e) return {"error": str(e)} # Retrain the model with the feedback data def retrain_model(feedback_data): try: logging.info("Retraining the model with feedback data...") # Load images and labels into a Hugging Face dataset def load_image(file_path): return Image.open(file_path).convert("RGB") dataset_dict = { "image": [load_image(f) for f in feedback_data["image_path"]], "label": [str(lbl) for lbl in feedback_data["correct_label"].tolist()] } dataset = Dataset.from_dict(dataset_dict) dataset = dataset.train_test_split(test_size=0.1) # Preprocess the dataset def preprocess(examples): inputs = feature_extractor(images=examples["image"], return_tensors="pt") inputs["labels"] = [int(lbl) for lbl in examples["label"]] return inputs dataset = dataset.with_transform(preprocess) # Set up the training arguments training_args = TrainingArguments( output_dir="./results", evaluation_strategy="epoch", per_device_train_batch_size=4, per_device_eval_batch_size=4, num_train_epochs=3, save_strategy="epoch", save_total_limit=2, remove_unused_columns=False, ) # Initialize the Trainer trainer = Trainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], ) # Train the model trainer.train() # Save the model model.save_pretrained("./fine_tuned_model") feature_extractor.save_pretrained("./fine_tuned_model") logging.info("Model retrained and saved successfully.") except Exception as e: logging.error("Error during model retraining: %s", e) # Create the Gradio interfaces predict_interface = gr.Interface( fn=predict, inputs=gr.Sketchpad(), outputs=gr.JSON(), title="Drawing Classifier", description="Draw something and the model will try to identify it!", live=True ) feedback_interface = gr.Interface( fn=save_feedback, inputs=[gr.Sketchpad(), gr.Textbox(label="Correct Label")], outputs="text", title="Save Feedback", description="Draw something and provide the correct label to improve the model." ) # Launch the interfaces together gr.TabbedInterface( [predict_interface, feedback_interface], ["Predict", "Provide Feedback"] ).launch(share=True)