File size: 574 Bytes
2cc42e2
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14

import torch

class BiEncoderCollator:
    def __call__(self, features):
        batch = {
            'input_ids_text1': torch.stack([f['input_ids_text1'] for f in features]),
            'attention_mask_text1': torch.stack([f['attention_mask_text1'] for f in features]),
            'input_ids_text2': torch.stack([f['input_ids_text2'] for f in features]),
            'attention_mask_text2': torch.stack([f['attention_mask_text2'] for f in features]),
            'labels': torch.tensor([f['labels'] for f in features], dtype=torch.float)
        }
        return batch