Abhilashvj commited on
Commit
bdcf215
1 Parent(s): 159916c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -31
app.py CHANGED
@@ -9,16 +9,21 @@ import io
9
  import cv2
10
  from insightface.app import FaceAnalysis
11
  from moviepy.editor import VideoFileClip
 
 
 
 
12
 
13
  # Load models
14
  @st.cache_resource
15
  def load_models():
16
- unified_model = SentenceTransformer("clip-ViT-B-32")
 
17
  face_app = FaceAnalysis(providers=['CPUExecutionProvider'])
18
  face_app.prepare(ctx_id=0, det_size=(640, 640))
19
- return unified_model, face_app
20
 
21
- unified_model, face_app = load_models()
22
 
23
  # Load data
24
  @st.cache_data
@@ -27,35 +32,97 @@ def load_data(video_id):
27
  summary = json.load(f)
28
  with open(f"{video_id}_transcription.json", "r") as f:
29
  transcription = json.load(f)
30
- with open(f"{video_id}_unified_metadata.json", "r") as f:
31
- unified_metadata = json.load(f)
 
 
32
  with open(f"{video_id}_face_metadata.json", "r") as f:
33
  face_metadata = json.load(f)
34
- return summary, transcription, unified_metadata, face_metadata
35
 
36
  video_id = "IMFUOexuEXw"
37
  video_path = "avengers_interview.mp4"
38
- summary, transcription, unified_metadata, face_metadata = load_data(video_id)
39
 
40
  # Load FAISS indexes
41
  @st.cache_resource
42
  def load_indexes(video_id):
43
- unified_index = faiss.read_index(f"{video_id}_unified_index.faiss")
 
44
  face_index = faiss.read_index(f"{video_id}_face_index.faiss")
45
- return unified_index, face_index
46
 
47
- unified_index, face_index = load_indexes(video_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Search functions
50
- def unified_search(query, index, metadata, model, n_results=5):
51
  if isinstance(query, str):
52
- query_vector = model.encode([query], convert_to_tensor=True).cpu().numpy()
 
53
  else: # Assume it's an image
54
- query_vector = model.encode(query, convert_to_tensor=True).cpu().numpy()
 
55
 
56
- D, I = index.search(query_vector, n_results)
57
- results = [{'data': metadata[i], 'distance': d} for i, d in zip(I[0], D[0])]
58
- return results
 
 
 
 
 
59
 
60
  def face_search(face_embedding, index, metadata, n_results=5):
61
  D, I = index.search(np.array(face_embedding).reshape(1, -1), n_results)
@@ -104,21 +171,43 @@ with col2:
104
  for theme in summary['themes']:
105
  st.write(f"Theme ID: {theme['id']}, Keywords: {', '.join(theme['keywords'])}")
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  # Search functionality
108
  st.header("Search")
109
 
110
- search_type = st.selectbox("Select search type", ["Unified", "Face"])
111
 
112
- if search_type == "Unified":
113
  search_method = st.radio("Choose search method", ["Text", "Image"])
114
 
115
  if search_method == "Text":
116
  query = st.text_input("Enter your search query")
117
  if st.button("Search"):
118
- results = unified_search(query, unified_index, unified_metadata, unified_model)
119
  st.subheader("Search Results")
120
  for result in results:
121
- st.write(f"Time: {result['data']['start']:.2f}s - {result['data']['end']:.2f}s, Distance: {result['distance']:.4f}")
122
  if 'text' in result['data']:
123
  st.write(f"Text: {result['data']['text']}")
124
  clip_path = create_video_clip(video_path, result['data']['start'], result['data']['end'], f"temp_clip_{result['data']['start']}.mp4")
@@ -130,26 +219,25 @@ if search_type == "Unified":
130
  image = Image.open(uploaded_file)
131
  st.image(image, caption="Uploaded Image", use_column_width=True)
132
  if st.button("Search"):
133
- results = unified_search(image, unified_index, unified_metadata, unified_model)
134
  st.subheader("Image Search Results")
135
  for result in results:
136
- st.write(f"Time: {result['data']['start']:.2f}s - {result['data']['end']:.2f}s, Distance: {result['distance']:.4f}")
137
  clip_path = create_video_clip(video_path, result['data']['start'], result['data']['end'], f"temp_clip_{result['data']['start']}.mp4")
138
  st.video(clip_path)
139
  st.write("---")
140
 
141
  elif search_type == "Face":
142
- face_search_type = st.radio("Choose face search method", ["Select from video", "Upload image"])
143
 
144
- if face_search_type == "Select from video":
145
- face_id = st.selectbox("Select a face", [face['id'] for face in summary['prominent_faces']])
146
  if st.button("Search"):
147
- selected_face = next(face for face in summary['prominent_faces'] if face['id'] == face_id)
148
- face_results, face_distances = face_search(selected_face['embedding'], face_index, face_metadata)
149
- st.subheader("Face Search Results")
150
- for result, distance in zip(face_results, face_distances):
151
- st.write(f"Time: {result['start']:.2f}s - {result['end']:.2f}s, Distance: {distance:.4f}")
152
- clip_path = create_video_clip(video_path, result['start'], result['end'], f"temp_face_clip_{result['start']}.mp4")
153
  st.video(clip_path)
154
  st.write("---")
155
  else:
 
9
  import cv2
10
  from insightface.app import FaceAnalysis
11
  from moviepy.editor import VideoFileClip
12
+ from sklearn.cluster import DBSCAN
13
+ from collections import defaultdict
14
+ import plotly.graph_objs as go
15
+ from sklearn.decomposition import PCA
16
 
17
  # Load models
18
  @st.cache_resource
19
  def load_models():
20
+ text_model = SentenceTransformer("all-MiniLM-L6-v2")
21
+ image_model = SentenceTransformer("clip-ViT-B-32")
22
  face_app = FaceAnalysis(providers=['CPUExecutionProvider'])
23
  face_app.prepare(ctx_id=0, det_size=(640, 640))
24
+ return text_model, image_model, face_app
25
 
26
+ text_model, image_model, face_app = load_models()
27
 
28
  # Load data
29
  @st.cache_data
 
32
  summary = json.load(f)
33
  with open(f"{video_id}_transcription.json", "r") as f:
34
  transcription = json.load(f)
35
+ with open(f"{video_id}_text_metadata.json", "r") as f:
36
+ text_metadata = json.load(f)
37
+ with open(f"{video_id}_image_metadata.json", "r") as f:
38
+ image_metadata = json.load(f)
39
  with open(f"{video_id}_face_metadata.json", "r") as f:
40
  face_metadata = json.load(f)
41
+ return summary, transcription, text_metadata, image_metadata, face_metadata
42
 
43
  video_id = "IMFUOexuEXw"
44
  video_path = "avengers_interview.mp4"
45
+ summary, transcription, text_metadata, image_metadata, face_metadata = load_data(video_id)
46
 
47
  # Load FAISS indexes
48
  @st.cache_resource
49
  def load_indexes(video_id):
50
+ text_index = faiss.read_index(f"{video_id}_text_index.faiss")
51
+ image_index = faiss.read_index(f"{video_id}_image_index.faiss")
52
  face_index = faiss.read_index(f"{video_id}_face_index.faiss")
53
+ return text_index, image_index, face_index
54
 
55
+ text_index, image_index, face_index = load_indexes(video_id)
56
+
57
+ # Face clustering function
58
+ def cluster_faces(face_embeddings, eps=0.5, min_samples=3):
59
+ clustering = DBSCAN(eps=eps, min_samples=min_samples, metric='cosine').fit(face_embeddings)
60
+ return clustering.labels_
61
+
62
+ # Face clustering visualization
63
+ def plot_face_clusters(face_embeddings, labels, face_metadata):
64
+ pca = PCA(n_components=3)
65
+ embeddings_3d = pca.fit_transform(face_embeddings)
66
+
67
+ unique_labels = set(labels)
68
+ colors = [f'rgb({int(r*255)},{int(g*255)},{int(b*255)})'
69
+ for r, g, b, _ in plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))]
70
+
71
+ traces = []
72
+ for label, color in zip(unique_labels, colors):
73
+ cluster_points = embeddings_3d[labels == label]
74
+
75
+ hover_text = []
76
+ for i, point in enumerate(cluster_points):
77
+ face = face_metadata[np.where(labels == label)[0][i]]
78
+ hover_text.append(f"Cluster {label}<br>Time: {face['start']:.2f}s")
79
+
80
+ trace = go.Scatter3d(
81
+ x=cluster_points[:, 0],
82
+ y=cluster_points[:, 1],
83
+ z=cluster_points[:, 2],
84
+ mode='markers',
85
+ name=f'Cluster {label}',
86
+ marker=dict(
87
+ size=5,
88
+ color=color,
89
+ opacity=0.8
90
+ ),
91
+ text=hover_text,
92
+ hoverinfo='text'
93
+ )
94
+ traces.append(trace)
95
+
96
+ layout = go.Layout(
97
+ title='Face Clusters Visualization',
98
+ scene=dict(
99
+ xaxis_title='PCA Component 1',
100
+ yaxis_title='PCA Component 2',
101
+ zaxis_title='PCA Component 3'
102
+ ),
103
+ margin=dict(r=0, b=0, l=0, t=40)
104
+ )
105
+
106
+ fig = go.Figure(data=traces, layout=layout)
107
+ return fig
108
 
109
  # Search functions
110
+ def combined_search(query, text_index, image_index, text_metadata, image_metadata, text_model, image_model, n_results=5):
111
  if isinstance(query, str):
112
+ text_vector = text_model.encode([query], convert_to_tensor=True).cpu().numpy()
113
+ image_vector = image_model.encode([query], convert_to_tensor=True).cpu().numpy()
114
  else: # Assume it's an image
115
+ image_vector = image_model.encode(query, convert_to_tensor=True).cpu().numpy()
116
+ text_vector = image_vector # Use the same vector for text search in this case
117
 
118
+ text_D, text_I = text_index.search(text_vector, n_results)
119
+ image_D, image_I = image_index.search(image_vector, n_results)
120
+
121
+ text_results = [{'data': text_metadata[i], 'distance': d, 'type': 'text'} for i, d in zip(text_I[0], text_D[0])]
122
+ image_results = [{'data': image_metadata[i], 'distance': d, 'type': 'image'} for i, d in zip(image_I[0], image_D[0])]
123
+
124
+ combined_results = sorted(text_results + image_results, key=lambda x: x['distance'])
125
+ return combined_results[:n_results]
126
 
127
  def face_search(face_embedding, index, metadata, n_results=5):
128
  D, I = index.search(np.array(face_embedding).reshape(1, -1), n_results)
 
171
  for theme in summary['themes']:
172
  st.write(f"Theme ID: {theme['id']}, Keywords: {', '.join(theme['keywords'])}")
173
 
174
+ # Face Clustering
175
+ st.header("Face Clustering")
176
+ face_embeddings = face_index.reconstruct_n(0, face_index.ntotal)
177
+ face_labels = cluster_faces(face_embeddings)
178
+
179
+ # Update face clusters in summary
180
+ face_clusters = defaultdict(list)
181
+ for i, label in enumerate(face_labels):
182
+ face_clusters[label].append(face_metadata[i])
183
+
184
+ summary['face_clusters'] = [
185
+ {
186
+ 'cluster_id': f'cluster_{label}',
187
+ 'faces': cluster
188
+ } for label, cluster in face_clusters.items()
189
+ ]
190
+
191
+ # Visualize face clusters
192
+ st.subheader("Face Cluster Visualization")
193
+ fig = plot_face_clusters(face_embeddings, face_labels, face_metadata)
194
+ st.plotly_chart(fig)
195
+
196
  # Search functionality
197
  st.header("Search")
198
 
199
+ search_type = st.selectbox("Select search type", ["Combined", "Face"])
200
 
201
+ if search_type == "Combined":
202
  search_method = st.radio("Choose search method", ["Text", "Image"])
203
 
204
  if search_method == "Text":
205
  query = st.text_input("Enter your search query")
206
  if st.button("Search"):
207
+ results = combined_search(query, text_index, image_index, text_metadata, image_metadata, text_model, image_model)
208
  st.subheader("Search Results")
209
  for result in results:
210
+ st.write(f"Type: {result['type']}, Time: {result['data']['start']:.2f}s - {result['data']['end']:.2f}s, Distance: {result['distance']:.4f}")
211
  if 'text' in result['data']:
212
  st.write(f"Text: {result['data']['text']}")
213
  clip_path = create_video_clip(video_path, result['data']['start'], result['data']['end'], f"temp_clip_{result['data']['start']}.mp4")
 
219
  image = Image.open(uploaded_file)
220
  st.image(image, caption="Uploaded Image", use_column_width=True)
221
  if st.button("Search"):
222
+ results = combined_search(image, text_index, image_index, text_metadata, image_metadata, text_model, image_model)
223
  st.subheader("Image Search Results")
224
  for result in results:
225
+ st.write(f"Type: {result['type']}, Time: {result['data']['start']:.2f}s - {result['data']['end']:.2f}s, Distance: {result['distance']:.4f}")
226
  clip_path = create_video_clip(video_path, result['data']['start'], result['data']['end'], f"temp_clip_{result['data']['start']}.mp4")
227
  st.video(clip_path)
228
  st.write("---")
229
 
230
  elif search_type == "Face":
231
+ face_search_type = st.radio("Choose face search method", ["Select from clusters", "Upload image"])
232
 
233
+ if face_search_type == "Select from clusters":
234
+ cluster_id = st.selectbox("Select a face cluster", [f'cluster_{label}' for label in set(face_labels) if label != -1])
235
  if st.button("Search"):
236
+ selected_cluster = next(cluster for cluster in summary['face_clusters'] if cluster['cluster_id'] == cluster_id)
237
+ st.subheader("Face Cluster Search Results")
238
+ for face in selected_cluster['faces']:
239
+ st.write(f"Time: {face['start']:.2f}s - {face['end']:.2f}s")
240
+ clip_path = create_video_clip(video_path, face['start'], face['end'], f"temp_face_clip_{face['start']}.mp4")
 
241
  st.video(clip_path)
242
  st.write("---")
243
  else: