Update tasks/text.py
Browse files- tasks/text.py +1 -1
tasks/text.py
CHANGED
@@ -66,7 +66,7 @@ def bert_classifier(test_dataset: dict, model: str):
|
|
66 |
raise(ValueError)
|
67 |
|
68 |
# Use CUDA if available
|
69 |
-
device
|
70 |
|
71 |
model = model.to(device)
|
72 |
|
|
|
66 |
raise(ValueError)
|
67 |
|
68 |
# Use CUDA if available
|
69 |
+
device = "cuda"
|
70 |
|
71 |
model = model.to(device)
|
72 |
|