from typing import Union import torch def set_device(device : Union[str, torch.device]) -> torch.device: """ Set the device to use for inference. Recommended to use GPU. Arguments: device Union[str, torch.device] The device to use for inference. Can be either a string or a torch.device object. Returns: torch.device The device to use for inference. """ if isinstance(device, str): if device == 'cuda' and torch.cuda.is_available(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') elif device == 'mps' and torch.backends.mps.is_built(): device = torch.device('mps') else: device = torch.device(device) return device