from typing import List, Optional import requests import logging from haystack import Document, component from haystack.lazy_imports import LazyImport from PIL import Image logger = logging.getLogger(__name__) with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: import torch from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, BlipProcessor, BlipForConditionalGeneration from PIL import Image @component class ImageCaptioner: def __init__( self, model_name: str = "Salesforce/blip-image-captioning-base", ): torch_and_transformers_import.check() self.model_name = model_name if model_name == "nlpconnect/vit-gpt2-image-captioning": self.model = VisionEncoderDecoderModel.from_pretrained(model_name) self.feature_extractor = ViTImageProcessor.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name) max_length = 16 num_beams = 4 self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams} else: self.processor = BlipProcessor.from_pretrained(model_name) self.model = BlipForConditionalGeneration.from_pretrained(model_name) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) @component.output_types(captions=List[str]) def run(self, image_file_paths: List[str]) -> List[Document]: images = [] for image_path in image_file_paths: i_image = Image.open(image_path) if i_image.mode != "RGB": i_image = i_image.convert(mode="RGB") images.append(i_image) preds = [] if self.model_name == "nlpconnect/vit-gpt2-image-captioning": pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values pixel_values = pixel_values.to(self.device) output_ids = self.model.generate(pixel_values, **self.gen_kwargs) preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) preds = [pred.strip() for pred in preds] else: inputs = self.processor(images, return_tensors="pt") output_ids = self.model.generate(**inputs) preds = self.processor.batch_decode(output_ids, skip_special_tokens=True) preds = [pred.strip() for pred in preds] # captions: List[Document] = [] # for caption, image_file_path in zip(preds, image_file_paths): # document = Document(content=caption, meta={"image_path": image_file_path}) # captions.append(document) return {"captions": preds} # captioner = ImageCaptioner(model_name="Salesforce/blip-image-captioning-base") # result = captioner.run(image_file_paths=["selfie.png"]) # print(result)