PolyAI-pheme / utils /__init__.py
taras-sereda's picture
minimal set of files to run inference; pheme-small checkpoint
96ee597
raw
history blame
1.57 kB
"""Copyright PolyAI Limited."""
import logging
import pdb
import sys
import traceback
from functools import wraps
from time import time
from typing import List
import torch
from .symbol_table import SymbolTable
def load_checkpoint(ckpt_path: str) -> dict:
"""
Loads checkpoint, while matching phone embedding size.
"""
state_dict: dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
new_state_dict = dict()
for p_name in state_dict.keys():
if p_name.startswith("vocoder"):
continue
new_state_dict[p_name] = state_dict[p_name]
return new_state_dict
def breakpoint_on_error(fn):
"""Creates a breakpoint on error
Use as a wrapper
Args:
fn: the function
Returns:
inner function
"""
def inner(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception:
"""Standard python way of creating a breakpoint on error"""
extype, value, tb = sys.exc_info()
print(f"extype={extype},\nvalue={value}")
traceback.print_exc()
pdb.post_mortem(tb)
return inner
def measure_duration(f):
@wraps(f)
def wrap(*args, **kw):
ts = time()
result = f(*args, **kw)
te = time()
logging.debug("func:%r took: %2.4f sec" % (f.__name__, te - ts))
return result
return wrap
def split_metapath(in_paths: List[str]):
other_paths = []
for itm_path in in_paths:
other_paths.append(itm_path)
return other_paths