import gradio as gr import yt_dlp import os import time import torch import transformers import clip import numpy as np import cv2 import random from PIL import Image from multilingual_clip import pt_multilingual_clip class SearchVideo: def __init__( self, clip_model: str, text_model: str, tokenizer, compose, ) -> None: """ clip_model: CLIP model to use for image embeddings text_model: text encoder model """ self.text_model = text_model self.tokenizer = tokenizer self.clip_model = clip_model self.compose = compose self.device = "cuda" if torch.cuda.is_available() else "cpu" def __call__(self, video: str, text: str) -> list: torch.cuda.empty_cache() img_list = [] text_list = [] frames = self.video2frames_ffmpeg(video) img_embs = self.get_img_embs(frames) txt_emb = self.get_txt_embs(text) # txt_emb = [[t]*len(frames) for t in txt_emb] txt_emb = txt_emb*len(frames) logits_per_image = self.compare_embeddings(img_embs, txt_emb) logits_per_image = [logit.numpy()[0] for logit in logits_per_image] ind = np.argmax(logits_per_image) seg_path = self.extract_seg(video, ind) return ind, seg_path, frames[ind] def extract_seg(self, video:str, start:int): start = start if start > 5 else start-5 start = time.strftime('%H:%M:%S', time.gmtime(start)) cmd = f'ffmpeg -ss {start} -i "{video}" -t 00:00:02 -vcodec copy -acodec copy -y segment_{start}.mp4' os.system(cmd) return f'segment_{start}.mp4' def video2frames_ffmpeg(self, video: str) -> list: frames_dir = 'frames' if not os.path.exists(frames_dir): os.makedirs(frames_dir) select = "select='if(eq(n\,0),1,floor(t)-floor(prev_selected_t))'" os.system(f'ffmpeg -i {video} -r 1 {frames_dir}/output-%04d.jpg') images = [Image.open(f'{frames_dir}/{f}') for f in sorted(os.listdir(frames_dir))] os.system(f'rm -rf {frames_dir}') return images def video2frames(self, video: str) -> list: cap = cv2.VideoCapture(video) num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) images = [] frames_sec = [i for i in range(0, num_frames, 24*1)] has_frames,image = cap.read() frame_count = 0 while has_frames: has_frames,image = cap.read() frame_count += 1 if has_frames: if frame_count in frames_sec: image = Image.fromarray(image) images.append(image) return images def get_img_embs(self, img_list: list) -> list: """ takes list of image and calculates clip embeddings with model specified by clip_model """ img_input = torch.stack([self.compose(img).to(self.device) for img in img_list]) with torch.no_grad(): image_embs = self.clip_model.encode_image(img_input).float().cpu() return image_embs def get_txt_embs(self, text: str) -> torch.Tensor: "calculates clip emebdding for the text " with torch.no_grad(): return self.text_model(text, self.tokenizer) def compare_embeddings(self, img_embs, txt_embs): # normalized features image_features = img_embs / img_embs.norm(dim=-1, keepdim=True) text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True) # cosine similarity as logits logits_per_image = [] for image_feature in image_features: logits_per_image.append(image_feature @ text_features.t()) return logits_per_image def download_yt_video(url): ydl_opts = { 'quiet': True, "outtmpl": "%(id)s.%(ext)s", 'format': 'bv*[height<=360][ext=mp4]+ba/b[height<=360] / wv*+ba/w' } with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.download([url]) return url.split('/')[-1].replace('watch?v=', '')+'.mp4' clip_model='ViT-B/32' text_model='M-CLIP/XLM-Roberta-Large-Vit-B-32' clip_model, compose = clip.load(clip_model) tokenizer = transformers.AutoTokenizer.from_pretrained(text_model) text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(text_model) def search_video(video_url, text, video=None): search = SearchVideo( clip_model=clip_model, text_model=text_model, tokenizer=tokenizer, compose=compose ) if video !=None: video_url = None if video_url: video = download_yt_video(video_url) ind, seg_path, img = search(video, text) start = time.strftime('%H:%M:%S', time.gmtime(ind)) return f'"{text}" found at {start}', seg_path title = '🔎🎞️🚀 Search inside a video' description = '''Just enter a search query, a video URL or upload your video and get a 2-sec fragment from the video which is visually closest to you query.''' examples = [["https://www.youtube.com/watch?v=M93w3TjzVUE", "A dog"]] iface = gr.Interface( search_video, inputs=[gr.Textbox(value="https://www.youtube.com/watch?v=M93w3TjzVUE", label='Video URL'), gr.Textbox(value="a dog", label='Text query'), gr.Video()], outputs=[gr.Textbox(label="Output"), gr.Video(label="Video segment")], allow_flagging="never", title=title, description=description, examples=examples ) if __name__ == "__main__": iface.launch(show_error=True)