Spaces:
Running
Running
import gradio as gr | |
import torch | |
from torch import nn | |
from transformers import SiglipImageProcessor,SiglipModel | |
import dbimutils as utils | |
class ScoreClassifier(nn.Module): | |
def __init__(self): | |
super(ScoreClassifier, self).__init__() | |
self.classifier = nn.Sequential( | |
nn.Linear(256, 1), | |
nn.Sigmoid() | |
) | |
self.extractor = nn.Sequential( | |
nn.Linear(768, 512), | |
nn.BatchNorm1d(512), | |
nn.ReLU(), | |
nn.Linear(512, 256), | |
nn.BatchNorm1d(256), | |
nn.ReLU(), | |
nn.Linear(256, 256), | |
nn.ReLU(), | |
) | |
def forward(self, img): | |
return self.classifier(self.extractor(img)) | |
from huggingface_hub import hf_hub_download | |
model_file = hf_hub_download(repo_id="Muinez/Image-scorer", filename="scorer.pth") | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = ScoreClassifier().to(DEVICE) | |
model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu'))) | |
model.eval() | |
processor = SiglipImageProcessor.from_pretrained('google/siglip-base-patch16-512') | |
siglip = SiglipModel.from_pretrained('google/siglip-base-patch16-512').to(DEVICE) | |
def predict(img): | |
img = utils.preprocess_image(img) | |
encoded = processor(img, return_tensors="pt").pixel_values.to(DEVICE) | |
with torch.no_grad(): | |
score = model(siglip.get_image_features(encoded)) | |
return score.item() | |
gr.Interface( | |
title="Image scorer", | |
description="Predicts score (0-1) for image.\nCould be wrong", | |
fn=predict, | |
allow_flagging="never", | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Number(label="Score")] | |
).launch() |