File size: 12,175 Bytes
9ae9789
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# coding=utf-8
"""
Processor class for EvaByte.
"""
import base64
from io import BytesIO

import requests
import os
import PIL
from PIL import Image

from typing import List, Optional, Union

from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, is_valid_image
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import TensorType, to_py_obj

def fetch_image(image: Union[str, "PIL.Image.Image"]) -> Image.Image:
    image_obj = None
    if isinstance(image, Image.Image):
        image_obj = image
    elif image.startswith("http://") or image.startswith("https://"):
        image_obj = Image.open(BytesIO(requests.get(image, timeout=None).content))
    elif os.path.isfile(image):
        image_obj = Image.open(image)
    elif image.startswith("data:image/"):
        image = image.split(",")[1]
        # Try to load as base64
        try:
            b64 = base64.decodebytes(image.encode())
            image = PIL.Image.open(BytesIO(b64))
        except Exception as e:
            raise ValueError(
                f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
            )
    else:
        image_obj = Image.open(image)
    if image_obj is None:
        raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")

    return image_obj

def is_url(val) -> bool:
    return isinstance(val, str) and val.startswith("http")

def is_file(val) -> bool:
    return isinstance(val, str) and os.path.isfile(val)

def is_image_or_image_url(elem):
    return is_url(elem) or is_valid_image(elem) or is_file(elem)

vl_chat_template = """
{{- bos_token }}
{%- if messages[0]['role'] == 'system' %}
    {%- set system_message = messages[0]['content'] %}
    {%- set messages = messages[1:] %}
{%- else %}
    {%- set system_message = "" %}
{%- endif %}

{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}}

{%- for message in messages %}
    {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %}
        {{- raise_exception('Conversation roles must be user or assistant') }}
    {%- endif %}
    
    {%- if message['content'] is string %}
        {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
    {%- else %}
        {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
        {%- for content in message['content'] %}
            {%- if content['type'] == 'image' %}
                {{- '<image_placeholder>\n' }}
            {%- elif content['type'] == 'text' %}
                {{- content['text'] }}
            {%- endif %}
        {%- endfor %}
        {{- '<|eot_id|>' }}        
    {%- endif %}
{%- endfor %}

{%- if add_generation_prompt %}
    {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
{%- endif %}
"""

class EvaByteProcessor(ProcessorMixin):
    r"""
    Constructs a EvaByte processor which wraps a EvaByte image processor and a EvaByte tokenizer into a single processor.

    [`EvaByteProcessor`] offers all the functionalities of [`EvaByteImageProcessor`] and [`EvaByteTokenizer`]. See the
    [`~EvaByteProcessor.__call__`] and [`~EvaByteProcessor.decode`] for more information.

    Args:
        image_processor ([`EvaByteImageProcessor`], *optional*):
            The image processor is a required input.
        tokenizer ([`EvaByteTokenizer`], *optional*):
            The tokenizer is a required input.
    """

    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(self, image_processor=None, tokenizer=None, **kwargs):
        if image_processor is None:
            raise ValueError("You need to specify an `image_processor`.")
        if tokenizer is None:
            raise ValueError("You need to specify a `tokenizer`.")

        super().__init__(image_processor, tokenizer)
        self.t2v_token_id = self.tokenizer.convert_tokens_to_ids("<t2v_token>")
        self.v2t_token_id = self.tokenizer.convert_tokens_to_ids("<v2t_token>")
        self.image_placeholder = "<image_placeholder>"
        self.vl_chat_template = vl_chat_template

    def __call__(
        self,
        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
        images: ImageInput = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        strip_ending_sentinel: bool = False,
        encode_only: bool = False,
        **kwargs
    ) -> Union[BatchFeature, List[List[int]]]:
        # processing pipeline:
        # 1. read images or videos from paths
        # 2. use image_processor to convert images / videos to byte streams
        if images is not None:
            if isinstance(images, bytes):
                image_bytes_list = [[images]]
            elif isinstance(images, list) and isinstance(images[0], bytes):
                image_bytes_list = [images]
            elif isinstance(images, list) and isinstance(images[0], list) and isinstance(images[0][0], bytes):
                image_bytes_list = images
            else:
                if is_image_or_image_url(images):
                    images = [[images]]
                elif isinstance(images, list) and is_image_or_image_url(images[0]):
                    images = [images]
                elif (
                    not isinstance(images, list)
                    and not isinstance(images[0], list)
                    and not is_image_or_image_url(images[0][0])
                ):
                    raise ValueError(
                        "Invalid input images. Please provide a single image or a list of images or a list of list of images."
                    )
                # Load images if they are URLs
                images = [[fetch_image(im) if is_url(im) or is_file(im) else im for im in sample] for sample in images]
                image_bytes_list = self.image_processor(images=images, **kwargs)

        if not isinstance(text, list):
            text = [text]
        assert len(text) == 1, "Only support batch size 1 for now"
        assert len(text) == len(image_bytes_list), "text and image_bytes_list must have the same length"
        # TODO: invoke SequenceFeatureExtractor to get batched inputs

        # 3. tokenize the text and put images / videos byte streams into the placeholders
        #    surrounded by special tokens like "<image>" and "</image>"
        batch_input_ids = []
        if not encode_only:
            batch_attention_mask = []
        else:
            batch_attention_mask = None

        for t, image_bytes in zip(text, image_bytes_list):
            text_splits = t.split(self.image_placeholder)
            if len(text_splits) != len(image_bytes) + 1:
                raise ValueError(
                    f"The number of image tokens should be equal to the number of images, "
                    f"but got {len(text_splits)} and {len(image_bytes) + 1}"
                )

            input_ids = [self.tokenizer.bos_token_id]
            for i, text_part in enumerate(text_splits):
                # each text part must be non-empty because we added markers around placeholders
                split_tokens = self.tokenizer.encode(text_part, add_special_tokens=False)
                input_ids.extend(split_tokens)
                # Add image bytes after each text part except the last one
                if i < len(image_bytes):
                    input_ids.append(self.t2v_token_id)
                    input_ids.extend([b + self.tokenizer.offset for b in image_bytes[i]])
                    input_ids.append(self.v2t_token_id)

            if strip_ending_sentinel and (input_ids[-1] in [self.t2v_token_id, self.v2t_token_id]):
                input_ids = input_ids[:-1]

            batch_input_ids.append(input_ids)
            if not encode_only:
                batch_attention_mask.append([1] * len(input_ids))

        if not encode_only:
            # 4. return batch of features
            inputs = BatchFeature({
                "input_ids": batch_input_ids,
                "attention_mask": batch_attention_mask
            }, tensor_type=return_tensors)
            return inputs
            # # Pad sequences
            # padded_inputs = self.tokenizer.pad(
            #     {"input_ids": batch_input_ids},
            #     padding=True,
            #     return_attention_mask=True,
            #     return_tensors=return_tensors,
            # )
            # return BatchFeature(data=padded_inputs)
        else:
            return batch_input_ids

    def image_tokens_to_bytes(self, image_token_ids, jpeg_quality=None):
        image_bytes = bytes([token_id - self.tokenizer.offset for token_id in image_token_ids])
        image_bytes = self.image_processor.jpeg_merge_qtables(image_bytes, jpeg_quality)
        return image_bytes

    def batch_decode(self, sequences, **kwargs):
        """
        This method forwards all its arguments to EvaByteTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
        refer to the docstring of this method for more information.
        """
        rets = [self.decode(seq, **kwargs) for seq in sequences]
        return tuple(map(list, zip(*rets)))

    def decode(self, token_ids, **kwargs):
        """
        Decodes a sequence of input_ids, handling image tokens separately.
        Returns a tuple of (decoded_text, images), where images is a list of bytes.
        """
        if kwargs and "jpeg_quality" in kwargs:
            kwargs = kwargs.copy()
            jpeg_quality = kwargs.pop("jpeg_quality")
        else:
            jpeg_quality = None
        
        token_ids = to_py_obj(token_ids)
        # Find indices of t2v_token_id and v2t_token_id
        t2v_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.t2v_token_id]
        v2t_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.v2t_token_id]
        
        # Check for correct pairing of t2v and v2t tokens
        if len(t2v_indices) != len(v2t_indices):
            raise ValueError("Mismatched number of t2v and v2t tokens in token_ids: {} and {}".format(t2v_indices, v2t_indices))

        # Ensure t2v and v2t tokens are in the correct order
        for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
            if t2v_idx >= v2t_idx:
                raise ValueError("Found t2v_token_id after v2t_token_id in token_ids")

        # Initialize the start index
        images = []
        decoded_text = ""

        start = 0
        # Iterate over pairs of t2v and v2t indices
        for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices):
            # Decode text tokens before the image
            text_token_ids = token_ids[start:t2v_idx]
            if len(text_token_ids) > 0:
                decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)

            # Insert image placeholder
            decoded_text += self.image_placeholder

            # Extract image tokens and convert them to bytes
            image_token_ids = token_ids[t2v_idx + 1 : v2t_idx]
            image_bytes = self.image_tokens_to_bytes(image_token_ids, jpeg_quality)
            images.append(image_bytes)

            # Update the start index to the token after v2t_token_id
            start = v2t_idx + 1

        # Decode any remaining text tokens after the last image
        if start < len(token_ids):
            text_token_ids = token_ids[start:]
            decoded_text += self.tokenizer.decode(text_token_ids, **kwargs)

        return decoded_text, images

    @property
    def model_input_names(self):
        tokenizer_input_names = self.tokenizer.model_input_names
        image_processor_input_names = self.image_processor.model_input_names
        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))