marianna13's picture
Update app.py
3406caf
raw
history blame contribute delete
No virus
5.62 kB
import gradio as gr
import yt_dlp
import os
import time
import torch
import transformers
import clip
import numpy as np
import cv2
import random
from PIL import Image
from multilingual_clip import pt_multilingual_clip
class SearchVideo:
def __init__(
self,
clip_model: str,
text_model: str,
tokenizer,
compose,
) -> None:
"""
clip_model: CLIP model to use for image embeddings
text_model: text encoder model
"""
self.text_model = text_model
self.tokenizer = tokenizer
self.clip_model = clip_model
self.compose = compose
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, video: str, text: str) -> list:
torch.cuda.empty_cache()
img_list = []
text_list = []
frames = self.video2frames_ffmpeg(video)
img_embs = self.get_img_embs(frames)
txt_emb = self.get_txt_embs(text)
# txt_emb = [[t]*len(frames) for t in txt_emb]
txt_emb = txt_emb*len(frames)
logits_per_image = self.compare_embeddings(img_embs, txt_emb)
logits_per_image = [logit.numpy()[0] for logit in logits_per_image]
ind = np.argmax(logits_per_image)
seg_path = self.extract_seg(video, ind)
return ind, seg_path, frames[ind]
def extract_seg(self, video:str, start:int):
start = start if start > 5 else start-5
start = time.strftime('%H:%M:%S', time.gmtime(start))
cmd = f'ffmpeg -ss {start} -i "{video}" -t 00:00:02 -vcodec copy -acodec copy -y segment_{start}.mp4'
os.system(cmd)
return f'segment_{start}.mp4'
def video2frames_ffmpeg(self, video: str) -> list:
frames_dir = 'frames'
if not os.path.exists(frames_dir):
os.makedirs(frames_dir)
select = "select='if(eq(n\,0),1,floor(t)-floor(prev_selected_t))'"
os.system(f'ffmpeg -i {video} -r 1 {frames_dir}/output-%04d.jpg')
images = [Image.open(f'{frames_dir}/{f}') for f in sorted(os.listdir(frames_dir))]
os.system(f'rm -rf {frames_dir}')
return images
def video2frames(self, video: str) -> list:
cap = cv2.VideoCapture(video)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
images = []
frames_sec = [i for i in range(0, num_frames, 24*1)]
has_frames,image = cap.read()
frame_count = 0
while has_frames:
has_frames,image = cap.read()
frame_count += 1
if has_frames:
if frame_count in frames_sec:
image = Image.fromarray(image)
images.append(image)
return images
def get_img_embs(self, img_list: list) -> list:
"""
takes list of image and calculates clip embeddings with model specified by clip_model
"""
img_input = torch.stack([self.compose(img).to(self.device)
for img in img_list])
with torch.no_grad():
image_embs = self.clip_model.encode_image(img_input).float().cpu()
return image_embs
def get_txt_embs(self, text: str) -> torch.Tensor:
"calculates clip emebdding for the text "
with torch.no_grad():
return self.text_model(text, self.tokenizer)
def compare_embeddings(self, img_embs, txt_embs):
# normalized features
image_features = img_embs / img_embs.norm(dim=-1, keepdim=True)
text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = []
for image_feature in image_features:
logits_per_image.append(image_feature @ text_features.t())
return logits_per_image
def download_yt_video(url):
ydl_opts = {
'quiet': True,
"outtmpl": "%(id)s.%(ext)s",
'format': 'bv*[height<=360][ext=mp4]+ba/b[height<=360] / wv*+ba/w'
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([url])
return url.split('/')[-1].replace('watch?v=', '')+'.mp4'
clip_model='ViT-B/32'
text_model='M-CLIP/XLM-Roberta-Large-Vit-B-32'
clip_model, compose = clip.load(clip_model)
tokenizer = transformers.AutoTokenizer.from_pretrained(text_model)
text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(text_model)
def search_video(video_url, text, video=None):
search = SearchVideo(
clip_model=clip_model,
text_model=text_model,
tokenizer=tokenizer,
compose=compose
)
if video !=None:
video_url = None
if video_url:
video = download_yt_video(video_url)
ind, seg_path, img = search(video, text)
start = time.strftime('%H:%M:%S', time.gmtime(ind))
return f'"{text}" found at {start}', seg_path
title = 'πŸ”ŽπŸŽžοΈπŸš€ Search inside a video'
description = '''Just enter a search query, a video URL or upload your video and get a 2-sec fragment from the video which is visually closest to you query.'''
examples = [["https://www.youtube.com/watch?v=M93w3TjzVUE", "A dog"]]
iface = gr.Interface(
search_video,
inputs=[gr.Textbox(value="https://www.youtube.com/watch?v=M93w3TjzVUE", label='Video URL'), gr.Textbox(value="a dog", label='Text query'), gr.Video()],
outputs=[gr.Textbox(label="Output"), gr.Video(label="Video segment")],
allow_flagging="never",
title=title,
description=description,
examples=examples
)
if __name__ == "__main__":
iface.launch(show_error=True)