suchkow's picture
Upload 48 files
25e7dcb verified
raw
history blame
No virus
644 Bytes
from typing import Optional
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
def get_dataloader(inputs: np.ndarray, targets: np.ndarray, device: Optional[torch.device] = None,
batch_size: int = 64) -> DataLoader:
input_tensor = torch.tensor(inputs).float()
target_tensor = torch.tensor(targets).float()
if device is not None:
input_tensor = input_tensor.to(device)
target_tensor = target_tensor.to(device)
dataset = TensorDataset(input_tensor, target_tensor)
loader = DataLoader(dataset, batch_size=batch_size)
return loader