Spaces:
Configuration error
Configuration error
Create new file
Browse files
app.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BATCH_SIZE = 64
|
2 |
+
DOWNSAMPLE = 24
|
3 |
+
|
4 |
+
import phash_jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
import statistics
|
9 |
+
import gradio
|
10 |
+
|
11 |
+
def binary_array_to_hex(arr):
|
12 |
+
"""
|
13 |
+
Function to make a hex string out of a binary array.
|
14 |
+
"""
|
15 |
+
bit_string = ''.join(str(b) for b in 1 * arr.flatten())
|
16 |
+
width = int(jnp.ceil(len(bit_string) / 4))
|
17 |
+
return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
|
18 |
+
|
19 |
+
def compute_batch_hashes(vid_path):
|
20 |
+
kwargs={"width": 64, "height":64}
|
21 |
+
vr = VideoReader(vid_path, ctx=cpu(0), **kwargs)
|
22 |
+
hashes = []
|
23 |
+
h_prev = None
|
24 |
+
batch = []
|
25 |
+
for i in range(0, len(vr), DOWNSAMPLE * BATCH_SIZE):
|
26 |
+
ids = [id for id in range(i, min(i + DOWNSAMPLE * BATCH_SIZE, len(vr)), DOWNSAMPLE)]
|
27 |
+
vr.seek(0)
|
28 |
+
batch = jnp.array(vr.get_batch(ids).asnumpy())
|
29 |
+
batch_h = phash_jax.batch_phash(batch)
|
30 |
+
for i in range(len(ids)):
|
31 |
+
h = batch_h[i]
|
32 |
+
if h_prev == None:
|
33 |
+
h_prev=h
|
34 |
+
hashes.append({"frame_id":ids[i], "hash": binary_array_to_hex(h), "distance": int(phash_jax.hash_dist(h, h_prev))})
|
35 |
+
h_prev = h
|
36 |
+
return gradio.update(value=hashes, visible=False)
|
37 |
+
|
38 |
+
def plot_hash_distance(hashes, threshold):
|
39 |
+
fig = plt.figure()
|
40 |
+
ids = [h["frame_id"] for h in hashes]
|
41 |
+
distances = [h["distance"] for h in hashes]
|
42 |
+
plt.plot(ids, distances, ".")
|
43 |
+
plt.plot(ids, [threshold]* len(ids), "r-")
|
44 |
+
return fig
|
45 |
+
|
46 |
+
def compute_threshold(hashes):
|
47 |
+
min_length = 24 * 3
|
48 |
+
ids = [h["frame_id"] for h in hashes]
|
49 |
+
distances = [h["distance"] for h in hashes]
|
50 |
+
thrs_ = sorted(list(set(distances)),reverse=True)
|
51 |
+
best = thrs_[0] - 1
|
52 |
+
for threshold in thrs_[1:]:
|
53 |
+
durations = []
|
54 |
+
i_start=0
|
55 |
+
for i, h in enumerate(hashes):
|
56 |
+
if h["distance"] > threshold and hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"] > min_length:
|
57 |
+
durations.append(hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"])
|
58 |
+
i_start=i
|
59 |
+
if len(durations) < (len(hashes) * DOWNSAMPLE / 24) / 20:
|
60 |
+
best = threshold
|
61 |
+
return best
|
62 |
+
|
63 |
+
def get_slides(vid_path, hashes, threshold):
|
64 |
+
min_length = 24 * 1.5
|
65 |
+
vr = VideoReader(vid_path, ctx=cpu(0))
|
66 |
+
slideshow = []
|
67 |
+
i_start = 0
|
68 |
+
for i, h in enumerate(hashes):
|
69 |
+
if h["distance"] > threshold and hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"] > min_length:
|
70 |
+
path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{i-1}.png'
|
71 |
+
Image.fromarray(vr[hashes[i-1]["frame_id"]].asnumpy()).save(path)
|
72 |
+
slideshow.append({"slide": path, "start": i_start, "end": i-1})
|
73 |
+
i_start=i
|
74 |
+
path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{len(vr)-1}.png'
|
75 |
+
Image.fromarray(vr[-1].asnumpy()).save(path)
|
76 |
+
slideshow.append({"slide": path, "start": i_start, "end": len(vr)-1})
|
77 |
+
return [s["slide"] for s in slideshow]
|
78 |
+
|
79 |
+
def trigger_plots(f2f_distance_plot, hashes, threshold):
|
80 |
+
# if not hist_plot.get_config()["visible"] and len(hashes.get_config()["value"]) > 0 :
|
81 |
+
return gradio.update(value=plot_hash_distance(hashes, threshold))
|
82 |
+
|
83 |
+
def set_visible():
|
84 |
+
return gradio.update(visible=True)
|
85 |
+
|
86 |
+
demo = gradio.Blocks(analytics_enabled=True)
|
87 |
+
|
88 |
+
with demo:
|
89 |
+
with gradio.Row():
|
90 |
+
with gradio.Column():
|
91 |
+
with gradio.Row():
|
92 |
+
vid=gradio.Video(mirror_webcam=False)
|
93 |
+
with gradio.Row():
|
94 |
+
btn_vid_proc = gradio.Button("Compute hashes")
|
95 |
+
with gradio.Row():
|
96 |
+
hist_plot = gradio.Plot(label="Frame to frame hash distance histogram", visible=False)
|
97 |
+
with gradio.Column():
|
98 |
+
hashes = gradio.JSON()
|
99 |
+
with gradio.Column(visible=False) as result_row:
|
100 |
+
btn_plot = gradio.Button("Plot & compute optimal threshold")
|
101 |
+
threshold = gradio.Slider(minimum=1, maximum=30, value=5, label="Threshold")
|
102 |
+
f2f_distance_plot = gradio.Plot(label="Frame to frame hash distance")
|
103 |
+
btn_slides = gradio.Button("Extract Slides")
|
104 |
+
with gradio.Row():
|
105 |
+
slideshow = gradio.Gallery(label="Extracted slides")
|
106 |
+
slideshow.style(grid=6)
|
107 |
+
btn_vid_proc.click(fn=compute_batch_hashes, inputs=[vid], outputs=[hashes])
|
108 |
+
hashes.change(fn=set_visible, inputs=[], outputs=[result_row])
|
109 |
+
btn_plot.click(fn=compute_threshold, inputs=[hashes], outputs=[threshold])
|
110 |
+
btn_plot.click(fn=trigger_plots, inputs=[f2f_distance_plot, hashes, threshold], outputs=[f2f_distance_plot])
|
111 |
+
threshold.change(fn=plot_hash_distance, inputs=[hashes, threshold], outputs=f2f_distance_plot)
|
112 |
+
btn_slides.click(fn=get_slides, inputs=[vid, hashes, threshold], outputs=[slideshow])
|
113 |
+
|
114 |
+
demo.launch(debug=True)
|