|
from transformers import AutoModel, AutoTokenizer |
|
import torch |
|
import json |
|
import requests |
|
from PIL import Image |
|
from torchvision import transforms |
|
import urllib.request |
|
from torchvision import models |
|
import torch.nn as nn |
|
|
|
schema ={ |
|
"inputs": [ |
|
{ |
|
"name": "image", |
|
"type": "image", |
|
"description": "The image file to classify." |
|
}, |
|
{ |
|
"name": "title", |
|
"type": "string", |
|
"description": "The text title associated with the image." |
|
} |
|
], |
|
"outputs": [ |
|
{ |
|
"name": "label", |
|
"type": "string", |
|
"description": "Predicted class label." |
|
}, |
|
{ |
|
"name": "probability", |
|
"type": "float", |
|
"description": "Prediction confidence score." |
|
} |
|
] |
|
} |
|
|
|
|
|
|
|
class FineGrainedClassifier(nn.Module): |
|
def __init__(self, num_classes=434): |
|
super(FineGrainedClassifier, self).__init__() |
|
self.image_encoder = models.resnet50(pretrained=True) |
|
self.image_encoder.fc = nn.Identity() |
|
self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en') |
|
self.classifier = nn.Sequential( |
|
nn.Linear(2048 + 768, 1024), |
|
nn.BatchNorm1d(1024), |
|
nn.ReLU(), |
|
nn.Dropout(0.3), |
|
nn.Linear(1024, 512), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(), |
|
nn.Dropout(0.3), |
|
nn.Linear(512, num_classes) |
|
) |
|
|
|
def forward(self, image, input_ids, attention_mask): |
|
image_features = self.image_encoder(image) |
|
text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
text_features = text_output.last_hidden_state[:, 0, :] |
|
combined_features = torch.cat((image_features, text_features), dim=1) |
|
output = self.classifier(combined_features) |
|
return output |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.RandomRotation(15), |
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
model = FineGrainedClassifier(num_classes=len(label_to_class)) |
|
checkpoint_url = f"https://huggingface.co./Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth" |
|
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu')) |
|
|
|
|
|
|
|
state_dict = checkpoint.get('model_state_dict', checkpoint) |
|
new_state_dict = {} |
|
for k, v in state_dict.items(): |
|
if k.startswith("module."): |
|
new_key = k[7:] |
|
else: |
|
new_key = k |
|
|
|
|
|
if new_key in model.state_dict(): |
|
new_state_dict[new_key] = v |
|
|
|
model.load_state_dict(new_state_dict) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(inputs): |
|
image = inputs.get("image") |
|
title = inputs.get("title") |
|
if not isinstance(title, str): |
|
return {"error": "Title must be a string."} |
|
|
|
if not isinstance(image, (Image.Image, torch.Tensor)): |
|
return {"error": "Image must be a valid image file or a tensor."} |
|
|
|
threshold = 0.4 |
|
|
|
if not title or len(title.split()) < 3: |
|
raise gr.Error("Title must be at least 3 words long. Please provide a valid title.") |
|
|
|
|
|
image = load_image(image_path_or_file) |
|
|
|
|
|
title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt') |
|
input_ids = title_encoding['input_ids'] |
|
attention_mask = title_encoding['attention_mask'] |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
output = model(image, input_ids=input_ids, attention_mask=attention_mask) |
|
probabilities = torch.nn.functional.softmax(output, dim=1) |
|
top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1) |
|
|
|
|
|
with open("label_to_class.json", "r") as f: |
|
label_to_class = json.load(f) |
|
|
|
|
|
top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]] |
|
|
|
|
|
if top3_probabilities[0][0].item() < threshold: |
|
top3_classes.insert(0, "Others") |
|
top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1) |
|
|
|
|
|
results = {} |
|
for i in range(len(top3_classes)): |
|
results[top3_classes[i]] = top3_probabilities[0][i].item() |
|
|
|
return results |