File size: 3,507 Bytes
1df9f81
 
03eedfb
1df9f81
 
 
 
03eedfb
1df9f81
 
03eedfb
 
 
 
 
 
 
 
 
 
 
 
 
 
1df9f81
03eedfb
 
 
 
 
 
 
 
 
 
 
 
1df9f81
0fdfc65
 
 
 
03eedfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1df9f81
 
03eedfb
 
 
 
1df9f81
03eedfb
 
1df9f81
 
03eedfb
 
1df9f81
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor, LogitsProcessor
import torch

model_name_or_path = "Salesforce/xgen-mm-vid-phi3-mini-r-v1.5-128tokens-16frames"
model = AutoModelForVision2Seq.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=False, legacy=False)
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer = model.update_special_tokens(tokenizer)

model = model.to('cuda')
model.eval()
tokenizer.padding_side = "left"
tokenizer.eos_token = "<|end|>"


# %%
import numpy as np
import torchvision

import torchvision.io

import math

def sample_frames(vframes, num_frames):
    frame_indice = np.linspace(int(num_frames/2), len(vframes) - int(num_frames/2), num_frames, dtype=int)
    video = vframes[frame_indice]
    video_list = []
    for i in range(len(video)):
        video_list.append(torchvision.transforms.functional.to_pil_image(video[i]))
    return video_list


def generate(messages, images):
    # img_bytes_list = [base64.b64decode(image.encode("utf-8")) for image in images]
    # images = [Image.open(BytesIO(img_bytes)) for img_bytes in img_bytes_list]
    image_sizes = [image.size for image in images]
    # Similar operation in model_worker.py
    image_tensor = [image_processor([img])["pixel_values"].to(model.device, dtype=torch.float32) for img in images]

    image_tensor = torch.stack(image_tensor, dim=1)
    image_tensor = image_tensor.squeeze(2)
    inputs = {"pixel_values": image_tensor}

    full_conv = "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
    for msg in messages:
        msg_str = "<|{role}|>\n{content}<|end|>\n".format(
            role=msg["role"], content=msg["content"]
        )
        full_conv += msg_str

    full_conv += "<|assistant|>\n"
    print(full_conv)
    language_inputs = tokenizer([full_conv], return_tensors="pt")
    for name, value in language_inputs.items():
        language_inputs[name] = value.to(model.device)
    inputs.update(language_inputs)
    # print(inputs)

    with torch.inference_mode():
        generated_text = model.generate(
            **inputs,
            image_size=[image_sizes],
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            temperature=0.05,
            do_sample=False,
            max_new_tokens=1024,
            top_p=None,
            num_beams=1,
        )

    outputs = (
        tokenizer.decode(generated_text[0], skip_special_tokens=True)
        .split("<|end|>")[0]
        .strip()
    )
    return outputs


def predict(video_file, num_frames=8):
    vframes, _, _ = torchvision.io.read_video(
        filename=video_file, pts_unit="sec", output_format="TCHW"
    )
    total_frames = len(vframes)
    images = sample_frames(vframes, num_frames)

    prompt = ""
    prompt = prompt + "<image>\n"
    # prompt = prompt + "What's the main gist of the video ?"
    prompt = prompt + "Please describe the primary object or subject in the video, capturing their attributes, actions, positions, and movements."
    messages = [{"role": "user", "content": prompt}]
    return generate(messages, images)

# %%
video_path = ""
print(
    predict(
        video_path,
        num_frames = 8
    )
)

# %%