File size: 5,490 Bytes
85b975e
a6fcafd
ffb04af
73d50df
77e4539
ffb04af
 
7d27275
ffb04af
7d27275
 
 
73d50df
0c2a5d3
ffb04af
7d27275
ffb04af
 
 
 
 
 
 
 
 
 
 
 
73d50df
 
 
7d27275
 
70548b8
f97cb5f
ffb04af
70548b8
95cdd7b
7d27275
a6fcafd
 
 
 
 
 
ffb04af
 
528548d
c1c81f5
7d27275
528548d
7d27275
 
c1c81f5
73d50df
ffb04af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85b975e
ffb04af
 
 
 
 
 
 
 
85b975e
ffb04af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73d50df
c7f0278
528548d
73d50df
ffb04af
 
 
 
 
 
 
 
 
 
73d50df
 
ffb04af
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
iimport gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, Trainer, TrainingArguments
from PIL import Image
import numpy as np
import pandas as pd
import os
import logging
from datasets import Dataset

# Configure logging
logging.basicConfig(level=logging.DEBUG)

# Load the pre-trained model and feature extractor
model_name = "google/vit-base-patch16-224"
logging.info("Loading image processor and model...")
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)

# Load or initialize the feedback data
feedback_data_path = "feedback_data.csv"
if os.path.exists(feedback_data_path):
    feedback_data = pd.read_csv(feedback_data_path)
else:
    feedback_data = pd.DataFrame(columns=["image_path", "correct_label"])

# Directory to save images
os.makedirs("images", exist_ok=True)

# Define the prediction function
def predict(image):
    try:
        logging.info("Received image of type: %s", type(image))
        logging.debug("Image content: %s", image)
        
        # Convert to NumPy array and then to PIL image
        image = np.array(image).astype('uint8')
        image = Image.fromarray(image, 'RGBA').convert('RGB')

        logging.info("Processing image...")
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)
        top_probs, top_idxs = probs.topk(3, dim=-1)
        top_probs = top_probs.detach().numpy()[0]
        top_idxs = top_idxs.detach().numpy()[0]
        top_classes = [model.config.id2label[idx] for idx in top_idxs]
        result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
        logging.info("Prediction successful.")
        return result
    except Exception as e:
        logging.error("Error during prediction: %s", e)
        return {"error": str(e)}

# Save feedback and retrain if necessary
def save_feedback(image, correct_label):
    global feedback_data
    try:
        image_np = np.array(image['composite']).astype('uint8')
        image_pil = Image.fromarray(image_np, 'RGBA').convert('RGB')
        image_path = f"images/{len(feedback_data)}.png"
        image_pil.save(image_path)
        
        # Add the feedback to the DataFrame
        feedback_data = feedback_data.append({"image_path": image_path, "correct_label": correct_label}, ignore_index=True)
        feedback_data.to_csv(feedback_data_path, index=False)
        
        # Retrain if we have collected 5 new feedbacks
        if len(feedback_data) % 5 == 0:
            retrain_model(feedback_data)
        
        return "Feedback saved and model retrained!" if len(feedback_data) % 5 == 0 else "Feedback saved!"
    except Exception as e:
        logging.error("Error saving feedback: %s", e)
        return {"error": str(e)}

# Retrain the model with the feedback data
def retrain_model(feedback_data):
    try:
        logging.info("Retraining the model with feedback data...")
        
        # Load images and labels into a Hugging Face dataset
        def load_image(file_path):
            return Image.open(file_path).convert("RGB")

        dataset_dict = {
            "image": [load_image(f) for f in feedback_data["image_path"]],
            "label": [str(lbl) for lbl in feedback_data["correct_label"].tolist()]
        }

        dataset = Dataset.from_dict(dataset_dict)
        dataset = dataset.train_test_split(test_size=0.1)

        # Preprocess the dataset
        def preprocess(examples):
            inputs = feature_extractor(images=examples["image"], return_tensors="pt")
            inputs["labels"] = [int(lbl) for lbl in examples["label"]]
            return inputs

        dataset = dataset.with_transform(preprocess)

        # Set up the training arguments
        training_args = TrainingArguments(
            output_dir="./results",
            evaluation_strategy="epoch",
            per_device_train_batch_size=4,
            per_device_eval_batch_size=4,
            num_train_epochs=3,
            save_strategy="epoch",
            save_total_limit=2,
            remove_unused_columns=False,
        )

        # Initialize the Trainer
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=dataset["train"],
            eval_dataset=dataset["test"],
        )

        # Train the model
        trainer.train()

        # Save the model
        model.save_pretrained("./fine_tuned_model")
        feature_extractor.save_pretrained("./fine_tuned_model")
        logging.info("Model retrained and saved successfully.")
    except Exception as e:
        logging.error("Error during model retraining: %s", e)

# Create the Gradio interfaces
predict_interface = gr.Interface(
    fn=predict,
    inputs=gr.Sketchpad(),
    outputs=gr.JSON(),
    title="Drawing Classifier",
    description="Draw something and the model will try to identify it!",
    live=True
)

feedback_interface = gr.Interface(
    fn=save_feedback,
    inputs=[gr.Sketchpad(), gr.Textbox(label="Correct Label")],
    outputs="text",
    title="Save Feedback",
    description="Draw something and provide the correct label to improve the model."
)

# Launch the interfaces together
gr.TabbedInterface(
    [predict_interface, feedback_interface],
    ["Predict", "Provide Feedback"]
).launch(share=True)