import torch.hub from transformers import ( CLIPVisionModel, CLIPVisionConfig, CLIPModel, CLIPProcessor, AutoTokenizer, CLIPTextModelWithProjection, CLIPTextConfig, CLIPVisionModelWithProjection, ResNetModel, ResNetConfig ) from torch import nn from PIL import Image import requests class CLIP(nn.Module): def __init__(self, path): """Initializes the CLIP model.""" super().__init__() if path == "": config_vision = CLIPVisionConfig() self.clip = CLIPVisionModel(config_vision) else: self.clip = CLIPVisionModel.from_pretrained(path) def forward(self, x): """Predicts CLIP features from an image. Args: x (dict that contains "img": torch.Tensor): Input batch """ features = self.clip(pixel_values=x["img"])["last_hidden_state"] return features class CLIPJZ(nn.Module): def __init__(self, path): """Initializes the CLIP model.""" super().__init__() if path == "": config_vision = CLIPVisionConfig() self.clip = CLIPVisionModel(config_vision) else: self.clip = CLIPVisionModel.from_pretrained(path) def forward(self, x): """Predicts CLIP features from an image. Args: x (dict that contains "img": torch.Tensor): Input batch """ features = self.clip(pixel_values=x["img"])["last_hidden_state"] return features class StreetCLIP(nn.Module): def __init__(self, path): """Initializes the CLIP model.""" super().__init__() self.clip = CLIPModel.from_pretrained(path) self.transform = CLIPProcessor.from_pretrained(path) def forward(self, x): """Predicts CLIP features from an image. Args: x (dict that contains "img": torch.Tensor): Input batch """ features = self.clip.get_image_features( **self.transform(images=x["img"], return_tensors="pt").to(x["gps"].device) ).unsqueeze(1) return features class CLIPText(nn.Module): def __init__(self, path): """Initializes the CLIP model.""" super().__init__() if path == "": config_vision = CLIPVisionConfig() self.clip = CLIPVisionModel(config_vision) else: self.clip = CLIPVisionModelWithProjection.from_pretrained(path) def forward(self, x): """Predicts CLIP features from an image. Args: x (dict that contains "img": torch.Tensor): Input batch """ features = self.clip(pixel_values=x["img"]) return features.image_embeds, features.last_hidden_state class TextEncoder(nn.Module): def __init__(self, path): """Initializes the CLIP text model.""" super().__init__() if path == "": config_vision = CLIPTextConfig() self.clip = CLIPTextModelWithProjection(config_vision) self.transform = AutoTokenizer() else: self.clip = CLIPTextModelWithProjection.from_pretrained(path) self.transform = AutoTokenizer.from_pretrained(path) for p in self.clip.parameters(): p.requires_grad = False self.clip.eval() def forward(self, x): """Predicts CLIP features from text. Args: x (dict that contains "text": list): Input batch """ features = self.clip( **self.transform(x["text"], padding=True, return_tensors="pt").to( x["gps"].device ) ).text_embeds return features class DINOv2(nn.Module): def __init__(self, tag) -> None: """Initializes the DINO model.""" super().__init__() self.dino = torch.hub.load("facebookresearch/dinov2", tag) self.stride = 14 # ugly but dinov2 stride = 14 def forward(self, x): """Predicts DINO features from an image.""" x = x["img"] # crop for stride _, _, H, W = x.shape H_new = H - H % self.stride W_new = W - W % self.stride x = x[:, :, :H_new, :W_new] # forward features x = self.dino.forward_features(x) x = x["x_prenorm"] return x class ResNet(nn.Module): def __init__(self, path): """Initializes the ResNet model.""" super().__init__() if path == "": config_vision = ResNetConfig() self.resnet = ResNetModel(config_vision) else: self.resnet = ResNetModel.from_pretrained(path) def forward(self, x): """Predicts ResNet50 features from an image. Args: x (dict that contains "img": torch.Tensor): Input batch """ features = self.resnet(x["img"])["pooler_output"] return features.squeeze()