|
import hashlib |
|
import os |
|
import urllib |
|
import warnings |
|
from typing import Any, Union, List |
|
from pkg_resources import packaging |
|
from torch import nn |
|
import torch |
|
from PIL import Image |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|
|
|
from .model_text_encoder import build_model |
|
from .simple_tokenizer import SimpleTokenizer as _Tokenizer |
|
|
|
try: |
|
from torchvision.transforms import InterpolationMode |
|
BICUBIC = InterpolationMode.BICUBIC |
|
except ImportError: |
|
BICUBIC = Image.BICUBIC |
|
|
|
|
|
_tokenizer = _Tokenizer() |
|
|
|
|
|
def _convert_image_to_rgb(image): |
|
return image.convert("RGB") |
|
|
|
|
|
def load(): |
|
model = build_model(load_from_clip = False) |
|
|
|
return model |
|
|
|
|
|
def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: |
|
""" |
|
Returns the tokenized representation of given input string(s) |
|
|
|
Parameters |
|
---------- |
|
texts : Union[str, List[str]] |
|
An input string or a list of input strings to tokenize |
|
|
|
context_length : int |
|
The context length to use; all CLIP models use 77 as the context length |
|
|
|
truncate: bool |
|
Whether to truncate the text in case its encoding is longer than the context length |
|
|
|
Returns |
|
------- |
|
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. |
|
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. |
|
""" |
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
sot_token = _tokenizer.encoder["<|startoftext|>"] |
|
eot_token = _tokenizer.encoder["<|endoftext|>"] |
|
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] |
|
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): |
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
|
else: |
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) |
|
|
|
for i, tokens in enumerate(all_tokens): |
|
if len(tokens) > context_length: |
|
if truncate: |
|
tokens = tokens[:context_length] |
|
tokens[-1] = eot_token |
|
else: |
|
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") |
|
result[i, :len(tokens)] = torch.tensor(tokens) |
|
|
|
return result |
|
|