File size: 644 Bytes
25e7dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from typing import Optional

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset


def get_dataloader(inputs: np.ndarray, targets: np.ndarray, device: Optional[torch.device] = None,

                   batch_size: int = 64) -> DataLoader:
    input_tensor = torch.tensor(inputs).float()
    target_tensor = torch.tensor(targets).float()
    if device is not None:
        input_tensor = input_tensor.to(device)
        target_tensor = target_tensor.to(device)

    dataset = TensorDataset(input_tensor, target_tensor)
    loader = DataLoader(dataset, batch_size=batch_size)
    return loader