train model
Browse files- scripts/train_model.py +1 -1
scripts/train_model.py
CHANGED
@@ -252,7 +252,7 @@ print(data_collator)
|
|
252 |
|
253 |
def collate_fn(examples):
|
254 |
texts = [ex['text'] for ex in examples]
|
255 |
-
batch = tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=32 * 1024)
|
256 |
batch['labels'] = batch['input_ids'].clone()
|
257 |
return batch
|
258 |
|
|
|
252 |
|
253 |
def collate_fn(examples):
|
254 |
texts = [ex['text'] for ex in examples]
|
255 |
+
batch = tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=32 * 1024, return_token_type_ids=False)
|
256 |
batch['labels'] = batch['input_ids'].clone()
|
257 |
return batch
|
258 |
|