michaelryoo commited on
Commit
03eedfb
·
verified ·
1 Parent(s): a1b0c6c

Upload xgen-mm-vid-inference-script.py

Browse files
Files changed (1) hide show
  1. xgen-mm-vid-inference-script.py +141 -0
xgen-mm-vid-inference-script.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from modeling_xgenmm import *
3
+
4
+
5
+ # %%
6
+ cfg = XGenMMConfig()
7
+ model = XGenMMModelForConditionalGeneration(cfg)
8
+ model = model.cuda()
9
+ model = model.half()
10
+
11
+
12
+ # %%
13
+ from transformers import AutoTokenizer, AutoImageProcessor
14
+
15
+ xgenmm_path = "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5"
16
+ tokenizer = AutoTokenizer.from_pretrained(
17
+ xgenmm_path, trust_remote_code=True, use_fast=False, legacy=False
18
+ )
19
+ image_processor = AutoImageProcessor.from_pretrained(
20
+ xgenmm_path, trust_remote_code=True
21
+ )
22
+ tokenizer = model.update_special_tokens(tokenizer)
23
+ # model = model.to("cuda")
24
+ model.eval()
25
+ tokenizer.padding_side = "left"
26
+ tokenizer.eos_token = "<|end|>"
27
+
28
+
29
+ # %%
30
+ import numpy as np
31
+ import torchvision
32
+
33
+ import torchvision.io
34
+
35
+ import math
36
+
37
+
38
+ def sample_frames(vframes, num_frames):
39
+ frame_indice = np.linspace(0, len(vframes) - 1, num_frames, dtype=int)
40
+ video = vframes[frame_indice]
41
+ video_list = []
42
+ for i in range(len(video)):
43
+ video_list.append(torchvision.transforms.functional.to_pil_image(video[i]))
44
+ return video_list
45
+
46
+
47
+ def generate(messages, images):
48
+ # img_bytes_list = [base64.b64decode(image.encode("utf-8")) for image in images]
49
+ # images = [Image.open(BytesIO(img_bytes)) for img_bytes in img_bytes_list]
50
+ image_sizes = [image.size for image in images]
51
+ # Similar operation in model_worker.py
52
+
53
+ if cfg.vision_encoder_config.image_aspect_ratio == "anyres":
54
+ image_list = [
55
+ image_processor([img], image_aspect_ratio="anyres")["pixel_values"].to(
56
+ model.device, dtype=torch.float16
57
+ )
58
+ for img in images
59
+ ]
60
+
61
+ inputs = {"pixel_values": [image_list]}
62
+ else:
63
+ image_tensor = [image_processor([img])["pixel_values"].to(model.device, dtype=torch.float16) for img in images]
64
+
65
+ for i in range(0, 8):
66
+ image_tensor[i] = torch.zeros([1, 1, 1, 3, 384, 384], device=model.device, dtype=torch.float16)
67
+ image_tensor = torch.stack(image_tensor, dim=1)
68
+ image_tensor = image_tensor.squeeze(2)
69
+ inputs = {"pixel_values": image_tensor}
70
+
71
+ full_conv = "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
72
+ for msg in messages:
73
+ msg_str = "<|{role}|>\n{content}<|end|>\n".format(
74
+ role=msg["role"], content=msg["content"]
75
+ )
76
+ full_conv += msg_str
77
+
78
+ full_conv += "<|assistant|>\n"
79
+ print(full_conv)
80
+ language_inputs = tokenizer([full_conv], return_tensors="pt")
81
+ for name, value in language_inputs.items():
82
+ language_inputs[name] = value.to(model.device)
83
+ inputs.update(language_inputs)
84
+ # print(inputs)
85
+
86
+ with torch.inference_mode():
87
+ generated_text = model.generate(
88
+ **inputs,
89
+ image_size=[image_sizes],
90
+ pad_token_id=tokenizer.pad_token_id,
91
+ eos_token_id=tokenizer.eos_token_id,
92
+ temperature=0.05,
93
+ do_sample=False,
94
+ max_new_tokens=1024,
95
+ top_p=None,
96
+ num_beams=1,
97
+ )
98
+
99
+ outputs = (
100
+ tokenizer.decode(generated_text[0], skip_special_tokens=True)
101
+ .split("<|end|>")[0]
102
+ .strip()
103
+ )
104
+ return outputs
105
+
106
+
107
+ def predict(video_file, num_frames=8):
108
+ vframes, _, _ = torchvision.io.read_video(
109
+ filename=video_file, pts_unit="sec", output_format="TCHW"
110
+ )
111
+ total_frames = len(vframes)
112
+ images = sample_frames(vframes, num_frames)
113
+
114
+ prompt = ""
115
+ prompt = prompt + "<image>\n"
116
+ prompt = prompt + "Describe this video."
117
+ messages = [{"role": "user", "content": prompt}]
118
+ return generate(messages, images)
119
+
120
+
121
+ # %%
122
+ import torch
123
+
124
+ your_checkpoint_path = ""
125
+ sd = torch.load(your_checkpoint_path)
126
+
127
+ sd = sd["model_state_dict"]
128
+ for k, v in list(sd.items()):
129
+ sd["vlm." + k] = v
130
+ del sd[k]
131
+
132
+ model.load_state_dict(sd)
133
+
134
+ # %%
135
+ your_video_path = ""
136
+ print(
137
+ predict(
138
+ your_video_path,
139
+ num_frames = 16
140
+ )
141
+ )