THUdyh commited on
Commit
649a916
1 Parent(s): 83fe9bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -25,7 +25,8 @@ overwrite_config["mm_resampler_type"] = "dynamic_compressor"
25
  overwrite_config["patchify_video_feature"] = False
26
  overwrite_config["attn_implementation"] = "sdpa" if torch.__version__ >= "2.1.2" else "eager"
27
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device_map="cpu", overwrite_config=overwrite_config)
28
- model.to("cuda").eval()
 
29
 
30
  def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
31
  roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
@@ -100,7 +101,7 @@ def oryx_inference(video, text):
100
  conv.append_message(conv.roles[1], None)
101
  prompt = conv.get_prompt()
102
 
103
- input_ids = preprocess_qwen([{'from': 'human','value': question},{'from': 'gpt','value': None}], tokenizer, has_image=True).to("cuda")
104
 
105
  video_processed = []
106
  for idx, frame in enumerate(video):
@@ -116,7 +117,7 @@ def oryx_inference(video, text):
116
  if frame_idx is None:
117
  frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
118
 
119
- video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda")
120
  video_processed = (video_processed, video_processed)
121
 
122
  video_data = (video_processed, (384, 384), "video")
 
25
  overwrite_config["patchify_video_feature"] = False
26
  overwrite_config["attn_implementation"] = "sdpa" if torch.__version__ >= "2.1.2" else "eager"
27
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device_map="cpu", overwrite_config=overwrite_config)
28
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+ model.to(device).eval()
30
 
31
  def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
32
  roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
 
101
  conv.append_message(conv.roles[1], None)
102
  prompt = conv.get_prompt()
103
 
104
+ input_ids = preprocess_qwen([{'from': 'human','value': question},{'from': 'gpt','value': None}], tokenizer, has_image=True).to(device)
105
 
106
  video_processed = []
107
  for idx, frame in enumerate(video):
 
117
  if frame_idx is None:
118
  frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
119
 
120
+ video_processed = torch.cat(video_processed, dim=0).bfloat16().to(device)
121
  video_processed = (video_processed, video_processed)
122
 
123
  video_data = (video_processed, (384, 384), "video")