Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
""" | |
Model Factory for loading a checkpoint from gcloud, then initializing the model and the configuration | |
""" | |
import os | |
import requests | |
import torch | |
import tqdm | |
from cwm.model import model_pretrain | |
GCLOUD_BUCKET_NAME = "stanford_neuroai_models" | |
GCLOUD_URL_NAME = "https://storage.googleapis.com/stanford_neuroai_models" | |
CACHE_PATH = f"{os.getenv('CACHE')}/stanford_neuroai_models" if os.getenv('CACHE') is not None else ".cache/stanford_neuroai_models" | |
_model_catalogue ={ | |
"vitb_8x8patch_3frames": { | |
"path": "cwm/3frame_cwm_8x8.pth", | |
"init_fn": model_pretrain.vitb_8x8patch_3frames, | |
}, | |
"vitb_4x4patch_2frames": { | |
"path": "cwm/2frame_cwm_4x4.pth", | |
"init_fn": model_pretrain.vitb_4x4patch_2frames, | |
}, | |
"vitb_8x8patch_2frames": { | |
"path": "cwm/2frame_cwm_8x8.pth", | |
"init_fn": model_pretrain.vitb_8x8patch_2frames, | |
}, | |
"vitb_8x8patch_2frames_encoder_mask_token": { | |
"path": "cwm/2frame_cwm_mask_token.pth", | |
"init_fn": model_pretrain.vitb_8x8patch_2frames_encoder_mask_token, | |
}, | |
} | |
class ModelFactory: | |
def __init__(self, bucket_name: str = GCLOUD_BUCKET_NAME): | |
self.bucket_name = bucket_name | |
def get_catalog(self): | |
""" | |
Get the list of available models | |
""" | |
# Initialize the storage client | |
return _model_catalogue.keys() | |
def load_model(self, model_name: str, force_download=False): | |
""" | |
Load the model given the name | |
Args: | |
model_name: str | |
Name of the model to load | |
force_download: bool (optional) | |
Whether to force the download of the freshest weights from gcloud | |
Returns: | |
model: torch.nn.Module | |
Model initialized from the checkpoint | |
""" | |
# Find cache dir, use a directory inside of it as the checkpoint path | |
checkpoint_path = os.path.join(CACHE_PATH, _model_catalogue[model_name]["path"]) | |
# Make checkpoint directory if it does not exist | |
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) | |
# Construct gcloud url | |
gcloud_url = os.path.join(GCLOUD_URL_NAME, _model_catalogue[model_name]['path']) | |
# Download the model from google cloud using requests (with tqdm timer) | |
response = requests.get(gcloud_url, stream=True) | |
total_size_in_bytes= int(response.headers.get('content-length', 0)) | |
block_size = 1024 # 1 Kibibyte | |
# If force_download is true or the model has not yet been downloaded, grab it from gcloud | |
if force_download or not os.path.exists(checkpoint_path): | |
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | |
print(f"Saving model to cache: {CACHE_PATH}") | |
with open(checkpoint_path, 'wb') as file: | |
for data in response.iter_content(block_size): | |
progress_bar.update(len(data)) | |
file.write(data) | |
# Initialize the model and the configuration from the checkpoint | |
print("checkpoint_path", checkpoint_path) | |
ckpt = torch.load(checkpoint_path, map_location="cpu") | |
# Initialize the model from the specified initialization function | |
model = _model_catalogue[model_name]["init_fn"]() | |
# Load the model from the checkpoint | |
model.load_state_dict(ckpt['model'], strict=True) | |
print('Model loaded successfully') | |
return model | |
model_factory = ModelFactory() |