PFEemp2024's picture
solving GPU error for previous version
4a1df2e
raw
history blame contribute delete
612 Bytes
"""
Util function for Model Wrapper
---------------------------------------------------------------------
"""
import glob
import os
import torch
def load_cached_state_dict(model_folder_path):
# Take the first model matching the pattern *model.bin.
model_path_list = glob.glob(os.path.join(model_folder_path, "*model.bin"))
if not model_path_list:
raise FileNotFoundError(
f"model.bin not found in model folder {model_folder_path}."
)
model_path = model_path_list[0]
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
return state_dict