Spaces:
Build error
Build error
Iskaj
commited on
Commit
•
2a1a736
1
Parent(s):
0112deb
add segment based decision plotting
Browse files
app.py
CHANGED
@@ -5,13 +5,13 @@ import gradio as gr
|
|
5 |
from config import *
|
6 |
from videomatch import index_hashes_for_video, get_decent_distance, \
|
7 |
get_video_index, compare_videos, get_change_points, get_videomatch_df
|
8 |
-
from plot import plot_comparison, plot_multi_comparison
|
9 |
|
10 |
logging.basicConfig()
|
11 |
logging.getLogger().setLevel(logging.INFO)
|
12 |
|
13 |
|
14 |
-
def get_comparison(url, target, MIN_DISTANCE = 4):
|
15 |
""" Function for Gradio to combine all helper functions"""
|
16 |
video_index, hash_vectors = get_video_index(url)
|
17 |
target_index, _ = get_video_index(target)
|
@@ -31,7 +31,7 @@ def get_auto_comparison(url, target, smoothing_window_size=10, method="CUSUM"):
|
|
31 |
# fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = distance)
|
32 |
df = get_videomatch_df(url, target, min_distance=MIN_DISTANCE, vanilla_df=False)
|
33 |
change_points = get_change_points(df, smoothing_window_size=smoothing_window_size, method=method)
|
34 |
-
fig =
|
35 |
return fig
|
36 |
|
37 |
def get_auto_edit_decision(url, target, smoothing_window_size=10):
|
|
|
5 |
from config import *
|
6 |
from videomatch import index_hashes_for_video, get_decent_distance, \
|
7 |
get_video_index, compare_videos, get_change_points, get_videomatch_df
|
8 |
+
from plot import plot_comparison, plot_multi_comparison, plot_segment_comparison
|
9 |
|
10 |
logging.basicConfig()
|
11 |
logging.getLogger().setLevel(logging.INFO)
|
12 |
|
13 |
|
14 |
+
def get_comparison(url, target, MIN_DISTANCE = 4):
|
15 |
""" Function for Gradio to combine all helper functions"""
|
16 |
video_index, hash_vectors = get_video_index(url)
|
17 |
target_index, _ = get_video_index(target)
|
|
|
31 |
# fig = plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = distance)
|
32 |
df = get_videomatch_df(url, target, min_distance=MIN_DISTANCE, vanilla_df=False)
|
33 |
change_points = get_change_points(df, smoothing_window_size=smoothing_window_size, method=method)
|
34 |
+
fig = plot_segment_comparison(df, change_points)
|
35 |
return fig
|
36 |
|
37 |
def get_auto_edit_decision(url, target, smoothing_window_size=10):
|
plot.py
CHANGED
@@ -55,4 +55,67 @@ def plot_multi_comparison(df, change_points):
|
|
55 |
rand_y_pos = np.random.uniform(low=np.min(df['OFFSET_LIP']), high=np.max(df['OFFSET_LIP']), size=None)
|
56 |
plt.text(x=cp_time, y=rand_y_pos, s=str(np.round(x.confidence, 2)), color='r', rotation=-0.0, fontsize=14)
|
57 |
plt.xticks(rotation=90)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
return fig
|
|
|
55 |
rand_y_pos = np.random.uniform(low=np.min(df['OFFSET_LIP']), high=np.max(df['OFFSET_LIP']), size=None)
|
56 |
plt.text(x=cp_time, y=rand_y_pos, s=str(np.round(x.confidence, 2)), color='r', rotation=-0.0, fontsize=14)
|
57 |
plt.xticks(rotation=90)
|
58 |
+
return fig
|
59 |
+
|
60 |
+
def change_points_to_segments(df, change_points):
|
61 |
+
""" Convert change points from kats detector to segment indicators """
|
62 |
+
return [pd.to_datetime(0.0, unit='s').to_datetime64()] + [cp.start_time for cp in change_points] + [pd.to_datetime(df.iloc[-1]['TARGET_S'], unit='s').to_datetime64()]
|
63 |
+
|
64 |
+
def add_seconds_to_datetime64(datetime64, seconds, subtract=False):
|
65 |
+
"""Add or substract a number of seconds to a np.datetime64 object """
|
66 |
+
s, m = divmod(seconds, 1.0)
|
67 |
+
if subtract:
|
68 |
+
return datetime64 - np.timedelta64(int(s), 's') - np.timedelta64(int(m * 1000), 'ms')
|
69 |
+
return datetime64 + np.timedelta64(int(s), 's') + np.timedelta64(int(m * 1000), 'ms')
|
70 |
+
|
71 |
+
def plot_segment_comparison(df, change_points):
|
72 |
+
""" From the dataframe plot the current set of plots, where the bottom right is most indicative """
|
73 |
+
fig, ax_arr = plt.subplots(2, 2, figsize=(12, 4), dpi=100, sharex=True)
|
74 |
+
sns.scatterplot(data = df, x='time', y='SOURCE_S', ax=ax_arr[0,0])
|
75 |
+
sns.lineplot(data = df, x='time', y='SOURCE_LIP_S', ax=ax_arr[0,1])
|
76 |
+
|
77 |
+
# Plot change point as lines
|
78 |
+
sns.lineplot(data = df, x='time', y='OFFSET_LIP', ax=ax_arr[1,0])
|
79 |
+
sns.lineplot(data = df, x='time', y='OFFSET_LIP', ax=ax_arr[1,1])
|
80 |
+
timestamps = change_points_to_segments(df, change_points)
|
81 |
+
|
82 |
+
# To plot the detected segment lines
|
83 |
+
for x in timestamps:
|
84 |
+
plt.vlines(x=x, ymin=np.min(df['OFFSET_LIP']), ymax=np.max(df['OFFSET_LIP']), colors='black', lw=2)
|
85 |
+
rand_y_pos = np.random.uniform(low=np.min(df['OFFSET_LIP']), high=np.max(df['OFFSET_LIP']), size=None)
|
86 |
+
|
87 |
+
# To get each detected segment and their mean?
|
88 |
+
threshold_diff = 1.5 # Average diff threshold
|
89 |
+
# threshold = 3.0 # s diff threshold
|
90 |
+
for start_time, end_time in zip(timestamps[:-1], timestamps[1:]):
|
91 |
+
|
92 |
+
add_offset = np.min(df['SOURCE_S'])
|
93 |
+
|
94 |
+
# Cut out the segment between the segment lines
|
95 |
+
segment = df[(df['time'] > start_time) & (df['time'] < end_time)] # Not offset LIP
|
96 |
+
segment_no_nan = segment[~np.isnan(segment['OFFSET'])] # Remove NaNs
|
97 |
+
seg_mean = np.mean(segment_no_nan['OFFSET'])
|
98 |
+
|
99 |
+
# Get average difference from mean of the segment to see if it is a "straight line" or not
|
100 |
+
# segment_no_nan = segment['OFFSET'][~np.isnan(segment['OFFSET'])] # Remove NaNs
|
101 |
+
average_diff = np.mean(np.abs(segment_no_nan['OFFSET'] - seg_mean))
|
102 |
+
|
103 |
+
# If the time where the segment comes from (origin time) is close to the start_time, it's a "good match", so no editing
|
104 |
+
prefix = "GOOD" if average_diff < threshold_diff else "BAD"
|
105 |
+
origin_time = add_seconds_to_datetime64(start_time, seg_mean + add_offset)
|
106 |
+
# prefix = "BAD"
|
107 |
+
# if (start_time < add_seconds_to_datetime64(origin_time, threshold) and (start_time > add_seconds_to_datetime64(origin_time, threshold, subtract=True))):
|
108 |
+
# prefix = "GOOD"
|
109 |
+
|
110 |
+
# Plot green for a confident prediction (straight line), red otherwise
|
111 |
+
if prefix == "GOOD":
|
112 |
+
plt.text(x=start_time, y=seg_mean, s=str(np.round(average_diff, 1)), color='g', rotation=-0.0, fontsize=14)
|
113 |
+
else:
|
114 |
+
plt.text(x=start_time, y=seg_mean, s=str(np.round(average_diff, 1)), color='r', rotation=-0.0, fontsize=14)
|
115 |
+
|
116 |
+
print(f"[{prefix}] DIFF={average_diff:.1f} MEAN={seg_mean:.1f} {start_time} -> {end_time} comes from video X, from {origin_time}")
|
117 |
+
|
118 |
+
|
119 |
+
# Return figure
|
120 |
+
plt.xticks(rotation=90)
|
121 |
return fig
|