File size: 5,836 Bytes
43bc4d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from transformers import AutoModel, AutoTokenizer
import torch
import json
import requests
from PIL import Image
from torchvision import transforms
import urllib.request
from torchvision import models
import torch.nn as nn
schema ={
"inputs": [
{
"name": "image",
"type": "image",
"description": "The image file to classify."
},
{
"name": "title",
"type": "string",
"description": "The text title associated with the image."
}
],
"outputs": [
{
"name": "label",
"type": "string",
"description": "Predicted class label."
},
{
"name": "probability",
"type": "float",
"description": "Prediction confidence score."
}
]
}
# --- Define the Model ---
class FineGrainedClassifier(nn.Module):
def __init__(self, num_classes=434): # Updated to 434 classes
super(FineGrainedClassifier, self).__init__()
self.image_encoder = models.resnet50(pretrained=True)
self.image_encoder.fc = nn.Identity()
self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en')
self.classifier = nn.Sequential(
nn.Linear(2048 + 768, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes) # Updated to 434 classes
)
def forward(self, image, input_ids, attention_mask):
image_features = self.image_encoder(image)
text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
text_features = text_output.last_hidden_state[:, 0, :]
combined_features = torch.cat((image_features, text_features), dim=1)
output = self.classifier(combined_features)
return output
# --- Data Augmentation Setup ---
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# # Load the label-to-class mapping from your Hugging Face repository
# label_map_url = "https://huggingface.co./Maverick98/EcommerceClassifier/resolve/main/label_to_class.json"
# label_to_class = requests.get(label_map_url).json()
# Load your custom model from Hugging Face
model = FineGrainedClassifier(num_classes=len(label_to_class))
checkpoint_url = f"https://huggingface.co./Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth"
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
# Strip the "module." prefix from the keys in the state_dict if they exist
# Clean up the state dictionary
state_dict = checkpoint.get('model_state_dict', checkpoint)
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("module."):
new_key = k[7:] # Remove "module." prefix
else:
new_key = k
# Check if the new_key exists in the model's state_dict, only add if it does
if new_key in model.state_dict():
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)
# Load the tokenizer from Jina
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
# def load_image(image_path_or_url):
# if isinstance(image_path_or_url, str) and image_path_or_url.startswith("http"):
# with urllib.request.urlopen(image_path_or_url) as url:
# image = Image.open(url).convert('RGB')
# else:
# image = Image.open(image_path_or_url).convert('RGB')
# image = transform(image)
# image = image.unsqueeze(0) # Add batch dimension
# return image
# def predict(image_path_or_file, title, threshold=0.4):
def inference(inputs):
image = inputs.get("image")
title = inputs.get("title")
if not isinstance(title, str):
return {"error": "Title must be a string."}
if not isinstance(image, (Image.Image, torch.Tensor)):
return {"error": "Image must be a valid image file or a tensor."}
threshold = 0.4
# Validation: Check if the title is empty or has fewer than 3 words
if not title or len(title.split()) < 3:
raise gr.Error("Title must be at least 3 words long. Please provide a valid title.")
# Preprocess the image
image = load_image(image_path_or_file)
# Tokenize title
title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
input_ids = title_encoding['input_ids']
attention_mask = title_encoding['attention_mask']
# Predict
model.eval()
with torch.no_grad():
output = model(image, input_ids=input_ids, attention_mask=attention_mask)
probabilities = torch.nn.functional.softmax(output, dim=1)
top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1)
# Map indices to class names (Assuming you have a mapping)
with open("label_to_class.json", "r") as f:
label_to_class = json.load(f)
# Map the top 3 indices to class names
top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]]
# Check if the highest probability is below the threshold
if top3_probabilities[0][0].item() < threshold:
top3_classes.insert(0, "Others")
top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1)
# Prepare the output as a dictionary
results = {}
for i in range(len(top3_classes)):
results[top3_classes[i]] = top3_probabilities[0][i].item()
return results |