File size: 3,006 Bytes
3e6ae58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import List, Optional
import requests

import logging

from haystack import Document, component
from haystack.lazy_imports import LazyImport
from PIL import Image

logger = logging.getLogger(__name__)

with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
    import torch
    from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, BlipProcessor, BlipForConditionalGeneration
    from PIL import Image

@component
class ImageCaptioner:
    def __init__(
        self,
        model_name: str = "Salesforce/blip-image-captioning-base",
    ):
        torch_and_transformers_import.check()
        self.model_name = model_name

        if model_name == "nlpconnect/vit-gpt2-image-captioning":
            self.model = VisionEncoderDecoderModel.from_pretrained(model_name)
            self.feature_extractor = ViTImageProcessor.from_pretrained(model_name)
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            max_length = 16
            num_beams = 4
            self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
        else:
            self.processor = BlipProcessor.from_pretrained(model_name)
            self.model = BlipForConditionalGeneration.from_pretrained(model_name)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
    @component.output_types(captions=List[str])
    def run(self, image_file_paths: List[str]) -> List[Document]:
        
        images = []
        for image_path in image_file_paths:
            i_image = Image.open(image_path)
            if i_image.mode != "RGB":
                i_image = i_image.convert(mode="RGB")

            images.append(i_image)
        
        preds = []
        if self.model_name == "nlpconnect/vit-gpt2-image-captioning":
            pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
            pixel_values = pixel_values.to(self.device)

            output_ids = self.model.generate(pixel_values, **self.gen_kwargs)

            preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            preds = [pred.strip() for pred in preds]
        else:

            inputs = self.processor(images, return_tensors="pt")
            output_ids = self.model.generate(**inputs)
            preds = self.processor.batch_decode(output_ids, skip_special_tokens=True)
            preds = [pred.strip() for pred in preds]
        
        # captions: List[Document] = []
        # for caption, image_file_path in zip(preds, image_file_paths):
        #     document = Document(content=caption, meta={"image_path": image_file_path})
        #     captions.append(document)
        return {"captions": preds}
    
# captioner = ImageCaptioner(model_name="Salesforce/blip-image-captioning-base")
# result = captioner.run(image_file_paths=["selfie.png"])
# print(result)