bert-reg-biencoder-mae / data_collator.py
minoosh's picture
Upload folder using huggingface_hub
c65ed72 verified
raw
history blame contribute delete
574 Bytes
import torch
class BiEncoderCollator:
def __call__(self, features):
batch = {
'input_ids_text1': torch.stack([f['input_ids_text1'] for f in features]),
'attention_mask_text1': torch.stack([f['attention_mask_text1'] for f in features]),
'input_ids_text2': torch.stack([f['input_ids_text2'] for f in features]),
'attention_mask_text2': torch.stack([f['attention_mask_text2'] for f in features]),
'labels': torch.tensor([f['labels'] for f in features], dtype=torch.float)
}
return batch