ikuinen99 commited on
Commit
2cc0bbc
·
1 Parent(s): 64ee624
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -67,6 +67,8 @@ args = parse_args()
67
 
68
  assert args.dummy or (args.cfg_path is not None), "Invalid Config! Set --dummy or configurate the cfg_path!"
69
 
 
 
70
  if not args.dummy:
71
  cfg = Config(args)
72
 
@@ -79,17 +81,17 @@ if not args.dummy:
79
 
80
  # Create model
81
  model_config = cfg.model_cfg
82
- model_config.device_8bit = args.gpu_id
83
  model_cls = registry.get_model_class(model_config.arch)
84
- model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
85
- chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id))
86
  else:
87
  model = None
88
  chat = DummyChat()
89
 
90
  match = MatchModule(model='gpt-4')
91
- tagging_module = TaggingModule(device='cuda:{}'.format(args.gpu_id))
92
- grounding_dino = GroundingModule(device='cuda:{}'.format(args.gpu_id))
93
  print('Initialization Finished')
94
 
95
 
 
67
 
68
  assert args.dummy or (args.cfg_path is not None), "Invalid Config! Set --dummy or configurate the cfg_path!"
69
 
70
+ device = 'cuda:{}'.format(args.gpu_id) if torch.cuda.is_available() else 'cpu'
71
+
72
  if not args.dummy:
73
  cfg = Config(args)
74
 
 
81
 
82
  # Create model
83
  model_config = cfg.model_cfg
84
+ model_config.device_8bit = device
85
  model_cls = registry.get_model_class(model_config.arch)
86
+ model = model_cls.from_config(model_config).to(device)
87
+ chat = Chat(model, processors, device=device)
88
  else:
89
  model = None
90
  chat = DummyChat()
91
 
92
  match = MatchModule(model='gpt-4')
93
+ tagging_module = TaggingModule(device=device)
94
+ grounding_dino = GroundingModule(device=device)
95
  print('Initialization Finished')
96
 
97