kbora's picture
Upload 51 files
6af7294
raw
history blame contribute delete
768 Bytes
from typing import Union
import torch
def set_device(device : Union[str, torch.device]) -> torch.device:
"""
Set the device to use for inference. Recommended to use GPU.
Arguments:
device Union[str, torch.device]
The device to use for inference. Can be either a string or a torch.device object.
Returns:
torch.device
The device to use for inference.
"""
if isinstance(device, str):
if device == 'cuda' and torch.cuda.is_available():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
elif device == 'mps' and torch.backends.mps.is_built():
device = torch.device('mps')
else:
device = torch.device(device)
return device