marianna13 commited on
Commit
0edd243
β€’
1 Parent(s): 7968d81

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import yt_dlp
3
+ import os
4
+ import time
5
+ import torch
6
+ from multilingual_clip import pt_multilingual_clip
7
+ import transformers
8
+ import clip
9
+ import numpy as np
10
+ import cv2
11
+ import random
12
+ from PIL import Image
13
+
14
+ os.system('%cd /Multilingual-CLIP && bash get-weights.sh')
15
+
16
+ class SearchVideo:
17
+
18
+ def __init__(
19
+ self,
20
+ clip_model: str,
21
+ text_model: str,
22
+ tokenizer,
23
+ compose,
24
+ ) -> None:
25
+ """
26
+ clip_model: CLIP model to use for image embeddings
27
+ text_model: text encoder model
28
+ """
29
+ self.text_model = text_model
30
+ self.tokenizer = tokenizer
31
+ self.clip_model = clip_model
32
+ self.compose = compose
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+
36
+ def __call__(self, video: str, text: str) -> list:
37
+ torch.cuda.empty_cache()
38
+ img_list = []
39
+ text_list = []
40
+ frames = self.video2frames_ffmpeg(video)
41
+
42
+
43
+ img_embs = self.get_img_embs(frames)
44
+ txt_emb = self.get_txt_embs(text)
45
+ # txt_emb = [[t]*len(frames) for t in txt_emb]
46
+ txt_emb = txt_emb*len(frames)
47
+
48
+ logits_per_image = self.compare_embeddings(img_embs, txt_emb)
49
+ logits_per_image = [logit.numpy()[0] for logit in logits_per_image]
50
+ ind = np.argmax(logits_per_image)
51
+ seg_path = self.extract_seg(video, ind)
52
+ return ind, seg_path, frames[ind]
53
+
54
+
55
+ def extract_seg(self, video:str, start:int):
56
+ start = start if start > 5 else start-5
57
+ start = time.strftime('%H:%M:%S', time.gmtime(start))
58
+ cmd = f'ffmpeg -ss {start} -i "{video}" -t 00:00:05 -vcodec copy -acodec copy -y segment_{start}.mp4'
59
+ os.system(cmd)
60
+ return f'segment_{start}.mp4'
61
+
62
+ def video2frames_ffmpeg(self, video: str) -> list:
63
+ frames_dir = 'frames'
64
+ if not os.path.exists(frames_dir):
65
+ os.makedirs(frames_dir)
66
+
67
+ select = "select='if(eq(n\,0),1,floor(t)-floor(prev_selected_t))'"
68
+ os.system(f'ffmpeg -i {video} -r 1 {frames_dir}/output-%04d.jpg')
69
+
70
+ images = [Image.open(f'{frames_dir}/{f}') for f in sorted(os.listdir(frames_dir))]
71
+ os.system(f'rm -rf {frames_dir}')
72
+ return images
73
+
74
+ def video2frames(self, video: str) -> list:
75
+ cap = cv2.VideoCapture(video)
76
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
77
+ images = []
78
+ frames_sec = [i for i in range(0, num_frames, 24*1)]
79
+ has_frames,image = cap.read()
80
+ frame_count = 0
81
+ while has_frames:
82
+ has_frames,image = cap.read()
83
+ frame_count += 1
84
+ if has_frames:
85
+ if frame_count in frames_sec:
86
+ image = Image.fromarray(image)
87
+ images.append(image)
88
+ return images
89
+
90
+ def get_img_embs(self, img_list: list) -> list:
91
+ """
92
+ takes list of image and calculates clip embeddings with model specified by clip_model
93
+ """
94
+ img_input = torch.stack([self.compose(img).to(self.device)
95
+ for img in img_list])
96
+ with torch.no_grad():
97
+ image_embs = self.clip_model.encode_image(img_input).float().cpu()
98
+ return image_embs
99
+
100
+ def get_txt_embs(self, text: str) -> torch.Tensor:
101
+ "calculates clip emebdding for the text "
102
+ with torch.no_grad():
103
+ return self.text_model(text, self.tokenizer)
104
+
105
+ def compare_embeddings(self, img_embs, txt_embs):
106
+ # normalized features
107
+ image_features = img_embs / img_embs.norm(dim=-1, keepdim=True)
108
+ text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)
109
+
110
+ # cosine similarity as logits
111
+ logits_per_image = []
112
+ for image_feature in image_features:
113
+ logits_per_image.append(image_feature @ text_features.t())
114
+
115
+ return logits_per_image
116
+
117
+ def download_yt_video(url):
118
+ ydl_opts = {
119
+ 'quiet': True,
120
+ "outtmpl": "%(id)s.%(ext)s",
121
+ 'format': 'bv*[height<=360][ext=mp4]+ba/b[height<=360] / wv*+ba/w'
122
+ }
123
+
124
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
125
+ ydl.download([url])
126
+ return url.split('/')[-1].replace('watch?v=', '')+'.mp4'
127
+
128
+
129
+ clip_model='ViT-B/32'
130
+ text_model='M-CLIP/XLM-Roberta-Large-Vit-B-32'
131
+ clip_model, compose = clip.load(clip_model)
132
+ tokenizer = transformers.AutoTokenizer.from_pretrained(text_model)
133
+ text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(text_model)
134
+
135
+ def search_video(video_url, text, video=None):
136
+ search = SearchVideo(
137
+ clip_model=clip_model,
138
+ text_model=text_model,
139
+ tokenizer=tokenizer,
140
+ compose=compose
141
+ )
142
+ if video !=None:
143
+ video_url = None
144
+ if video_url:
145
+ video = download_yt_video(video_url)
146
+ ind, seg_path, img = search(video, text)
147
+ start = time.strftime('%H:%M:%S', time.gmtime(ind))
148
+ return f'"{text}" found at {start}', seg_path
149
+
150
+ title = 'πŸ”ŽπŸŽžοΈπŸš€ Search inside a video'
151
+ description = '''Just enter a search query, a video URL or upload your video and get a 5-sec fragment from the video which is visually closest to you query.'''
152
+
153
+ examples = [["https://www.youtube.com/watch?v=M93w3TjzVUE", "A dog"]]
154
+
155
+ iface = gr.Interface(
156
+ search_video,
157
+ inputs=[gr.Textbox(value="https://www.youtube.com/watch?v=M93w3TjzVUE", label='Video URL'), gr.Textbox(value="a dog", label='Text query'), gr.Video()],
158
+ outputs=[gr.Textbox(label="Output"), gr.Video(label="Video segment")],
159
+ allow_flagging="never",
160
+ title=title,
161
+ description=description,
162
+ examples=examples
163
+ )
164
+
165
+ if __name__ == "__main__":
166
+ iface.launch(show_error=True)