yiyixuxu commited on
Commit
5f4ce2c
·
1 Parent(s): 15b3749

added batch processing for image encoding

Browse files
Files changed (1) hide show
  1. app.py +48 -43
app.py CHANGED
@@ -30,19 +30,24 @@ def select_video_format(url, format_note='480p', ext='mp4'):
30
  format_id = format.get('format_id', None)
31
  fps = format.get('fps', None)
32
  print(f'format selected: {format}')
33
- return(format_id, fps)
34
 
35
- def download_video(url,format_id):
36
- # testing
37
- print(f"testing...all the files in local directory: {os.listdir('.')}")
38
  ydl_opts = {
39
  'format':format_id,
40
- 'outtmpl': "%(id)s.%(ext)s"}
 
 
 
 
 
 
41
  with youtube_dl.YoutubeDL(ydl_opts) as ydl:
42
  try:
43
  ydl.cache.remove()
44
  meta = ydl.extract_info(url)
45
- save_location = meta['id'] + '.' + meta['ext']
46
  except youtube_dl.DownloadError as error:
47
  print(f'error with download_video function: {error}')
48
  return(save_location)
@@ -51,17 +56,17 @@ def process_video_parallel(video, skip_frames, dest_path, num_processes, process
51
  cap = cv2.VideoCapture(video)
52
  frames_per_process = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) // (num_processes)
53
  count = frames_per_process * process_number
 
54
  print(f"worker: {process_number}, process frames {count} ~ {frames_per_process * (process_number + 1)} \n total number of frames: {cap.get(cv2.CAP_PROP_FRAME_COUNT)} \n video: {video}; isOpen? : {cap.isOpened()}")
55
  while count < frames_per_process * (process_number + 1) :
56
  ret, frame = cap.read()
57
  if not ret:
58
  break
59
- count += 1
60
- if (count - frames_per_process * process_number) % skip_frames ==0:
61
  filename =f"{dest_path}/{count}.jpg"
62
  cv2.imwrite(filename, frame)
63
  #print(f"saved {filename}")
64
-
65
  cap.release()
66
 
67
 
@@ -74,9 +79,8 @@ def vid2frames(url, sampling_interval=1, ext='mp4'):
74
  shutil.rmtree(dest_path)
75
  dest_path.mkdir(parents=True)
76
  # figure out the format for download,
77
- # by default select 480p, if not available, choose the best format available
78
- # mp4
79
- format_id, fps = select_video_format(url, format_note='480p', ext='mp4')
80
  # download the video
81
  video = download_video(url,format_id)
82
  # calculate skip_frames
@@ -85,27 +89,16 @@ def vid2frames(url, sampling_interval=1, ext='mp4'):
85
  except:
86
  skip_frames = int(30 * sampling_interval)
87
 
88
-
89
  print(f'video saved at: {video}, fps:{fps}, skip_frames: {skip_frames}')
90
  # extract video frames at given sampling interval with multiprocessing -
91
- print('extracting frames...')
92
- n_workers = min(os.cpu_count(), 1)
93
- # testing..
94
- cap = cv2.VideoCapture(video)
95
- print(f'video: {video}; isOpen? : {cap.isOpened()}')
96
- print(f'n_workers: {n_workers}')
97
  with Pool(n_workers) as pool:
98
  pool.map(partial(process_video_parallel, video, skip_frames, dest_path, n_workers), range(n_workers))
99
- # read frames
100
- original_images = []
101
- images = []
102
- filenames = sorted(dest_path.glob('*.jpg'),key=lambda p: int(p.stem))
103
- print(f"extracted {len(filenames)} frames")
104
- for filename in filenames:
105
- image = Image.open(filename).convert("RGB")
106
- original_images.append(image)
107
- images.append(preprocess(image))
108
- return original_images, images
109
 
110
 
111
  def captioned_strip(images, caption=None, times=None, rows=1):
@@ -116,8 +109,6 @@ def captioned_strip(images, caption=None, times=None, rows=1):
116
  img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
117
  if caption is not None:
118
  draw = ImageDraw.Draw(img)
119
- #font = ImageFont.load_default()
120
- #font_small = ImageFont.truetype("arial.pil", 12)
121
  font = ImageFont.truetype(
122
  "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 16
123
  )
@@ -131,26 +122,40 @@ def captioned_strip(images, caption=None, times=None, rows=1):
131
  (255, 255, 255), font=font_small)
132
  return img
133
 
134
- def run_inference(url, sampling_interval, search_query):
135
- original_images, images = vid2frames(url,sampling_interval)
136
- image_input = torch.tensor(np.stack(images)).to(device)
137
- print("testing.. created image_input")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  with torch.no_grad():
139
- image_features = model.encode_image(image_input)
140
  text_features = model.encode_text(clip.tokenize(search_query).to(device))
141
-
142
- image_features /= image_features.norm(dim=-1, keepdim=True)
143
- text_features /= text_features.norm(dim=-1, keepdim=True)
144
 
145
  similarity = (100.0 * image_features @ text_features.T)
146
  values, indices = similarity.topk(4, dim=0)
147
- print("testing.. selected best frames")
148
- best_frames = [original_images[ind] for ind in indices]
149
  times = [f'{datetime.timedelta(seconds = ind[0].item() * sampling_interval)}' for ind in indices]
150
- print("testing... before captioned_strip func")
151
  image_output = captioned_strip(best_frames,search_query, times,2)
152
  title = search_query
153
- print("testing... after captioned_strip func")
154
  return(title, image_output)
155
 
156
  inputs = [gr.inputs.Textbox(label="Give us the link to your youtube video!"),
 
30
  format_id = format.get('format_id', None)
31
  fps = format.get('fps', None)
32
  print(f'format selected: {format}')
33
+ return(format, format_id, fps)
34
 
35
+ # to-do: delete saved videos
36
+ def download_video(url,format_id, n_keep=10):
 
37
  ydl_opts = {
38
  'format':format_id,
39
+ 'outtmpl': "videos/%(id)s.%(ext)s"}
40
+ # create a directory for saved videos
41
+ video_path = Path('videos')
42
+ try:
43
+ video_path.mkdir(parents=True)
44
+ except FileExistsError:
45
+ pass
46
  with youtube_dl.YoutubeDL(ydl_opts) as ydl:
47
  try:
48
  ydl.cache.remove()
49
  meta = ydl.extract_info(url)
50
+ save_location = 'videos/' + meta['id'] + '.' + meta['ext']
51
  except youtube_dl.DownloadError as error:
52
  print(f'error with download_video function: {error}')
53
  return(save_location)
 
56
  cap = cv2.VideoCapture(video)
57
  frames_per_process = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) // (num_processes)
58
  count = frames_per_process * process_number
59
+ cap.set(cv2.CAP_PROP_POS_FRAMES, count)
60
  print(f"worker: {process_number}, process frames {count} ~ {frames_per_process * (process_number + 1)} \n total number of frames: {cap.get(cv2.CAP_PROP_FRAME_COUNT)} \n video: {video}; isOpen? : {cap.isOpened()}")
61
  while count < frames_per_process * (process_number + 1) :
62
  ret, frame = cap.read()
63
  if not ret:
64
  break
65
+ if count % skip_frames ==0:
 
66
  filename =f"{dest_path}/{count}.jpg"
67
  cv2.imwrite(filename, frame)
68
  #print(f"saved {filename}")
69
+ count += 1
70
  cap.release()
71
 
72
 
 
79
  shutil.rmtree(dest_path)
80
  dest_path.mkdir(parents=True)
81
  # figure out the format for download,
82
+ # by default select 480p and .mp4
83
+ format, format_id, fps = select_video_format(url, format_note='480p', ext='mp4')
 
84
  # download the video
85
  video = download_video(url,format_id)
86
  # calculate skip_frames
 
89
  except:
90
  skip_frames = int(30 * sampling_interval)
91
 
 
92
  print(f'video saved at: {video}, fps:{fps}, skip_frames: {skip_frames}')
93
  # extract video frames at given sampling interval with multiprocessing -
94
+ n_workers = min(os.cpu_count(), 12)
95
+
96
+ print(f'now extracting frames with {n_workers} process...')
97
+
 
 
98
  with Pool(n_workers) as pool:
99
  pool.map(partial(process_video_parallel, video, skip_frames, dest_path, n_workers), range(n_workers))
100
+ return(skip_frames, dest_path)
101
+
 
 
 
 
 
 
 
 
102
 
103
 
104
  def captioned_strip(images, caption=None, times=None, rows=1):
 
109
  img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
110
  if caption is not None:
111
  draw = ImageDraw.Draw(img)
 
 
112
  font = ImageFont.truetype(
113
  "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 16
114
  )
 
122
  (255, 255, 255), font=font_small)
123
  return img
124
 
125
+ def run_inference(url, sampling_interval, search_query, bs=256):
126
+ skip_frames, path_frames= vid2frames(url,sampling_interval)
127
+ filenames = sorted(path_frames.glob('*.jpg'),key=lambda p: int(p.stem))
128
+ n_frames = len(filenames)
129
+ bs = min(n_frames,bs)
130
+ print(f"extracted {n_frames} frames, now encoding images")
131
+ # encoding images one batch at a time, combine all batch outputs -> image_features, size n_frames x 512
132
+ image_features = torch.empty(size=(n_frames, 512), dtype=torch.float16).to(device)
133
+ print(f"batch size :{bs} ; number of batches: {len(range(0, n_frames,bs))}")
134
+ for b in range(0, n_frames,bs):
135
+ images = []
136
+ # loop through all frames in the batch -> create batch_image_input, size bs x 3 x 224 x 224
137
+ for filename in filenames[b:b+bs]:
138
+ image = Image.open(filename).convert("RGB")
139
+ images.append(preprocess(image))
140
+ batch_image_input = torch.tensor(np.stack(images)).to(device)
141
+ # encoding batch_image_input -> batch_image_features
142
+ with torch.no_grad():
143
+ batch_image_features = model.encode_image(batch_image_input)
144
+ batch_image_features /= batch_image_features.norm(dim=-1, keepdim=True)
145
+ # add encoded image embedding to image_features
146
+ image_features[b:b+bs] = batch_image_features
147
+ # encoding search query
148
  with torch.no_grad():
 
149
  text_features = model.encode_text(clip.tokenize(search_query).to(device))
150
+ text_features /= text_features.norm(dim=-1, keepdim=True)
 
 
151
 
152
  similarity = (100.0 * image_features @ text_features.T)
153
  values, indices = similarity.topk(4, dim=0)
154
+
155
+ best_frames = [Image.open(filenames[ind]).convert("RGB") for ind in indices]
156
  times = [f'{datetime.timedelta(seconds = ind[0].item() * sampling_interval)}' for ind in indices]
 
157
  image_output = captioned_strip(best_frames,search_query, times,2)
158
  title = search_query
 
159
  return(title, image_output)
160
 
161
  inputs = [gr.inputs.Textbox(label="Give us the link to your youtube video!"),