# coding=utf-8 """ Processor class for EvaByte. """ import base64 from io import BytesIO import requests import os import PIL from PIL import Image from typing import List, Optional, Union from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput, is_valid_image from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import TensorType, to_py_obj def fetch_image(image: Union[str, "PIL.Image.Image"]) -> Image.Image: image_obj = None if isinstance(image, Image.Image): image_obj = image elif image.startswith("http://") or image.startswith("https://"): image_obj = Image.open(BytesIO(requests.get(image, timeout=None).content)) elif os.path.isfile(image): image_obj = Image.open(image) elif image.startswith("data:image/"): image = image.split(",")[1] # Try to load as base64 try: b64 = base64.decodebytes(image.encode()) image = PIL.Image.open(BytesIO(b64)) except Exception as e: raise ValueError( f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}" ) else: image_obj = Image.open(image) if image_obj is None: raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") return image_obj def is_url(val) -> bool: return isinstance(val, str) and val.startswith("http") def is_file(val) -> bool: return isinstance(val, str) and os.path.isfile(val) def is_image_or_image_url(elem): return is_url(elem) or is_valid_image(elem) or is_file(elem) vl_chat_template = """ {{- bos_token }} {%- if messages[0]['role'] == 'system' %} {%- set system_message = messages[0]['content'] %} {%- set messages = messages[1:] %} {%- else %} {%- set system_message = "" %} {%- endif %} {{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}} {%- for message in messages %} {%- if (message['role'] != 'user') and (message['role'] != 'assistant') %} {{- raise_exception('Conversation roles must be user or assistant') }} {%- endif %} {%- if message['content'] is string %} {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }} {%- else %} {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} {%- for content in message['content'] %} {%- if content['type'] == 'image' %} {{- '\n' }} {%- elif content['type'] == 'text' %} {{- content['text'] }} {%- endif %} {%- endfor %} {{- '<|eot_id|>' }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }} {%- endif %} """ class EvaByteProcessor(ProcessorMixin): r""" Constructs a EvaByte processor which wraps a EvaByte image processor and a EvaByte tokenizer into a single processor. [`EvaByteProcessor`] offers all the functionalities of [`EvaByteImageProcessor`] and [`EvaByteTokenizer`]. See the [`~EvaByteProcessor.__call__`] and [`~EvaByteProcessor.decode`] for more information. Args: image_processor ([`EvaByteImageProcessor`], *optional*): The image processor is a required input. tokenizer ([`EvaByteTokenizer`], *optional*): The tokenizer is a required input. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__(self, image_processor=None, tokenizer=None, **kwargs): if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: raise ValueError("You need to specify a `tokenizer`.") super().__init__(image_processor, tokenizer) self.t2v_token_id = self.tokenizer.convert_tokens_to_ids("") self.v2t_token_id = self.tokenizer.convert_tokens_to_ids("") self.image_placeholder = "" self.vl_chat_template = vl_chat_template def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, return_tensors: Optional[Union[str, TensorType]] = None, strip_ending_sentinel: bool = False, encode_only: bool = False, **kwargs ) -> Union[BatchFeature, List[List[int]]]: # processing pipeline: # 1. read images or videos from paths # 2. use image_processor to convert images / videos to byte streams if images is not None: if isinstance(images, bytes): image_bytes_list = [[images]] elif isinstance(images, list) and isinstance(images[0], bytes): image_bytes_list = [images] elif isinstance(images, list) and isinstance(images[0], list) and isinstance(images[0][0], bytes): image_bytes_list = images else: if is_image_or_image_url(images): images = [[images]] elif isinstance(images, list) and is_image_or_image_url(images[0]): images = [images] elif ( not isinstance(images, list) and not isinstance(images[0], list) and not is_image_or_image_url(images[0][0]) ): raise ValueError( "Invalid input images. Please provide a single image or a list of images or a list of list of images." ) # Load images if they are URLs images = [[fetch_image(im) if is_url(im) or is_file(im) else im for im in sample] for sample in images] image_bytes_list = self.image_processor(images=images, **kwargs) if not isinstance(text, list): text = [text] assert len(text) == 1, "Only support batch size 1 for now" assert len(text) == len(image_bytes_list), "text and image_bytes_list must have the same length" # TODO: invoke SequenceFeatureExtractor to get batched inputs # 3. tokenize the text and put images / videos byte streams into the placeholders # surrounded by special tokens like "" and "" batch_input_ids = [] if not encode_only: batch_attention_mask = [] else: batch_attention_mask = None for t, image_bytes in zip(text, image_bytes_list): text_splits = t.split(self.image_placeholder) if len(text_splits) != len(image_bytes) + 1: raise ValueError( f"The number of image tokens should be equal to the number of images, " f"but got {len(text_splits)} and {len(image_bytes) + 1}" ) input_ids = [self.tokenizer.bos_token_id] for i, text_part in enumerate(text_splits): # each text part must be non-empty because we added markers around placeholders split_tokens = self.tokenizer.encode(text_part, add_special_tokens=False) input_ids.extend(split_tokens) # Add image bytes after each text part except the last one if i < len(image_bytes): input_ids.append(self.t2v_token_id) input_ids.extend([b + self.tokenizer.offset for b in image_bytes[i]]) input_ids.append(self.v2t_token_id) if strip_ending_sentinel and (input_ids[-1] in [self.t2v_token_id, self.v2t_token_id]): input_ids = input_ids[:-1] batch_input_ids.append(input_ids) if not encode_only: batch_attention_mask.append([1] * len(input_ids)) if not encode_only: # 4. return batch of features inputs = BatchFeature({ "input_ids": batch_input_ids, "attention_mask": batch_attention_mask }, tensor_type=return_tensors) return inputs # # Pad sequences # padded_inputs = self.tokenizer.pad( # {"input_ids": batch_input_ids}, # padding=True, # return_attention_mask=True, # return_tensors=return_tensors, # ) # return BatchFeature(data=padded_inputs) else: return batch_input_ids def image_tokens_to_bytes(self, image_token_ids, jpeg_quality=None): image_bytes = bytes([token_id - self.tokenizer.offset for token_id in image_token_ids]) image_bytes = self.image_processor.jpeg_merge_qtables(image_bytes, jpeg_quality) return image_bytes def batch_decode(self, sequences, **kwargs): """ This method forwards all its arguments to EvaByteTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ rets = [self.decode(seq, **kwargs) for seq in sequences] return tuple(map(list, zip(*rets))) def decode(self, token_ids, **kwargs): """ Decodes a sequence of input_ids, handling image tokens separately. Returns a tuple of (decoded_text, images), where images is a list of bytes. """ if kwargs and "jpeg_quality" in kwargs: kwargs = kwargs.copy() jpeg_quality = kwargs.pop("jpeg_quality") else: jpeg_quality = None token_ids = to_py_obj(token_ids) # Find indices of t2v_token_id and v2t_token_id t2v_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.t2v_token_id] v2t_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.v2t_token_id] # Check for correct pairing of t2v and v2t tokens if len(t2v_indices) != len(v2t_indices): raise ValueError("Mismatched number of t2v and v2t tokens in token_ids: {} and {}".format(t2v_indices, v2t_indices)) # Ensure t2v and v2t tokens are in the correct order for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices): if t2v_idx >= v2t_idx: raise ValueError("Found t2v_token_id after v2t_token_id in token_ids") # Initialize the start index images = [] decoded_text = "" start = 0 # Iterate over pairs of t2v and v2t indices for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices): # Decode text tokens before the image text_token_ids = token_ids[start:t2v_idx] if len(text_token_ids) > 0: decoded_text += self.tokenizer.decode(text_token_ids, **kwargs) # Insert image placeholder decoded_text += self.image_placeholder # Extract image tokens and convert them to bytes image_token_ids = token_ids[t2v_idx + 1 : v2t_idx] image_bytes = self.image_tokens_to_bytes(image_token_ids, jpeg_quality) images.append(image_bytes) # Update the start index to the token after v2t_token_id start = v2t_idx + 1 # Decode any remaining text tokens after the last image if start < len(token_ids): text_token_ids = token_ids[start:] decoded_text += self.tokenizer.decode(text_token_ids, **kwargs) return decoded_text, images @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))