Spaces:
Runtime error
Runtime error
YaTharThShaRma999
commited on
Commit
•
d5b92ae
1
Parent(s):
40a9ab8
Update app.py
Browse files
app.py
CHANGED
@@ -67,7 +67,30 @@ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
|
67 |
[0.700, 0.300, 0.600],[0.000, 0.447, 0.741], [0.850, 0.325, 0.098]]
|
68 |
|
69 |
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
coco_class_name = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
|
72 |
OBJ365_class_names = [cat['name'] for cat in OBJ365_CATEGORIESV2]
|
73 |
class_agnostic_name = ['object']
|
@@ -79,27 +102,6 @@ else:
|
|
79 |
print('use cpu')
|
80 |
device='cpu'
|
81 |
|
82 |
-
cfg_r50 = get_cfg()
|
83 |
-
add_deeplab_config(cfg_r50)
|
84 |
-
add_glee_config(cfg_r50)
|
85 |
-
conf_files_r50 = 'GLEE/configs/R50.yaml'
|
86 |
-
checkpoints_r50 = torch.load('GLEE_R50_Scaleup10m.pth')
|
87 |
-
cfg_r50.merge_from_file(conf_files_r50)
|
88 |
-
GLEEmodel_r50 = GLEE_Model(cfg_r50, None, device, None, True).to(device)
|
89 |
-
GLEEmodel_r50.load_state_dict(checkpoints_r50, strict=False)
|
90 |
-
GLEEmodel_r50.eval()
|
91 |
-
|
92 |
-
|
93 |
-
cfg_swin = get_cfg()
|
94 |
-
add_deeplab_config(cfg_swin)
|
95 |
-
add_glee_config(cfg_swin)
|
96 |
-
conf_files_swin = 'GLEE/configs/SwinL.yaml'
|
97 |
-
checkpoints_swin = torch.load('GLEE_SwinL_Scaleup10m.pth')
|
98 |
-
cfg_swin.merge_from_file(conf_files_swin)
|
99 |
-
GLEEmodel_swin = GLEE_Model(cfg_swin, None, device, None, True).to(device)
|
100 |
-
GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
|
101 |
-
GLEEmodel_swin.eval()
|
102 |
-
|
103 |
pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
|
104 |
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
|
105 |
normalizer = lambda x: (x - pixel_mean) / pixel_std
|
|
|
67 |
[0.700, 0.300, 0.600],[0.000, 0.447, 0.741], [0.850, 0.325, 0.098]]
|
68 |
|
69 |
|
70 |
+
def load_smallmodel():
|
71 |
+
cfg_r50 = get_cfg()
|
72 |
+
add_deeplab_config(cfg_r50)
|
73 |
+
add_glee_config(cfg_r50)
|
74 |
+
conf_files_r50 = 'GLEE/configs/R50.yaml'
|
75 |
+
checkpoints_r50 = torch.load('GLEE_R50_Scaleup10m.pth')
|
76 |
+
cfg_r50.merge_from_file(conf_files_r50)
|
77 |
+
GLEEmodel_r50 = GLEE_Model(cfg_r50, None, device, None, True).to(device)
|
78 |
+
GLEEmodel_r50.load_state_dict(checkpoints_r50, strict=False)
|
79 |
+
GLEEmodel_r50.eval()
|
80 |
+
return GLEEmodel_r50
|
81 |
+
def load_bigmodel():
|
82 |
+
cfg_swin = get_cfg()
|
83 |
+
add_deeplab_config(cfg_swin)
|
84 |
+
add_glee_config(cfg_swin)
|
85 |
+
conf_files_swin = 'GLEE/configs/SwinL.yaml'
|
86 |
+
checkpoints_swin = torch.load('GLEE_SwinL_Scaleup10m.pth')
|
87 |
+
cfg_swin.merge_from_file(conf_files_swin)
|
88 |
+
GLEEmodel_swin = GLEE_Model(cfg_swin, None, device, None, True).to(device)
|
89 |
+
GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
|
90 |
+
GLEEmodel_swin.eval()
|
91 |
+
return GLEEmodel_swin
|
92 |
+
GLEEmodel_swin = load_bigmodel()
|
93 |
+
GLEEmodel_r50 = load_smallmodel()
|
94 |
coco_class_name = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
|
95 |
OBJ365_class_names = [cat['name'] for cat in OBJ365_CATEGORIESV2]
|
96 |
class_agnostic_name = ['object']
|
|
|
102 |
print('use cpu')
|
103 |
device='cpu'
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
|
106 |
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
|
107 |
normalizer = lambda x: (x - pixel_mean) / pixel_std
|