|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Union |
|
|
|
import torch |
|
from transformers import AutoTokenizer, T5Tokenizer |
|
|
|
|
|
class CustomCLIPTokenizer(T5Tokenizer): |
|
model_input_names = ["input_ids", "attention_mask", "position_ids"] |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.do_lower_case = True |
|
|
|
def __call__( |
|
self, |
|
texts: Union[str, list[str]], |
|
tokenizer: T5Tokenizer = None, |
|
max_seq_len: int = 77, |
|
device: Union[str, torch.device] = ( |
|
"cuda" if torch.cuda.is_available() else "cpu" |
|
), |
|
**kwargs, |
|
): |
|
if isinstance(texts, str): |
|
texts = [texts] |
|
if tokenizer is None: |
|
tokenizer = self |
|
tokenizer_call = super().__call__ |
|
else: |
|
tokenizer_call = tokenizer |
|
inputs = tokenizer_call( |
|
texts, |
|
max_length=max_seq_len - 1, |
|
padding="max_length", |
|
truncation=True, |
|
add_special_tokens=False, |
|
) |
|
|
|
input_ids = [[tokenizer.cls_token_id] + ids for ids in inputs["input_ids"]] |
|
attention_mask = [[1] + am for am in inputs["attention_mask"]] |
|
position_ids = [list(range(0, len(input_ids[0])))] * len(texts) |
|
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
attention_mask = torch.tensor(attention_mask, dtype=torch.long) |
|
position_ids = torch.tensor(position_ids, dtype=torch.long) |
|
return { |
|
"input_ids": input_ids.to(device), |
|
"attention_mask": attention_mask.to(device), |
|
"position_ids": position_ids.to(device), |
|
} |
|
|
|
|
|
AutoTokenizer.register("CustomCLIPTokenizer", CustomCLIPTokenizer) |
|
|