|
from typing import Union |
|
import torch |
|
|
|
def set_device(device : Union[str, torch.device]) -> torch.device: |
|
""" |
|
Set the device to use for inference. Recommended to use GPU. |
|
Arguments: |
|
device Union[str, torch.device] |
|
The device to use for inference. Can be either a string or a torch.device object. |
|
|
|
Returns: |
|
torch.device |
|
The device to use for inference. |
|
""" |
|
if isinstance(device, str): |
|
if device == 'cuda' and torch.cuda.is_available(): |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
elif device == 'mps' and torch.backends.mps.is_built(): |
|
device = torch.device('mps') |
|
else: |
|
device = torch.device(device) |
|
return device |