train model
Browse files- scripts/train_model.py +4 -4
scripts/train_model.py
CHANGED
@@ -8,7 +8,7 @@ from transformers import DataCollatorForLanguageModeling
|
|
8 |
|
9 |
import torch
|
10 |
from torch.utils.data import DataLoader
|
11 |
-
import torch.multiprocessing as mp
|
12 |
|
13 |
|
14 |
# x = input('Are you sure? [y/N] ')
|
@@ -18,7 +18,7 @@ import torch.multiprocessing as mp
|
|
18 |
|
19 |
|
20 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
-
mp.set_start_method('spawn', force=True)
|
22 |
|
23 |
|
24 |
def _batch_iterator():
|
@@ -266,7 +266,7 @@ train_dataloader = DataLoader(
|
|
266 |
collate_fn=collate_fn,
|
267 |
batch_size=training_args.per_device_train_batch_size,
|
268 |
pin_memory=True,
|
269 |
-
num_workers=4
|
270 |
)
|
271 |
|
272 |
eval_dataloader = DataLoader(
|
@@ -274,7 +274,7 @@ eval_dataloader = DataLoader(
|
|
274 |
collate_fn=collate_fn,
|
275 |
batch_size=training_args.per_device_eval_batch_size,
|
276 |
pin_memory=True,
|
277 |
-
num_workers=4
|
278 |
)
|
279 |
|
280 |
trainer = Trainer(
|
|
|
8 |
|
9 |
import torch
|
10 |
from torch.utils.data import DataLoader
|
11 |
+
# import torch.multiprocessing as mp
|
12 |
|
13 |
|
14 |
# x = input('Are you sure? [y/N] ')
|
|
|
18 |
|
19 |
|
20 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
+
# mp.set_start_method('spawn', force=True)
|
22 |
|
23 |
|
24 |
def _batch_iterator():
|
|
|
266 |
collate_fn=collate_fn,
|
267 |
batch_size=training_args.per_device_train_batch_size,
|
268 |
pin_memory=True,
|
269 |
+
# num_workers=4
|
270 |
)
|
271 |
|
272 |
eval_dataloader = DataLoader(
|
|
|
274 |
collate_fn=collate_fn,
|
275 |
batch_size=training_args.per_device_eval_batch_size,
|
276 |
pin_memory=True,
|
277 |
+
# num_workers=4
|
278 |
)
|
279 |
|
280 |
trainer = Trainer(
|