import numpy as np from PIL import Image from transformers import ViTFeatureExtractor, ViTModel from .abstract_embedder import AbstractImageEmbedder class DinoEmbedder(AbstractImageEmbedder): def __init__(self, device: str = "cpu", model_name: str = "facebook/dino-vitb8"): super().__init__(device) self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) self.model = ViTModel.from_pretrained(model_name).to(self.device) def embed(self, image: Image) -> np.ndarray: inputs = self.feature_extractor(images=image, return_tensors="pt") for key in inputs: inputs[key] = inputs[key].to(self.device) outputs = self.model(**inputs) last_hidden_states = outputs.last_hidden_state.to("cpu").numpy() return last_hidden_states