File size: 3,063 Bytes
30099ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# coding=utf-8
# Copyright 2022 rinna Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
import json
import torch
from torchvision import transforms as T
from huggingface_hub import hf_hub_url, cached_download
import os

from .clip import CLIPModel
from .cloob import CLOOBModel

# TODO: Fill in repo_ids
MODELS = {
    'rinna/japanese-clip-vit-b-16': {
        'repo_id': 'rinna/japanese-clip-vit-b-16',
        'model_class': CLIPModel,
    },
    'rinna/japanese-cloob-vit-b-16': {
        'repo_id': 'rinna/japanese-cloob-vit-b-16',
        'model_class': CLOOBModel,
    }
}
MODEL_CLASSES = {
    "cloob": CLOOBModel,
    "clip": CLIPModel,
}
MODEL_FILE = "pytorch_model.bin"
CONFIG_FILE = "config.json"


def available_models():
    return list(MODELS.keys())


def _convert_to_rgb(image):
    return image.convert('RGB')


def _transform(image_size):
    return T.Compose([
        T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
        T.CenterCrop(image_size),
        _convert_to_rgb,
        T.ToTensor(),
        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711),)
    ])


def _download(repo_id: str, cache_dir: str):
    config_file_url = hf_hub_url(repo_id=repo_id, filename=CONFIG_FILE)
    cached_download(config_file_url, cache_dir=cache_dir)
    model_file_url = hf_hub_url(repo_id=repo_id, filename=MODEL_FILE)
    cached_download(model_file_url, cache_dir=cache_dir)


def load(
        model_name: str,
        device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
        **kwargs
):
    """
    Args:
        model_name: model unique name or path to pre-downloaded model
        device: device to put the loaded model
        kwargs: kwargs for huggingface pretrained model class
    Return:
        (torch.nn.Module, A torchvision transform)
    """
    if model_name in MODELS.keys():
        ModelClass = CLIPModel if 'clip' in model_name else CLOOBModel
    elif os.path.exists(model_name):
        assert os.path.exists(os.path.join(model_name, CONFIG_FILE))
        with open(os.path.join(model_name, CONFIG_FILE), "r", encoding="utf-8") as f:
            j = json.load(f)
        ModelClass = MODEL_CLASSES[j["model_type"]]
    else:
        RuntimeError(f"Model {model_name} not found; available models = {available_models()}")

    model = ModelClass.from_pretrained(model_name, **kwargs)
    model = model.eval().requires_grad_(False).to(device)
    return model, _transform(model.config.vision_config.image_size)