Merge branch 'main' of https://huggingface.co./spaces/NKU-AMT/AMT into main
Browse files- demo_img.py +1 -1
demo_img.py
CHANGED
@@ -9,7 +9,7 @@ from networks.amtl import Model as AMTL
|
|
9 |
from networks.amtg import Model as AMTG
|
10 |
from utils import img2tensor, tensor2img, InputPadder
|
11 |
|
12 |
-
device = torch.device('
|
13 |
model_dict = {
|
14 |
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
|
15 |
}
|
|
|
9 |
from networks.amtg import Model as AMTG
|
10 |
from utils import img2tensor, tensor2img, InputPadder
|
11 |
|
12 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
model_dict = {
|
14 |
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
|
15 |
}
|