import hashlib import os import urllib import warnings from typing import Any, Union, List from pkg_resources import packaging from torch import nn import torch from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from .model_text_encoder import build_model from .simple_tokenizer import SimpleTokenizer as _Tokenizer try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC _tokenizer = _Tokenizer() def _convert_image_to_rgb(image): return image.convert("RGB") def load(): model = build_model(load_from_clip = False) return model def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: """ Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length truncate: bool Whether to truncate the text in case its encoding is longer than the context length Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. """ if isinstance(texts, str): texts = [texts] sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) else: result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: if truncate: tokens = tokens[:context_length] tokens[-1] = eot_token else: raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") result[i, :len(tokens)] = torch.tensor(tokens) return result