michaelryoo commited on
Commit
1df9f81
·
verified ·
1 Parent(s): 2aed0a6

Rename xgen-mm-vid-inference-script.py to xgen-mm-vid-inference-script_hf.py

Browse files
xgen-mm-vid-inference-script.py → xgen-mm-vid-inference-script_hf.py RENAMED
@@ -1,26 +1,13 @@
1
- # %%
2
- from modeling_xgenmm import *
3
-
4
-
5
- # %%
6
- cfg = XGenMMConfig()
7
- model = XGenMMModelForConditionalGeneration(cfg)
8
- model = model.cuda()
9
- model = model.half()
10
-
11
-
12
- # %%
13
- from transformers import AutoTokenizer, AutoImageProcessor
14
 
15
- xgenmm_path = "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5"
16
- tokenizer = AutoTokenizer.from_pretrained(
17
- xgenmm_path, trust_remote_code=True, use_fast=False, legacy=False
18
- )
19
- image_processor = AutoImageProcessor.from_pretrained(
20
- xgenmm_path, trust_remote_code=True
21
- )
22
  tokenizer = model.update_special_tokens(tokenizer)
23
- # model = model.to("cuda")
 
24
  model.eval()
25
  tokenizer.padding_side = "left"
26
  tokenizer.eos_token = "<|end|>"
@@ -34,9 +21,8 @@ import torchvision.io
34
 
35
  import math
36
 
37
-
38
  def sample_frames(vframes, num_frames):
39
- frame_indice = np.linspace(0, len(vframes) - 1, num_frames, dtype=int)
40
  video = vframes[frame_indice]
41
  video_list = []
42
  for i in range(len(video)):
@@ -49,8 +35,7 @@ def generate(messages, images):
49
  # images = [Image.open(BytesIO(img_bytes)) for img_bytes in img_bytes_list]
50
  image_sizes = [image.size for image in images]
51
  # Similar operation in model_worker.py
52
-
53
- image_tensor = [image_processor([img])["pixel_values"].to(model.device, dtype=torch.float16) for img in images]
54
 
55
  image_tensor = torch.stack(image_tensor, dim=1)
56
  image_tensor = image_tensor.squeeze(2)
@@ -101,23 +86,18 @@ def predict(video_file, num_frames=8):
101
 
102
  prompt = ""
103
  prompt = prompt + "<image>\n"
104
- prompt = prompt + "Describe this video."
 
105
  messages = [{"role": "user", "content": prompt}]
106
  return generate(messages, images)
107
 
108
-
109
  # %%
110
- import torch
111
-
112
- your_checkpoint_path = ""
113
- sd = torch.load(your_checkpoint_path)
114
- model.load_state_dict(sd)
115
-
116
- # %%
117
- your_video_path = ""
118
  print(
119
  predict(
120
- your_video_path,
121
- num_frames = 16
122
  )
123
  )
 
 
 
1
+ from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor, LogitsProcessor
2
+ import torch
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ model_name_or_path = "Salesforce/xgen-mm-vid-phi3-mini-r-v1.5-128tokens-16frames"
5
+ model = AutoModelForVision2Seq.from_pretrained(model_name_or_path, trust_remote_code=True)
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=False, legacy=False)
7
+ image_processor = AutoImageProcessor.from_pretrained(model_name_or_path, trust_remote_code=True)
 
 
 
8
  tokenizer = model.update_special_tokens(tokenizer)
9
+
10
+ model = model.to('cuda')
11
  model.eval()
12
  tokenizer.padding_side = "left"
13
  tokenizer.eos_token = "<|end|>"
 
21
 
22
  import math
23
 
 
24
  def sample_frames(vframes, num_frames):
25
+ frame_indice = np.linspace(int(num_frames/2), len(vframes) - int(num_frames/2), num_frames, dtype=int)
26
  video = vframes[frame_indice]
27
  video_list = []
28
  for i in range(len(video)):
 
35
  # images = [Image.open(BytesIO(img_bytes)) for img_bytes in img_bytes_list]
36
  image_sizes = [image.size for image in images]
37
  # Similar operation in model_worker.py
38
+ image_tensor = [image_processor([img])["pixel_values"].to(model.device, dtype=torch.float32) for img in images]
 
39
 
40
  image_tensor = torch.stack(image_tensor, dim=1)
41
  image_tensor = image_tensor.squeeze(2)
 
86
 
87
  prompt = ""
88
  prompt = prompt + "<image>\n"
89
+ # prompt = prompt + "What's the main gist of the video ?"
90
+ prompt = prompt + "Please describe the primary object or subject in the video, capturing their attributes, actions, positions, and movements."
91
  messages = [{"role": "user", "content": prompt}]
92
  return generate(messages, images)
93
 
 
94
  # %%
95
+ video_path = ""
 
 
 
 
 
 
 
96
  print(
97
  predict(
98
+ video_path,
99
+ num_frames = 8
100
  )
101
  )
102
+
103
+ # %%