File size: 2,050 Bytes
1622f6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import os
from typing import List
import cv2
import numpy as np
import torch
from PIL import Image
import torch.nn.functional as F
from modeling import VideoCLIP_XL
from utils.text_encoder import text_encoder
def _frame_from_video(video):
while video.isOpened():
success, frame = video.read()
if success:
yield frame
else:
break
v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)
v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)
def normalize(data):
return (data / 255.0 - v_mean) / v_std
def video_preprocessing(video_path, fnum=8):
video = cv2.VideoCapture(video_path)
frames = [x for x in _frame_from_video(video)]
step = len(frames) // fnum
frames = frames[::step][:fnum]
vid_tube = []
for fr in frames:
fr = fr[:,:,::-1]
fr = cv2.resize(fr, (224, 224))
fr = np.expand_dims(normalize(fr), axis=(0, 1))
vid_tube.append(fr)
vid_tube = np.concatenate(vid_tube, axis=1)
vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
vid_tube = torch.from_numpy(vid_tube)
return vid_tube
videoclip_xl = VideoCLIP_XL()
state_dict = torch.load("./VideoCLIP-XL.bin", map_location="cpu")
videoclip_xl.load_state_dict(state_dict)
videoclip_xl.cuda().eval()
videos = [
"/path/to/video-1.mp4",
"/path/to/video-2.mp4",
]
texts = [
"text-1",
"text-2",
"text-3",
]
with torch.no_grad():
video_inputs = torch.cat([video_preprocessing(video) for video in videos], 0).float().cuda()
video_features = videoclip_xl.vision_model.get_vid_features(video_inputs).float()
video_features = video_features / video_features.norm(dim=-1, keepdim=True)
text_inputs = text_encoder.tokenize(texts, truncate=True).cuda()
text_features = videoclip_xl.text_model.encode_text(text_inputs).float()
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
Tmp = 100.
sim_matrix = (text_features @ video_features.T) * Tmp
print(sim_matrix) |