File size: 644 Bytes
25e7dcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from typing import Optional
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
def get_dataloader(inputs: np.ndarray, targets: np.ndarray, device: Optional[torch.device] = None,
batch_size: int = 64) -> DataLoader:
input_tensor = torch.tensor(inputs).float()
target_tensor = torch.tensor(targets).float()
if device is not None:
input_tensor = input_tensor.to(device)
target_tensor = target_tensor.to(device)
dataset = TensorDataset(input_tensor, target_tensor)
loader = DataLoader(dataset, batch_size=batch_size)
return loader
|