File size: 5,836 Bytes
43bc4d3 |
|
from transformers import AutoModel, AutoTokenizer
import torch
import json
import requests
from PIL import Image
from torchvision import transforms
import urllib.request
from torchvision import models
import torch.nn as nn
schema ={
"inputs": [
{
"name": "image",
"type": "image",
"description": "The image file to classify."
},
{
"name": "title",
"type": "string",
"description": "The text title associated with the image."
}
],
"outputs": [
{
"name": "label",
"type": "string",
"description": "Predicted class label."
},
{
"name": "probability",
"type": "float",
"description": "Prediction confidence score."
}
]
}
# --- Define the Model ---
class FineGrainedClassifier(nn.Module):
def __init__(self, num_classes=434): # Updated to 434 classes
super(FineGrainedClassifier, self).__init__()
self.image_encoder = models.resnet50(pretrained=True)
self.image_encoder.fc = nn.Identity()
self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en')
self.classifier = nn.Sequential(
nn.Linear(2048 + 768, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes) # Updated to 434 classes
)
def forward(self, image, input_ids, attention_mask):
image_features = self.image_encoder(image)
text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
text_features = text_output.last_hidden_state[:, 0, :]
combined_features = torch.cat((image_features, text_features), dim=1)
output = self.classifier(combined_features)
return output
# --- Data Augmentation Setup ---
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# # Load the label-to-class mapping from your Hugging Face repository
# label_map_url = "https://huggingface.co./Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
# label_to_class = requests.get(label_map_url).json()
# Load your custom model from Hugging Face
model = FineGrainedClassifier(num_classes=len(label_to_class))
checkpoint_url = f"https://huggingface.co./Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth"
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
# Strip the "module." prefix from the keys in the state_dict if they exist
# Clean up the state dictionary
state_dict = checkpoint.get('model_state_dict', checkpoint)
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("module."):
new_key = k[7:] # Remove "module." prefix
else:
new_key = k
# Check if the new_key exists in the model's state_dict, only add if it does
if new_key in model.state_dict():
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)
# Load the tokenizer from Jina
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
# def load_image(image_path_or_url):
# if isinstance(image_path_or_url, str) and image_path_or_url.startswith("http"):
# with urllib.request.urlopen(image_path_or_url) as url:
# image = Image.open(url).convert('RGB')
# else:
# image = Image.open(image_path_or_url).convert('RGB')
# image = transform(image)
# image = image.unsqueeze(0) # Add batch dimension
# return image
# def predict(image_path_or_file, title, threshold=0.4):
def inference(inputs):
image = inputs.get("image")
title = inputs.get("title")
if not isinstance(title, str):
return {"error": "Title must be a string."}
if not isinstance(image, (Image.Image, torch.Tensor)):
return {"error": "Image must be a valid image file or a tensor."}
threshold = 0.4
# Validation: Check if the title is empty or has fewer than 3 words
if not title or len(title.split()) < 3:
raise gr.Error("Title must be at least 3 words long. Please provide a valid title.")
# Preprocess the image
image = load_image(image_path_or_file)
# Tokenize title
title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
input_ids = title_encoding['input_ids']
attention_mask = title_encoding['attention_mask']
# Predict
model.eval()
with torch.no_grad():
output = model(image, input_ids=input_ids, attention_mask=attention_mask)
probabilities = torch.nn.functional.softmax(output, dim=1)
top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1)
# Map indices to class names (Assuming you have a mapping)
with open("label_to_class.json", "r") as f:
label_to_class = json.load(f)
# Map the top 3 indices to class names
top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]]
# Check if the highest probability is below the threshold
if top3_probabilities[0][0].item() < threshold:
top3_classes.insert(0, "Others")
top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1)
# Prepare the output as a dictionary
results = {}
for i in range(len(top3_classes)):
results[top3_classes[i]] = top3_probabilities[0][i].item()
return results |