import json from PIL import Image import gradio as gr import torch from torchvision.transforms import transforms from torchvision.transforms import InterpolationMode import torchvision.transforms.functional as TF import spaces import huggingface_hub import timm from timm.models import VisionTransformer import safetensors.torch torch.jit.script = lambda f: f torch.set_grad_enabled(False) class Fit(torch.nn.Module): def __init__( self, bounds: tuple[int, int] | int, interpolation = InterpolationMode.LANCZOS, grow: bool = True, pad: float | None = None ): super().__init__() self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds self.interpolation = interpolation self.grow = grow self.pad = pad def forward(self, img: Image) -> Image: wimg, himg = img.size hbound, wbound = self.bounds hscale = hbound / himg wscale = wbound / wimg if not self.grow: hscale = min(hscale, 1.0) wscale = min(wscale, 1.0) scale = min(hscale, wscale) if scale == 1.0: return img hnew = min(round(himg * scale), hbound) wnew = min(round(wimg * scale), wbound) img = TF.resize(img, (hnew, wnew), self.interpolation) if self.pad is None: return img hpad = hbound - hnew wpad = wbound - wnew tpad = hpad // 2 bpad = hpad - tpad lpad = wpad // 2 rpad = wpad - lpad return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" + f"bounds={self.bounds}, " + f"interpolation={self.interpolation.value}, " + f"grow={self.grow}, " + f"pad={self.pad})" ) class CompositeAlpha(torch.nn.Module): def __init__( self, background: tuple[float, float, float] | float, ): super().__init__() self.background = (background, background, background) if isinstance(background, float) else background self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2) def forward(self, img: torch.Tensor) -> torch.Tensor: if img.shape[-3] == 3: return img alpha = img[..., 3, None, :, :] img[..., :3, :, :] *= alpha background = self.background.expand(-1, img.shape[-2], img.shape[-1]) if background.ndim == 1: background = background[:, None, None] elif background.ndim == 2: background = background[None, :, :] img[..., :3, :, :] += (1.0 - alpha) * background return img[..., :3, :, :] def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" + f"background={self.background})" ) transform = transforms.Compose([ Fit((384, 384)), transforms.ToTensor(), CompositeAlpha(0.5), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), transforms.CenterCrop((384, 384)), ]) model_file = huggingface_hub.hf_hub_download( repo_id="RedRocket/JointTaggerProject", filename="JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors", subfolder="JTP_PILOT" ) model = timm.create_model( "vit_so400m_patch14_siglip_384.webli", pretrained=False, num_classes=9083, ) # type: VisionTransformer safetensors.torch.load_model(model, model_file) model.eval() tags_file = huggingface_hub.hf_hub_download( repo_id="RedRocket/JointTaggerProject", filename="tags.json", subfolder="JTP_PILOT" ) with open(tags_file, "r") as file: tags = json.load(file) # type: dict allowed_tags = tags.keys() @spaces.GPU(duration=5) def create_tags(image, threshold): img = image.convert('RGB') tensor = transform(img).unsqueeze(0) with torch.no_grad(): logits = model(tensor) probabilities = torch.nn.functional.sigmoid(logits[0]) indices = torch.where(probabilities > threshold)[0] values = probabilities[indices] temp = [] tag_score = dict() for i in range(indices.size(0)): temp.append([allowed_tags[indices[i]], values[i].item()]) tag_score[allowed_tags[indices[i]]] = values[i].item() temp = [t[0] for t in temp] text_no_impl = ", ".join(temp) return text_no_impl, tag_score with gr.Blocks() as demo: with gr.Tab("Single Image"): gr.Interface( create_tags, inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold")], outputs=[ gr.Textbox(label="Tag String"), gr.Label(label="Tag Predictions", num_top_classes=200), ], allow_flagging="never", ) if __name__ == "__main__": demo.launch()