|
import torch |
|
import platform |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def __call__(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
|
|
|
|
def getPlatform(): |
|
plt = platform.system() |
|
if plt=='Darwin': |
|
return 'mac' |
|
return plt |
|
|
|
|
|
|
|
def hasGPU(plt:str): |
|
if plt == 'mac': |
|
return torch.backends.mps.is_available() |
|
return torch.cuda.is_available() |
|
|
|
|
|
|
|
def getDevice(plt:str): |
|
if plt == 'mac': |
|
return torch.device('mps') |
|
return torch.device('cuda') |
|
|
|
|
|
|
|
def disableWarnings(): |
|
import warnings |
|
warnings.filterwarnings("ignore", category=UserWarning, module="transformers.utils.generic") |
|
warnings.filterwarnings("ignore", category=UserWarning, module="trl.trainer.ppo_config") |
|
warnings.filterwarnings("ignore", message="torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly") |
|
|
|
|