YaTharThShaRma999 commited on
Commit
d5b92ae
1 Parent(s): 40a9ab8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -22
app.py CHANGED
@@ -67,7 +67,30 @@ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
67
  [0.700, 0.300, 0.600],[0.000, 0.447, 0.741], [0.850, 0.325, 0.098]]
68
 
69
 
70
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  coco_class_name = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
72
  OBJ365_class_names = [cat['name'] for cat in OBJ365_CATEGORIESV2]
73
  class_agnostic_name = ['object']
@@ -79,27 +102,6 @@ else:
79
  print('use cpu')
80
  device='cpu'
81
 
82
- cfg_r50 = get_cfg()
83
- add_deeplab_config(cfg_r50)
84
- add_glee_config(cfg_r50)
85
- conf_files_r50 = 'GLEE/configs/R50.yaml'
86
- checkpoints_r50 = torch.load('GLEE_R50_Scaleup10m.pth')
87
- cfg_r50.merge_from_file(conf_files_r50)
88
- GLEEmodel_r50 = GLEE_Model(cfg_r50, None, device, None, True).to(device)
89
- GLEEmodel_r50.load_state_dict(checkpoints_r50, strict=False)
90
- GLEEmodel_r50.eval()
91
-
92
-
93
- cfg_swin = get_cfg()
94
- add_deeplab_config(cfg_swin)
95
- add_glee_config(cfg_swin)
96
- conf_files_swin = 'GLEE/configs/SwinL.yaml'
97
- checkpoints_swin = torch.load('GLEE_SwinL_Scaleup10m.pth')
98
- cfg_swin.merge_from_file(conf_files_swin)
99
- GLEEmodel_swin = GLEE_Model(cfg_swin, None, device, None, True).to(device)
100
- GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
101
- GLEEmodel_swin.eval()
102
-
103
  pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
104
  pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
105
  normalizer = lambda x: (x - pixel_mean) / pixel_std
 
67
  [0.700, 0.300, 0.600],[0.000, 0.447, 0.741], [0.850, 0.325, 0.098]]
68
 
69
 
70
+ def load_smallmodel():
71
+ cfg_r50 = get_cfg()
72
+ add_deeplab_config(cfg_r50)
73
+ add_glee_config(cfg_r50)
74
+ conf_files_r50 = 'GLEE/configs/R50.yaml'
75
+ checkpoints_r50 = torch.load('GLEE_R50_Scaleup10m.pth')
76
+ cfg_r50.merge_from_file(conf_files_r50)
77
+ GLEEmodel_r50 = GLEE_Model(cfg_r50, None, device, None, True).to(device)
78
+ GLEEmodel_r50.load_state_dict(checkpoints_r50, strict=False)
79
+ GLEEmodel_r50.eval()
80
+ return GLEEmodel_r50
81
+ def load_bigmodel():
82
+ cfg_swin = get_cfg()
83
+ add_deeplab_config(cfg_swin)
84
+ add_glee_config(cfg_swin)
85
+ conf_files_swin = 'GLEE/configs/SwinL.yaml'
86
+ checkpoints_swin = torch.load('GLEE_SwinL_Scaleup10m.pth')
87
+ cfg_swin.merge_from_file(conf_files_swin)
88
+ GLEEmodel_swin = GLEE_Model(cfg_swin, None, device, None, True).to(device)
89
+ GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
90
+ GLEEmodel_swin.eval()
91
+ return GLEEmodel_swin
92
+ GLEEmodel_swin = load_bigmodel()
93
+ GLEEmodel_r50 = load_smallmodel()
94
  coco_class_name = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
95
  OBJ365_class_names = [cat['name'] for cat in OBJ365_CATEGORIESV2]
96
  class_agnostic_name = ['object']
 
102
  print('use cpu')
103
  device='cpu'
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
106
  pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
107
  normalizer = lambda x: (x - pixel_mean) / pixel_std