|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
|
|
|
|
def get_dataloader(inputs: np.ndarray, targets: np.ndarray, device: Optional[torch.device] = None,
|
|
batch_size: int = 64) -> DataLoader:
|
|
input_tensor = torch.tensor(inputs).float()
|
|
target_tensor = torch.tensor(targets).float()
|
|
if device is not None:
|
|
input_tensor = input_tensor.to(device)
|
|
target_tensor = target_tensor.to(device)
|
|
|
|
dataset = TensorDataset(input_tensor, target_tensor)
|
|
loader = DataLoader(dataset, batch_size=batch_size)
|
|
return loader
|
|
|