Spaces:
Sleeping
Sleeping
from typing import Union | |
import torch | |
def set_device(device : Union[str, torch.device]) -> torch.device: | |
""" | |
Set the device to use for inference. Recommended to use GPU. | |
Arguments: | |
device Union[str, torch.device] | |
The device to use for inference. Can be either a string or a torch.device object. | |
Returns: | |
torch.device | |
The device to use for inference. | |
""" | |
if isinstance(device, str): | |
if device == 'cuda' and torch.cuda.is_available(): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
elif device == 'mps' and torch.backends.mps.is_built(): | |
device = torch.device('mps') | |
else: | |
device = torch.device(device) | |
return device |