File size: 5,490 Bytes
85b975e a6fcafd ffb04af 73d50df 77e4539 ffb04af 7d27275 ffb04af 7d27275 73d50df 0c2a5d3 ffb04af 7d27275 ffb04af 73d50df 7d27275 70548b8 f97cb5f ffb04af 70548b8 95cdd7b 7d27275 a6fcafd ffb04af 528548d c1c81f5 7d27275 528548d 7d27275 c1c81f5 73d50df ffb04af 85b975e ffb04af 85b975e ffb04af 73d50df c7f0278 528548d 73d50df ffb04af 73d50df ffb04af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
iimport gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, Trainer, TrainingArguments
from PIL import Image
import numpy as np
import pandas as pd
import os
import logging
from datasets import Dataset
# Configure logging
logging.basicConfig(level=logging.DEBUG)
# Load the pre-trained model and feature extractor
model_name = "google/vit-base-patch16-224"
logging.info("Loading image processor and model...")
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Load or initialize the feedback data
feedback_data_path = "feedback_data.csv"
if os.path.exists(feedback_data_path):
feedback_data = pd.read_csv(feedback_data_path)
else:
feedback_data = pd.DataFrame(columns=["image_path", "correct_label"])
# Directory to save images
os.makedirs("images", exist_ok=True)
# Define the prediction function
def predict(image):
try:
logging.info("Received image of type: %s", type(image))
logging.debug("Image content: %s", image)
# Convert to NumPy array and then to PIL image
image = np.array(image).astype('uint8')
image = Image.fromarray(image, 'RGBA').convert('RGB')
logging.info("Processing image...")
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
top_probs, top_idxs = probs.topk(3, dim=-1)
top_probs = top_probs.detach().numpy()[0]
top_idxs = top_idxs.detach().numpy()[0]
top_classes = [model.config.id2label[idx] for idx in top_idxs]
result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
logging.info("Prediction successful.")
return result
except Exception as e:
logging.error("Error during prediction: %s", e)
return {"error": str(e)}
# Save feedback and retrain if necessary
def save_feedback(image, correct_label):
global feedback_data
try:
image_np = np.array(image['composite']).astype('uint8')
image_pil = Image.fromarray(image_np, 'RGBA').convert('RGB')
image_path = f"images/{len(feedback_data)}.png"
image_pil.save(image_path)
# Add the feedback to the DataFrame
feedback_data = feedback_data.append({"image_path": image_path, "correct_label": correct_label}, ignore_index=True)
feedback_data.to_csv(feedback_data_path, index=False)
# Retrain if we have collected 5 new feedbacks
if len(feedback_data) % 5 == 0:
retrain_model(feedback_data)
return "Feedback saved and model retrained!" if len(feedback_data) % 5 == 0 else "Feedback saved!"
except Exception as e:
logging.error("Error saving feedback: %s", e)
return {"error": str(e)}
# Retrain the model with the feedback data
def retrain_model(feedback_data):
try:
logging.info("Retraining the model with feedback data...")
# Load images and labels into a Hugging Face dataset
def load_image(file_path):
return Image.open(file_path).convert("RGB")
dataset_dict = {
"image": [load_image(f) for f in feedback_data["image_path"]],
"label": [str(lbl) for lbl in feedback_data["correct_label"].tolist()]
}
dataset = Dataset.from_dict(dataset_dict)
dataset = dataset.train_test_split(test_size=0.1)
# Preprocess the dataset
def preprocess(examples):
inputs = feature_extractor(images=examples["image"], return_tensors="pt")
inputs["labels"] = [int(lbl) for lbl in examples["label"]]
return inputs
dataset = dataset.with_transform(preprocess)
# Set up the training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=3,
save_strategy="epoch",
save_total_limit=2,
remove_unused_columns=False,
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
# Train the model
trainer.train()
# Save the model
model.save_pretrained("./fine_tuned_model")
feature_extractor.save_pretrained("./fine_tuned_model")
logging.info("Model retrained and saved successfully.")
except Exception as e:
logging.error("Error during model retraining: %s", e)
# Create the Gradio interfaces
predict_interface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(),
outputs=gr.JSON(),
title="Drawing Classifier",
description="Draw something and the model will try to identify it!",
live=True
)
feedback_interface = gr.Interface(
fn=save_feedback,
inputs=[gr.Sketchpad(), gr.Textbox(label="Correct Label")],
outputs="text",
title="Save Feedback",
description="Draw something and provide the correct label to improve the model."
)
# Launch the interfaces together
gr.TabbedInterface(
[predict_interface, feedback_interface],
["Predict", "Provide Feedback"]
).launch(share=True)
|