Spaces:
Build error
Build error
Iskaj
commited on
Commit
•
1991773
1
Parent(s):
9807395
changed to work with apb files, plotting seperated from decision
Browse files- app.py +74 -29
- clip_data.ipynb +3 -3
- config.py +2 -1
- plot.py +27 -18
- videomatch.py +7 -1
app.py
CHANGED
@@ -1,16 +1,31 @@
|
|
1 |
import logging
|
|
|
|
|
|
|
2 |
|
3 |
import gradio as gr
|
|
|
4 |
|
5 |
from config import *
|
6 |
from videomatch import index_hashes_for_video, get_decent_distance, \
|
7 |
-
get_video_index, compare_videos, get_change_points, get_videomatch_df
|
|
|
8 |
from plot import plot_comparison, plot_multi_comparison, plot_segment_comparison
|
9 |
|
10 |
logging.basicConfig()
|
11 |
logging.getLogger().setLevel(logging.INFO)
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def get_comparison(url, target, MIN_DISTANCE = 4):
|
15 |
""" Function for Gradio to combine all helper functions"""
|
16 |
video_index, hash_vectors = get_video_index(url)
|
@@ -19,25 +34,53 @@ def get_comparison(url, target, MIN_DISTANCE = 4):
|
|
19 |
fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = MIN_DISTANCE)
|
20 |
return fig
|
21 |
|
22 |
-
def
|
23 |
-
|
24 |
source_index, source_hash_vectors = get_video_index(url)
|
25 |
target_index, _ = get_video_index(target)
|
|
|
|
|
26 |
distance = get_decent_distance(source_index, source_hash_vectors, target_index, MIN_DISTANCE, MAX_DISTANCE)
|
27 |
if distance == None:
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
for i in range(0, 1):
|
33 |
lims, D, I, hash_vectors = compare_videos(source_hash_vectors, target_index, MIN_DISTANCE = distance)
|
|
|
|
|
34 |
df = get_videomatch_df(lims, D, I, hash_vectors, distance)
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
|
|
|
|
|
|
39 |
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
41 |
"https://www.dropbox.com/s/rzmicviu1fe740t/Bram%20van%20Ojik%20krijgt%20reprimande.mp4?dl=1",
|
42 |
"https://www.dropbox.com/s/wcot34ldmb84071/Baudet%20ontmaskert%20Omtzigt_%20u%20bent%20door%20de%20mand%20gevallen%21.mp4?dl=1",
|
43 |
"https://drive.google.com/uc?id=1XW0niHR1k09vPNv1cp6NvdGXe7FHJc1D&export=download",
|
@@ -46,22 +89,24 @@ video_urls = ["https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
|
|
46 |
index_iface = gr.Interface(fn=lambda url: index_hashes_for_video(url).ntotal,
|
47 |
inputs="text",
|
48 |
outputs="text",
|
49 |
-
examples=
|
50 |
-
|
51 |
-
compare_iface = gr.Interface(fn=get_comparison,
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
inputs=["text",
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
|
66 |
if __name__ == "__main__":
|
67 |
import matplotlib
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
|
6 |
import gradio as gr
|
7 |
+
from faiss import read_index_binary, write_index_binary
|
8 |
|
9 |
from config import *
|
10 |
from videomatch import index_hashes_for_video, get_decent_distance, \
|
11 |
+
get_video_index, compare_videos, get_change_points, get_videomatch_df, \
|
12 |
+
get_target_urls
|
13 |
from plot import plot_comparison, plot_multi_comparison, plot_segment_comparison
|
14 |
|
15 |
logging.basicConfig()
|
16 |
logging.getLogger().setLevel(logging.INFO)
|
17 |
|
18 |
+
def transfer_data_indices_to_temp(temp_path = VIDEO_DIRECTORY, data_path='./data'):
|
19 |
+
""" The binary indices created from the .json file are not stored in the temporary directory
|
20 |
+
This function will load these indices and write them to the temporary directory.
|
21 |
+
Doing it this way reserves the way to link dynamically downloaded files and the static
|
22 |
+
files are the same """
|
23 |
+
index_files = os.listdir(data_path)
|
24 |
+
for index_file in index_files:
|
25 |
+
# Read from static location and write to temp storage
|
26 |
+
binary_index = read_index_binary(os.path.join(data_path, index_file))
|
27 |
+
write_index_binary(binary_index, f'{temp_path}/{index_file}')
|
28 |
+
|
29 |
def get_comparison(url, target, MIN_DISTANCE = 4):
|
30 |
""" Function for Gradio to combine all helper functions"""
|
31 |
video_index, hash_vectors = get_video_index(url)
|
|
|
34 |
fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = MIN_DISTANCE)
|
35 |
return fig
|
36 |
|
37 |
+
def compare(url, target):
|
38 |
+
# Get source and target indices
|
39 |
source_index, source_hash_vectors = get_video_index(url)
|
40 |
target_index, _ = get_video_index(target)
|
41 |
+
|
42 |
+
# Get decent distance by comparing url index with the target hash vectors + target index
|
43 |
distance = get_decent_distance(source_index, source_hash_vectors, target_index, MIN_DISTANCE, MAX_DISTANCE)
|
44 |
if distance == None:
|
45 |
+
logging.info(f"No matches found between {url} and {target}!")
|
46 |
+
return plt.figure(), []
|
47 |
+
else:
|
48 |
+
# Compare videos with heuristic distance
|
|
|
49 |
lims, D, I, hash_vectors = compare_videos(source_hash_vectors, target_index, MIN_DISTANCE = distance)
|
50 |
+
|
51 |
+
# Get dataframe holding all information
|
52 |
df = get_videomatch_df(lims, D, I, hash_vectors, distance)
|
53 |
+
|
54 |
+
# Determine change point using ROBUST method based on column ROLL_OFFSET_MODE
|
55 |
+
change_points = get_change_points(df, metric="ROLL_OFFSET_MODE", method="ROBUST")
|
56 |
+
|
57 |
+
# Plot and get figure and .json-style segment decision
|
58 |
+
fig, segment_decision = plot_segment_comparison(df, change_points, video_id=target)
|
59 |
+
return fig, segment_decision
|
60 |
+
|
61 |
+
def multiple_comparison(url, return_figure=False):
|
62 |
+
targets = get_target_urls()
|
63 |
+
|
64 |
+
# Figure and decision (list of dicts) storage
|
65 |
+
figures, decisions = [], []
|
66 |
+
for target in targets:
|
67 |
+
# Make comparison
|
68 |
+
fig, segment_decision = compare(url, target)
|
69 |
+
|
70 |
+
# Add decisions to global decision list
|
71 |
+
decisions.extend(segment_decision)
|
72 |
+
figures.append(fig)
|
73 |
|
74 |
+
if return_figure:
|
75 |
+
return figures
|
76 |
+
return decisions
|
77 |
|
78 |
+
def plot_multiple_comparison(url):
|
79 |
+
return multiple_comparison(url, return_figure=True)
|
80 |
+
|
81 |
+
transfer_data_indices_to_temp() # NOTE: Only works after doing 'git lfs pull' to actually obtain the .index files
|
82 |
+
example_video_urls = ["https://drive.google.com/uc?id=1Y1-ypXOvLrp1x0cjAe_hMobCEdA0UbEo&export=download",
|
83 |
+
"https://www.dropbox.com/s/8c89a9aba0w8gjg/Ploumen.mp4?dl=1",
|
84 |
"https://www.dropbox.com/s/rzmicviu1fe740t/Bram%20van%20Ojik%20krijgt%20reprimande.mp4?dl=1",
|
85 |
"https://www.dropbox.com/s/wcot34ldmb84071/Baudet%20ontmaskert%20Omtzigt_%20u%20bent%20door%20de%20mand%20gevallen%21.mp4?dl=1",
|
86 |
"https://drive.google.com/uc?id=1XW0niHR1k09vPNv1cp6NvdGXe7FHJc1D&export=download",
|
|
|
89 |
index_iface = gr.Interface(fn=lambda url: index_hashes_for_video(url).ntotal,
|
90 |
inputs="text",
|
91 |
outputs="text",
|
92 |
+
examples=example_video_urls, cache_examples=True)
|
93 |
+
|
94 |
+
# compare_iface = gr.Interface(fn=get_comparison,
|
95 |
+
# inputs=["text", "text", gr.Slider(2, 30, 4, step=2)],
|
96 |
+
# outputs="plot",
|
97 |
+
# examples=[[x, example_video_urls[-1]] for x in example_video_urls[:-1]])
|
98 |
+
|
99 |
+
plot_compare_iface = gr.Interface(fn=plot_multiple_comparison,
|
100 |
+
inputs=["text"],
|
101 |
+
outputs=[gr.Plot() for _ in range(len(get_target_urls()))],
|
102 |
+
examples=example_video_urls)
|
103 |
+
|
104 |
+
auto_compare_iface = gr.Interface(fn=multiple_comparison,
|
105 |
+
inputs=["text"],
|
106 |
+
outputs=["json"],
|
107 |
+
examples=example_video_urls)
|
108 |
+
|
109 |
+
iface = gr.TabbedInterface([auto_compare_iface, plot_compare_iface, index_iface], ["AutoCompare", "PlotAutoCompare", "Index"])
|
110 |
|
111 |
if __name__ == "__main__":
|
112 |
import matplotlib
|
clip_data.ipynb
CHANGED
@@ -395,7 +395,7 @@
|
|
395 |
],
|
396 |
"metadata": {
|
397 |
"kernelspec": {
|
398 |
-
"display_name": "Python 3.9.
|
399 |
"language": "python",
|
400 |
"name": "python3"
|
401 |
},
|
@@ -409,12 +409,12 @@
|
|
409 |
"name": "python",
|
410 |
"nbconvert_exporter": "python",
|
411 |
"pygments_lexer": "ipython3",
|
412 |
-
"version": "3.9.
|
413 |
},
|
414 |
"orig_nbformat": 4,
|
415 |
"vscode": {
|
416 |
"interpreter": {
|
417 |
-
"hash": "
|
418 |
}
|
419 |
}
|
420 |
},
|
|
|
395 |
],
|
396 |
"metadata": {
|
397 |
"kernelspec": {
|
398 |
+
"display_name": "Python 3.9.13 64-bit",
|
399 |
"language": "python",
|
400 |
"name": "python3"
|
401 |
},
|
|
|
409 |
"name": "python",
|
410 |
"nbconvert_exporter": "python",
|
411 |
"pygments_lexer": "ipython3",
|
412 |
+
"version": "3.9.13"
|
413 |
},
|
414 |
"orig_nbformat": 4,
|
415 |
"vscode": {
|
416 |
"interpreter": {
|
417 |
+
"hash": "397704579725e15f5c7cb49fe5f0341eb7531c82d19f2c29d197e8b64ab5776b"
|
418 |
}
|
419 |
}
|
420 |
},
|
config.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import tempfile
|
2 |
|
3 |
VIDEO_DIRECTORY = tempfile.gettempdir()
|
|
|
4 |
|
5 |
FPS = 5
|
6 |
MIN_DISTANCE = 4
|
7 |
-
MAX_DISTANCE = 30
|
8 |
ROLLING_WINDOW_SIZE = 10
|
|
|
1 |
import tempfile
|
2 |
|
3 |
VIDEO_DIRECTORY = tempfile.gettempdir()
|
4 |
+
# VIDEO_DIRECTORY = './data/'
|
5 |
|
6 |
FPS = 5
|
7 |
MIN_DISTANCE = 4
|
8 |
+
MAX_DISTANCE = 30 # Used to be 30
|
9 |
ROLLING_WINDOW_SIZE = 10
|
plot.py
CHANGED
@@ -69,33 +69,40 @@ def add_seconds_to_datetime64(datetime64, seconds, subtract=False):
|
|
69 |
return datetime64 - np.timedelta64(int(s), 's') - np.timedelta64(int(m * 1000), 'ms')
|
70 |
return datetime64 + np.timedelta64(int(s), 's') + np.timedelta64(int(m * 1000), 'ms')
|
71 |
|
72 |
-
def plot_segment_comparison(df, change_points, video_id="Placeholder_Video_ID"):
|
73 |
-
"""
|
74 |
-
|
|
|
75 |
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
# Plot
|
80 |
-
sns.
|
81 |
|
82 |
-
# Plot
|
83 |
metric = 'ROLL_OFFSET_MODE' # 'OFFSET'
|
84 |
-
sns.
|
|
|
85 |
|
86 |
-
# Plot
|
87 |
-
sns.scatterplot(data = df, x='time', y=metric, ax=ax_arr[
|
88 |
timestamps = change_points_to_segments(df, change_points)
|
89 |
-
|
|
|
|
|
90 |
# To store "decisions" about segments
|
91 |
segment_decisions = []
|
92 |
seg_i = 0
|
93 |
|
94 |
-
#
|
95 |
-
for x in timestamps:
|
96 |
-
plt.vlines(x=x, ymin=np.min(df[metric]), ymax=np.max(df[metric]), colors='black', lw=2, alpha=0.5)
|
97 |
-
|
98 |
-
threshold_diff = 1.5 # Average segment difference threshold for plotting
|
99 |
for start_time, end_time in zip(timestamps[:-1], timestamps[1:]):
|
100 |
|
101 |
# Time to add to each origin time to get the correct time back since it is offset by add_offset
|
@@ -149,4 +156,6 @@ def plot_segment_comparison(df, change_points, video_id="Placeholder_Video_ID"):
|
|
149 |
|
150 |
# Return figure
|
151 |
plt.xticks(rotation=90)
|
152 |
-
return fig, segment_decisions
|
|
|
|
|
|
69 |
return datetime64 - np.timedelta64(int(s), 's') - np.timedelta64(int(m * 1000), 'ms')
|
70 |
return datetime64 + np.timedelta64(int(s), 's') + np.timedelta64(int(m * 1000), 'ms')
|
71 |
|
72 |
+
def plot_segment_comparison(df, change_points, video_id="Placeholder_Video_ID", threshold_diff = 1.5):
|
73 |
+
""" Based on the dataframe and detected change points do two things:
|
74 |
+
1. Make a decision on where each segment belongs in time and return that info as a list of dicts
|
75 |
+
2. Plot how this decision got made as an informative plot
|
76 |
|
77 |
+
args:
|
78 |
+
- df: dataframe
|
79 |
+
- change_points: detected points in time where the average metric value changes
|
80 |
+
- video_id: the unique identifier for the video currently being compared
|
81 |
+
- threshold_diff: to plot which segments are likely bad matches
|
82 |
+
"""
|
83 |
+
fig, ax_arr = plt.subplots(4, 1, figsize=(16, 6), dpi=300, sharex=True)
|
84 |
+
ax_arr[0].set_title(video_id)
|
85 |
+
sns.scatterplot(data = df, x='time', y='SOURCE_S', ax=ax_arr[0], label="SOURCE_S", color='blue', alpha=1.0)
|
86 |
|
87 |
+
# Plot original datapoints without linear interpolation, offset by target video time
|
88 |
+
sns.scatterplot(data = df, x='time', y='OFFSET', ax=ax_arr[1], label="OFFSET", color='orange', alpha=1.0)
|
89 |
|
90 |
+
# Plot linearly interpolated values next to metric vales
|
91 |
metric = 'ROLL_OFFSET_MODE' # 'OFFSET'
|
92 |
+
sns.lineplot(data = df, x='time', y='OFFSET_LIP', ax=ax_arr[2], label="OFFSET_LIP", color='orange')
|
93 |
+
sns.scatterplot(data = df, x='time', y=metric, ax=ax_arr[2], label=metric, alpha=0.5)
|
94 |
|
95 |
+
# Plot detected change points as lines which will indicate the segments
|
96 |
+
sns.scatterplot(data = df, x='time', y=metric, ax=ax_arr[3], label=metric, s=20)
|
97 |
timestamps = change_points_to_segments(df, change_points)
|
98 |
+
for x in timestamps:
|
99 |
+
plt.vlines(x=x, ymin=np.min(df[metric]), ymax=np.max(df[metric]), colors='black', lw=2, alpha=0.5)
|
100 |
+
|
101 |
# To store "decisions" about segments
|
102 |
segment_decisions = []
|
103 |
seg_i = 0
|
104 |
|
105 |
+
# Average segment difference threshold for plotting
|
|
|
|
|
|
|
|
|
106 |
for start_time, end_time in zip(timestamps[:-1], timestamps[1:]):
|
107 |
|
108 |
# Time to add to each origin time to get the correct time back since it is offset by add_offset
|
|
|
156 |
|
157 |
# Return figure
|
158 |
plt.xticks(rotation=90)
|
159 |
+
return fig, segment_decisions
|
160 |
+
|
161 |
+
|
videomatch.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
import logging
|
3 |
-
|
4 |
import faiss
|
5 |
|
6 |
from kats.detectors.cusum_detection import CUSUMDetector
|
@@ -15,6 +15,12 @@ import pandas as pd
|
|
15 |
from videohash import compute_hashes, filepath_from_url
|
16 |
from config import FPS, MIN_DISTANCE, MAX_DISTANCE, ROLLING_WINDOW_SIZE
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def index_hashes_for_video(url: str) -> faiss.IndexBinaryIVF:
|
19 |
""" Compute hashes of a video and index the video using faiss indices and return the index. """
|
20 |
filepath = filepath_from_url(url)
|
|
|
1 |
import os
|
2 |
import logging
|
3 |
+
import json
|
4 |
import faiss
|
5 |
|
6 |
from kats.detectors.cusum_detection import CUSUMDetector
|
|
|
15 |
from videohash import compute_hashes, filepath_from_url
|
16 |
from config import FPS, MIN_DISTANCE, MAX_DISTANCE, ROLLING_WINDOW_SIZE
|
17 |
|
18 |
+
def get_target_urls(json_file='apb2022.json'):
|
19 |
+
""" Obtain target urls for the target videos of a json file containing .mp4 files """
|
20 |
+
with open('apb2022.json', "r") as json_file:
|
21 |
+
target_videos = json.load(json_file)
|
22 |
+
return [video['mp4'] for video in target_videos]
|
23 |
+
|
24 |
def index_hashes_for_video(url: str) -> faiss.IndexBinaryIVF:
|
25 |
""" Compute hashes of a video and index the video using faiss indices and return the index. """
|
26 |
filepath = filepath_from_url(url)
|