|
|
|
""" |
|
Processor class for EvaByte. |
|
""" |
|
import base64 |
|
from io import BytesIO |
|
|
|
import requests |
|
import os |
|
import PIL |
|
from PIL import Image |
|
|
|
from typing import List, Optional, Union |
|
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.image_utils import ImageInput, is_valid_image |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
|
from transformers.utils import TensorType, to_py_obj |
|
|
|
def fetch_image(image: Union[str, "PIL.Image.Image"]) -> Image.Image: |
|
image_obj = None |
|
if isinstance(image, Image.Image): |
|
image_obj = image |
|
elif image.startswith("http://") or image.startswith("https://"): |
|
image_obj = Image.open(BytesIO(requests.get(image, timeout=None).content)) |
|
elif os.path.isfile(image): |
|
image_obj = Image.open(image) |
|
elif image.startswith("data:image/"): |
|
image = image.split(",")[1] |
|
|
|
try: |
|
b64 = base64.decodebytes(image.encode()) |
|
image = PIL.Image.open(BytesIO(b64)) |
|
except Exception as e: |
|
raise ValueError( |
|
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}" |
|
) |
|
else: |
|
image_obj = Image.open(image) |
|
if image_obj is None: |
|
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") |
|
|
|
return image_obj |
|
|
|
def is_url(val) -> bool: |
|
return isinstance(val, str) and val.startswith("http") |
|
|
|
def is_file(val) -> bool: |
|
return isinstance(val, str) and os.path.isfile(val) |
|
|
|
def is_image_or_image_url(elem): |
|
return is_url(elem) or is_valid_image(elem) or is_file(elem) |
|
|
|
vl_chat_template = """ |
|
{{- bos_token }} |
|
{%- if messages[0]['role'] == 'system' %} |
|
{%- set system_message = messages[0]['content'] %} |
|
{%- set messages = messages[1:] %} |
|
{%- else %} |
|
{%- set system_message = "" %} |
|
{%- endif %} |
|
|
|
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>'}} |
|
|
|
{%- for message in messages %} |
|
{%- if (message['role'] != 'user') and (message['role'] != 'assistant') %} |
|
{{- raise_exception('Conversation roles must be user or assistant') }} |
|
{%- endif %} |
|
|
|
{%- if message['content'] is string %} |
|
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }} |
|
{%- else %} |
|
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} |
|
{%- for content in message['content'] %} |
|
{%- if content['type'] == 'image' %} |
|
{{- '<image_placeholder>\n' }} |
|
{%- elif content['type'] == 'text' %} |
|
{{- content['text'] }} |
|
{%- endif %} |
|
{%- endfor %} |
|
{{- '<|eot_id|>' }} |
|
{%- endif %} |
|
{%- endfor %} |
|
|
|
{%- if add_generation_prompt %} |
|
{{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }} |
|
{%- endif %} |
|
""" |
|
|
|
class EvaByteProcessor(ProcessorMixin): |
|
r""" |
|
Constructs a EvaByte processor which wraps a EvaByte image processor and a EvaByte tokenizer into a single processor. |
|
|
|
[`EvaByteProcessor`] offers all the functionalities of [`EvaByteImageProcessor`] and [`EvaByteTokenizer`]. See the |
|
[`~EvaByteProcessor.__call__`] and [`~EvaByteProcessor.decode`] for more information. |
|
|
|
Args: |
|
image_processor ([`EvaByteImageProcessor`], *optional*): |
|
The image processor is a required input. |
|
tokenizer ([`EvaByteTokenizer`], *optional*): |
|
The tokenizer is a required input. |
|
""" |
|
|
|
attributes = ["image_processor", "tokenizer"] |
|
image_processor_class = "AutoImageProcessor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__(self, image_processor=None, tokenizer=None, **kwargs): |
|
if image_processor is None: |
|
raise ValueError("You need to specify an `image_processor`.") |
|
if tokenizer is None: |
|
raise ValueError("You need to specify a `tokenizer`.") |
|
|
|
super().__init__(image_processor, tokenizer) |
|
self.t2v_token_id = self.tokenizer.convert_tokens_to_ids("<t2v_token>") |
|
self.v2t_token_id = self.tokenizer.convert_tokens_to_ids("<v2t_token>") |
|
self.image_placeholder = "<image_placeholder>" |
|
self.vl_chat_template = vl_chat_template |
|
|
|
def __call__( |
|
self, |
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, |
|
images: ImageInput = None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
strip_ending_sentinel: bool = False, |
|
encode_only: bool = False, |
|
**kwargs |
|
) -> Union[BatchFeature, List[List[int]]]: |
|
|
|
|
|
|
|
if images is not None: |
|
if isinstance(images, bytes): |
|
image_bytes_list = [[images]] |
|
elif isinstance(images, list) and isinstance(images[0], bytes): |
|
image_bytes_list = [images] |
|
elif isinstance(images, list) and isinstance(images[0], list) and isinstance(images[0][0], bytes): |
|
image_bytes_list = images |
|
else: |
|
if is_image_or_image_url(images): |
|
images = [[images]] |
|
elif isinstance(images, list) and is_image_or_image_url(images[0]): |
|
images = [images] |
|
elif ( |
|
not isinstance(images, list) |
|
and not isinstance(images[0], list) |
|
and not is_image_or_image_url(images[0][0]) |
|
): |
|
raise ValueError( |
|
"Invalid input images. Please provide a single image or a list of images or a list of list of images." |
|
) |
|
|
|
images = [[fetch_image(im) if is_url(im) or is_file(im) else im for im in sample] for sample in images] |
|
image_bytes_list = self.image_processor(images=images, **kwargs) |
|
|
|
if not isinstance(text, list): |
|
text = [text] |
|
assert len(text) == 1, "Only support batch size 1 for now" |
|
assert len(text) == len(image_bytes_list), "text and image_bytes_list must have the same length" |
|
|
|
|
|
|
|
|
|
batch_input_ids = [] |
|
if not encode_only: |
|
batch_attention_mask = [] |
|
else: |
|
batch_attention_mask = None |
|
|
|
for t, image_bytes in zip(text, image_bytes_list): |
|
text_splits = t.split(self.image_placeholder) |
|
if len(text_splits) != len(image_bytes) + 1: |
|
raise ValueError( |
|
f"The number of image tokens should be equal to the number of images, " |
|
f"but got {len(text_splits)} and {len(image_bytes) + 1}" |
|
) |
|
|
|
input_ids = [self.tokenizer.bos_token_id] |
|
for i, text_part in enumerate(text_splits): |
|
|
|
split_tokens = self.tokenizer.encode(text_part, add_special_tokens=False) |
|
input_ids.extend(split_tokens) |
|
|
|
if i < len(image_bytes): |
|
input_ids.append(self.t2v_token_id) |
|
input_ids.extend([b + self.tokenizer.offset for b in image_bytes[i]]) |
|
input_ids.append(self.v2t_token_id) |
|
|
|
if strip_ending_sentinel and (input_ids[-1] in [self.t2v_token_id, self.v2t_token_id]): |
|
input_ids = input_ids[:-1] |
|
|
|
batch_input_ids.append(input_ids) |
|
if not encode_only: |
|
batch_attention_mask.append([1] * len(input_ids)) |
|
|
|
if not encode_only: |
|
|
|
inputs = BatchFeature({ |
|
"input_ids": batch_input_ids, |
|
"attention_mask": batch_attention_mask |
|
}, tensor_type=return_tensors) |
|
return inputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
return batch_input_ids |
|
|
|
def image_tokens_to_bytes(self, image_token_ids, jpeg_quality=None): |
|
image_bytes = bytes([token_id - self.tokenizer.offset for token_id in image_token_ids]) |
|
image_bytes = self.image_processor.jpeg_merge_qtables(image_bytes, jpeg_quality) |
|
return image_bytes |
|
|
|
def batch_decode(self, sequences, **kwargs): |
|
""" |
|
This method forwards all its arguments to EvaByteTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
rets = [self.decode(seq, **kwargs) for seq in sequences] |
|
return tuple(map(list, zip(*rets))) |
|
|
|
def decode(self, token_ids, **kwargs): |
|
""" |
|
Decodes a sequence of input_ids, handling image tokens separately. |
|
Returns a tuple of (decoded_text, images), where images is a list of bytes. |
|
""" |
|
if kwargs and "jpeg_quality" in kwargs: |
|
kwargs = kwargs.copy() |
|
jpeg_quality = kwargs.pop("jpeg_quality") |
|
else: |
|
jpeg_quality = None |
|
|
|
token_ids = to_py_obj(token_ids) |
|
|
|
t2v_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.t2v_token_id] |
|
v2t_indices = [i for i, token_id in enumerate(token_ids) if token_id == self.v2t_token_id] |
|
|
|
|
|
if len(t2v_indices) != len(v2t_indices): |
|
raise ValueError("Mismatched number of t2v and v2t tokens in token_ids: {} and {}".format(t2v_indices, v2t_indices)) |
|
|
|
|
|
for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices): |
|
if t2v_idx >= v2t_idx: |
|
raise ValueError("Found t2v_token_id after v2t_token_id in token_ids") |
|
|
|
|
|
images = [] |
|
decoded_text = "" |
|
|
|
start = 0 |
|
|
|
for t2v_idx, v2t_idx in zip(t2v_indices, v2t_indices): |
|
|
|
text_token_ids = token_ids[start:t2v_idx] |
|
if len(text_token_ids) > 0: |
|
decoded_text += self.tokenizer.decode(text_token_ids, **kwargs) |
|
|
|
|
|
decoded_text += self.image_placeholder |
|
|
|
|
|
image_token_ids = token_ids[t2v_idx + 1 : v2t_idx] |
|
image_bytes = self.image_tokens_to_bytes(image_token_ids, jpeg_quality) |
|
images.append(image_bytes) |
|
|
|
|
|
start = v2t_idx + 1 |
|
|
|
|
|
if start < len(token_ids): |
|
text_token_ids = token_ids[start:] |
|
decoded_text += self.tokenizer.decode(text_token_ids, **kwargs) |
|
|
|
return decoded_text, images |
|
|
|
@property |
|
def model_input_names(self): |
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
image_processor_input_names = self.image_processor.model_input_names |
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |