English
VideoCLIP-XL / utils /text_encoder /text_encoder.py
jpWang's picture
initial commit
1622f6b
raw
history blame
2.43 kB
import hashlib
import os
import urllib
import warnings
from typing import Any, Union, List
from pkg_resources import packaging
from torch import nn
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from .model_text_encoder import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
_tokenizer = _Tokenizer()
def _convert_image_to_rgb(image):
return image.convert("RGB")
def load():
model = build_model(load_from_clip = False)
return model
def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result