Maverick98 commited on
Commit
43bc4d3
·
verified ·
1 Parent(s): 78e6f94

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +163 -0
model.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import torch
3
+ import json
4
+ import requests
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ import urllib.request
8
+ from torchvision import models
9
+ import torch.nn as nn
10
+
11
+ schema ={
12
+ "inputs": [
13
+ {
14
+ "name": "image",
15
+ "type": "image",
16
+ "description": "The image file to classify."
17
+ },
18
+ {
19
+ "name": "title",
20
+ "type": "string",
21
+ "description": "The text title associated with the image."
22
+ }
23
+ ],
24
+ "outputs": [
25
+ {
26
+ "name": "label",
27
+ "type": "string",
28
+ "description": "Predicted class label."
29
+ },
30
+ {
31
+ "name": "probability",
32
+ "type": "float",
33
+ "description": "Prediction confidence score."
34
+ }
35
+ ]
36
+ }
37
+
38
+
39
+ # --- Define the Model ---
40
+ class FineGrainedClassifier(nn.Module):
41
+ def __init__(self, num_classes=434): # Updated to 434 classes
42
+ super(FineGrainedClassifier, self).__init__()
43
+ self.image_encoder = models.resnet50(pretrained=True)
44
+ self.image_encoder.fc = nn.Identity()
45
+ self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en')
46
+ self.classifier = nn.Sequential(
47
+ nn.Linear(2048 + 768, 1024),
48
+ nn.BatchNorm1d(1024),
49
+ nn.ReLU(),
50
+ nn.Dropout(0.3),
51
+ nn.Linear(1024, 512),
52
+ nn.BatchNorm1d(512),
53
+ nn.ReLU(),
54
+ nn.Dropout(0.3),
55
+ nn.Linear(512, num_classes) # Updated to 434 classes
56
+ )
57
+
58
+ def forward(self, image, input_ids, attention_mask):
59
+ image_features = self.image_encoder(image)
60
+ text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
61
+ text_features = text_output.last_hidden_state[:, 0, :]
62
+ combined_features = torch.cat((image_features, text_features), dim=1)
63
+ output = self.classifier(combined_features)
64
+ return output
65
+
66
+ # --- Data Augmentation Setup ---
67
+ transform = transforms.Compose([
68
+ transforms.Resize((224, 224)),
69
+ transforms.RandomHorizontalFlip(),
70
+ transforms.RandomRotation(15),
71
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
72
+ transforms.ToTensor(),
73
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
74
+ ])
75
+
76
+ # # Load the label-to-class mapping from your Hugging Face repository
77
+ # label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
78
+ # label_to_class = requests.get(label_map_url).json()
79
+
80
+ # Load your custom model from Hugging Face
81
+ model = FineGrainedClassifier(num_classes=len(label_to_class))
82
+ checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth"
83
+ checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
84
+
85
+ # Strip the "module." prefix from the keys in the state_dict if they exist
86
+ # Clean up the state dictionary
87
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
88
+ new_state_dict = {}
89
+ for k, v in state_dict.items():
90
+ if k.startswith("module."):
91
+ new_key = k[7:] # Remove "module." prefix
92
+ else:
93
+ new_key = k
94
+
95
+ # Check if the new_key exists in the model's state_dict, only add if it does
96
+ if new_key in model.state_dict():
97
+ new_state_dict[new_key] = v
98
+
99
+ model.load_state_dict(new_state_dict)
100
+
101
+ # Load the tokenizer from Jina
102
+ tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
103
+
104
+ # def load_image(image_path_or_url):
105
+ # if isinstance(image_path_or_url, str) and image_path_or_url.startswith("http"):
106
+ # with urllib.request.urlopen(image_path_or_url) as url:
107
+ # image = Image.open(url).convert('RGB')
108
+ # else:
109
+ # image = Image.open(image_path_or_url).convert('RGB')
110
+
111
+ # image = transform(image)
112
+ # image = image.unsqueeze(0) # Add batch dimension
113
+ # return image
114
+
115
+ # def predict(image_path_or_file, title, threshold=0.4):
116
+
117
+ def inference(inputs):
118
+ image = inputs.get("image")
119
+ title = inputs.get("title")
120
+ if not isinstance(title, str):
121
+ return {"error": "Title must be a string."}
122
+
123
+ if not isinstance(image, (Image.Image, torch.Tensor)):
124
+ return {"error": "Image must be a valid image file or a tensor."}
125
+
126
+ threshold = 0.4
127
+ # Validation: Check if the title is empty or has fewer than 3 words
128
+ if not title or len(title.split()) < 3:
129
+ raise gr.Error("Title must be at least 3 words long. Please provide a valid title.")
130
+
131
+ # Preprocess the image
132
+ image = load_image(image_path_or_file)
133
+
134
+ # Tokenize title
135
+ title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
136
+ input_ids = title_encoding['input_ids']
137
+ attention_mask = title_encoding['attention_mask']
138
+
139
+ # Predict
140
+ model.eval()
141
+ with torch.no_grad():
142
+ output = model(image, input_ids=input_ids, attention_mask=attention_mask)
143
+ probabilities = torch.nn.functional.softmax(output, dim=1)
144
+ top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1)
145
+
146
+ # Map indices to class names (Assuming you have a mapping)
147
+ with open("label_to_class.json", "r") as f:
148
+ label_to_class = json.load(f)
149
+
150
+ # Map the top 3 indices to class names
151
+ top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]]
152
+
153
+ # Check if the highest probability is below the threshold
154
+ if top3_probabilities[0][0].item() < threshold:
155
+ top3_classes.insert(0, "Others")
156
+ top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1)
157
+
158
+ # Prepare the output as a dictionary
159
+ results = {}
160
+ for i in range(len(top3_classes)):
161
+ results[top3_classes[i]] = top3_probabilities[0][i].item()
162
+
163
+ return results