Spaces:
Build error
Build error
Iskaj
commited on
Commit
•
39557de
1
Parent(s):
8d6b883
added new plotting logic to a new gradio tab
Browse files
app.py
CHANGED
@@ -185,6 +185,73 @@ def plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = 3):
|
|
185 |
logging.basicConfig()
|
186 |
logging.getLogger().setLevel(logging.INFO)
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
def get_comparison(url, target, MIN_DISTANCE = 4):
|
189 |
""" Function for Gradio to combine all helper functions"""
|
190 |
video_index, hash_vectors, target_indices = get_video_indices(url, target, MIN_DISTANCE = MIN_DISTANCE)
|
@@ -197,7 +264,9 @@ def get_auto_comparison(url, target, MIN_DISTANCE = MIN_DISTANCE):
|
|
197 |
distance = get_decent_distance(url, target, MIN_DISTANCE, MAX_DISTANCE)
|
198 |
video_index, hash_vectors, target_indices = get_video_indices(url, target, MIN_DISTANCE = distance)
|
199 |
lims, D, I, hash_vectors = compare_videos(video_index, hash_vectors, target_indices, MIN_DISTANCE = distance)
|
200 |
-
fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = distance)
|
|
|
|
|
201 |
return fig
|
202 |
|
203 |
video_urls = ["https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
|
|
|
185 |
logging.basicConfig()
|
186 |
logging.getLogger().setLevel(logging.INFO)
|
187 |
|
188 |
+
def plot_multi_comparison(df):
|
189 |
+
fig, ax_arr = plt.subplots(3, 2, figsize=(12, 6), dpi=100, sharex=True) # , ax=axes[1]
|
190 |
+
# plt.scatter(x=df['TARGET_S'], y = df['SOURCE_S'], ax=ax_arr[0])
|
191 |
+
# plt.scatter(x=df['TARGET_S'], y = df['SOURCE_S'], ax=ax_arr[1])
|
192 |
+
sns.scatterplot(data = df, x='TARGET_S', y='SOURCE_S', ax=ax_arr[0,0])
|
193 |
+
sns.lineplot(data = df, x='TARGET_S', y='SOURCE_LIP_S', ax=ax_arr[0,1])
|
194 |
+
sns.scatterplot(data = df, x='TARGET_S', y='TIMESHIFT', ax=ax_arr[1,0])
|
195 |
+
sns.lineplot(data = df, x='TARGET_S', y='TIMESHIFT_LIP', ax=ax_arr[1,1])
|
196 |
+
sns.scatterplot(data = df, x='TARGET_S', y='OFFSET', ax=ax_arr[2,0])
|
197 |
+
sns.lineplot(data = df, x='TARGET_S', y='OFFSET_LIP', ax=ax_arr[2,1])
|
198 |
+
return fig
|
199 |
+
|
200 |
+
|
201 |
+
def get_videomatch_df(url, target, min_distance=MIN_DISTANCE, vanilla_df=False):
|
202 |
+
distance = get_decent_distance(url, target, MIN_DISTANCE, MAX_DISTANCE)
|
203 |
+
video_index, hash_vectors, target_indices = get_video_indices(url, target, MIN_DISTANCE = distance)
|
204 |
+
lims, D, I, hash_vectors = compare_videos(video_index, hash_vectors, target_indices, MIN_DISTANCE = distance)
|
205 |
+
|
206 |
+
target = [(lims[i+1]-lims[i]) * [i] for i in range(hash_vectors.shape[0])]
|
207 |
+
target_s = [i/FPS for j in target for i in j]
|
208 |
+
source_s = [i/FPS for i in I]
|
209 |
+
|
210 |
+
# Make df
|
211 |
+
df = pd.DataFrame(zip(target_s, source_s, D, I), columns = ['TARGET_S', 'SOURCE_S', 'DISTANCE', 'INDICES'])
|
212 |
+
if vanilla_df:
|
213 |
+
return df
|
214 |
+
|
215 |
+
# Minimum distance dataframe ----
|
216 |
+
# Group by X so for every second/x there will be 1 value of Y in the end
|
217 |
+
# index_min_distance = df.groupby('TARGET_S')['DISTANCE'].idxmin()
|
218 |
+
# df_min = df.loc[index_min_distance]
|
219 |
+
# df_min
|
220 |
+
# -------------------------------
|
221 |
+
|
222 |
+
df['TARGET_WEIGHT'] = 1 - df['DISTANCE']/distance # Higher value means a better match
|
223 |
+
df['SOURCE_WEIGHTED_VALUE'] = df['SOURCE_S'] * df['TARGET_WEIGHT'] # Multiply the weight (which indicates a better match) with the value for Y and aggregate to get a less noisy estimate of Y
|
224 |
+
|
225 |
+
# Group by X so for every second/x there will be 1 value of Y in the end
|
226 |
+
grouped_X = df.groupby('TARGET_S').agg({'SOURCE_WEIGHTED_VALUE' : 'sum', 'TARGET_WEIGHT' : 'sum'})
|
227 |
+
grouped_X['FINAL_SOURCE_VALUE'] = grouped_X['SOURCE_WEIGHTED_VALUE'] / grouped_X['TARGET_WEIGHT']
|
228 |
+
|
229 |
+
# Remake the dataframe
|
230 |
+
df = grouped_X.reset_index()
|
231 |
+
df = df.drop(columns=['SOURCE_WEIGHTED_VALUE', 'TARGET_WEIGHT'])
|
232 |
+
df = df.rename({'FINAL_SOURCE_VALUE' : 'SOURCE_S'}, axis='columns')
|
233 |
+
|
234 |
+
# Add NAN to "missing" x values (base it off hash vector, not target_s)
|
235 |
+
step_size = 1/FPS
|
236 |
+
x_complete = np.round(np.arange(start=0.0, stop = max(df['TARGET_S'])+step_size, step = step_size), 1) # More robust
|
237 |
+
df['TARGET_S'] = np.round(df['TARGET_S'], 1)
|
238 |
+
df_complete = pd.DataFrame(x_complete, columns=['TARGET_S'])
|
239 |
+
|
240 |
+
# Merge dataframes to get NAN values for every missing SOURCE_S
|
241 |
+
df = df_complete.merge(df, on='TARGET_S', how='left')
|
242 |
+
|
243 |
+
# Interpolate between frames since there are missing values
|
244 |
+
df['SOURCE_LIP_S'] = df['SOURCE_S'].interpolate(method='linear', limit_direction='both', axis=0)
|
245 |
+
|
246 |
+
# Add timeshift col and timeshift col with Linearly Interpolated Values
|
247 |
+
df['TIMESHIFT'] = df['SOURCE_S'].shift(1) - df['SOURCE_S']
|
248 |
+
df['TIMESHIFT_LIP'] = df['SOURCE_LIP_S'].shift(1) - df['SOURCE_LIP_S']
|
249 |
+
|
250 |
+
# Add Offset col that assumes the video is played at the same speed as the other to do a "timeshift"
|
251 |
+
df['OFFSET'] = df['SOURCE_S'] - df['TARGET_S'] - np.min(df['SOURCE_S'])
|
252 |
+
df['OFFSET_LIP'] = df['SOURCE_LIP_S'] - df['TARGET_S'] - np.min(df['SOURCE_LIP_S'])
|
253 |
+
return df
|
254 |
+
|
255 |
def get_comparison(url, target, MIN_DISTANCE = 4):
|
256 |
""" Function for Gradio to combine all helper functions"""
|
257 |
video_index, hash_vectors, target_indices = get_video_indices(url, target, MIN_DISTANCE = MIN_DISTANCE)
|
|
|
264 |
distance = get_decent_distance(url, target, MIN_DISTANCE, MAX_DISTANCE)
|
265 |
video_index, hash_vectors, target_indices = get_video_indices(url, target, MIN_DISTANCE = distance)
|
266 |
lims, D, I, hash_vectors = compare_videos(video_index, hash_vectors, target_indices, MIN_DISTANCE = distance)
|
267 |
+
# fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = distance)
|
268 |
+
df = get_videomatch_df(url, target, min_distance=MIN_DISTANCE, vanilla_df=False)
|
269 |
+
fig = plot_multi_comparison(df)
|
270 |
return fig
|
271 |
|
272 |
video_urls = ["https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
|