Spaces:
Sleeping
Sleeping
import streamlit as st | |
import json | |
import faiss | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
import base64 | |
from PIL import Image | |
import io | |
import cv2 | |
from insightface.app import FaceAnalysis | |
from moviepy.editor import VideoFileClip | |
from sklearn.cluster import DBSCAN | |
from collections import defaultdict | |
import plotly.graph_objs as go | |
from sklearn.decomposition import PCA | |
# Load models | |
def load_models(): | |
text_model = SentenceTransformer("all-MiniLM-L6-v2") | |
image_model = SentenceTransformer("clip-ViT-B-32") | |
face_app = FaceAnalysis(providers=['CPUExecutionProvider']) | |
face_app.prepare(ctx_id=0, det_size=(640, 640)) | |
return text_model, image_model, face_app | |
text_model, image_model, face_app = load_models() | |
# Load data | |
def load_data(video_id): | |
with open(f"{video_id}_summary.json", "r") as f: | |
summary = json.load(f) | |
with open(f"{video_id}_transcription.json", "r") as f: | |
transcription = json.load(f) | |
with open(f"{video_id}_text_metadata.json", "r") as f: | |
text_metadata = json.load(f) | |
with open(f"{video_id}_image_metadata.json", "r") as f: | |
image_metadata = json.load(f) | |
with open(f"{video_id}_face_metadata.json", "r") as f: | |
face_metadata = json.load(f) | |
return summary, transcription, text_metadata, image_metadata, face_metadata | |
video_id = "IMFUOexuEXw" | |
video_path = "avengers_interview.mp4" | |
summary, transcription, text_metadata, image_metadata, face_metadata = load_data(video_id) | |
# Load FAISS indexes | |
def load_indexes(video_id): | |
text_index = faiss.read_index(f"{video_id}_text_index.faiss") | |
image_index = faiss.read_index(f"{video_id}_image_index.faiss") | |
face_index = faiss.read_index(f"{video_id}_face_index.faiss") | |
return text_index, image_index, face_index | |
text_index, image_index, face_index = load_indexes(video_id) | |
# Face clustering function | |
def cluster_faces(face_embeddings, eps=0.5, min_samples=3): | |
clustering = DBSCAN(eps=eps, min_samples=min_samples, metric='cosine').fit(face_embeddings) | |
return clustering.labels_ | |
# Face clustering visualization | |
def plot_face_clusters(face_embeddings, labels, face_metadata): | |
pca = PCA(n_components=3) | |
embeddings_3d = pca.fit_transform(face_embeddings) | |
unique_labels = set(labels) | |
colors = [f'rgb({int(r*255)},{int(g*255)},{int(b*255)})' | |
for r, g, b, _ in plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))] | |
traces = [] | |
for label, color in zip(unique_labels, colors): | |
cluster_points = embeddings_3d[labels == label] | |
hover_text = [] | |
for i, point in enumerate(cluster_points): | |
face = face_metadata[np.where(labels == label)[0][i]] | |
hover_text.append(f"Cluster {label}<br>Time: {face['start']:.2f}s") | |
trace = go.Scatter3d( | |
x=cluster_points[:, 0], | |
y=cluster_points[:, 1], | |
z=cluster_points[:, 2], | |
mode='markers', | |
name=f'Cluster {label}', | |
marker=dict( | |
size=5, | |
color=color, | |
opacity=0.8 | |
), | |
text=hover_text, | |
hoverinfo='text' | |
) | |
traces.append(trace) | |
layout = go.Layout( | |
title='Face Clusters Visualization', | |
scene=dict( | |
xaxis_title='PCA Component 1', | |
yaxis_title='PCA Component 2', | |
zaxis_title='PCA Component 3' | |
), | |
margin=dict(r=0, b=0, l=0, t=40) | |
) | |
fig = go.Figure(data=traces, layout=layout) | |
return fig | |
# Search functions | |
def combined_search(query, text_index, image_index, text_metadata, image_metadata, text_model, image_model, n_results=5): | |
if isinstance(query, str): | |
text_vector = text_model.encode([query], convert_to_tensor=True).cpu().numpy() | |
image_vector = image_model.encode([query], convert_to_tensor=True).cpu().numpy() | |
else: # Assume it's an image | |
image_vector = image_model.encode(query, convert_to_tensor=True).cpu().numpy() | |
text_vector = image_vector # Use the same vector for text search in this case | |
text_D, text_I = text_index.search(text_vector, n_results) | |
image_D, image_I = image_index.search(image_vector, n_results) | |
text_results = [{'data': text_metadata[i], 'distance': d, 'type': 'text'} for i, d in zip(text_I[0], text_D[0])] | |
image_results = [{'data': image_metadata[i], 'distance': d, 'type': 'image'} for i, d in zip(image_I[0], image_D[0])] | |
combined_results = sorted(text_results + image_results, key=lambda x: x['distance']) | |
return combined_results[:n_results] | |
def face_search(face_embedding, index, metadata, n_results=5): | |
D, I = index.search(np.array(face_embedding).reshape(1, -1), n_results) | |
results = [metadata[i] for i in I[0]] | |
return results, D[0] | |
def detect_and_embed_face(image, face_app): | |
img_array = np.array(image) | |
faces = face_app.get(img_array) | |
if len(faces) == 0: | |
return None | |
largest_face = max(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) | |
return largest_face.embedding | |
def create_video_clip(video_path, start_time, end_time, output_path): | |
with VideoFileClip(video_path) as video: | |
new_clip = video.subclip(start_time, end_time) | |
new_clip.write_videofile(output_path, codec="libx264", audio_codec="aac") | |
return output_path | |
# Streamlit UI | |
st.title("Video Analysis Dashboard") | |
# Sidebar with full video and scrollable transcript | |
st.sidebar.header("Full Video") | |
st.sidebar.video(video_path) | |
st.sidebar.header("Video Transcript") | |
transcript_text = transcription['transcription'] | |
st.sidebar.text_area("Full Transcript", transcript_text, height=300) | |
# Main content | |
st.header("Video Summary") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Prominent Faces") | |
for face in summary['prominent_faces']: | |
st.write(f"Face ID: {face['id']}, Appearances: {face['appearances']}") | |
if 'thumbnail' in face: | |
image = Image.open(io.BytesIO(base64.b64decode(face['thumbnail']))) | |
st.image(image, caption=f"Face ID: {face['id']}", width=100) | |
with col2: | |
st.subheader("Themes") | |
for theme in summary['themes']: | |
st.write(f"Theme ID: {theme['id']}, Keywords: {', '.join(theme['keywords'])}") | |
# Face Clustering | |
st.header("Face Clustering") | |
face_embeddings = face_index.reconstruct_n(0, face_index.ntotal) | |
face_labels = cluster_faces(face_embeddings) | |
# Update face clusters in summary | |
face_clusters = defaultdict(list) | |
for i, label in enumerate(face_labels): | |
face_clusters[label].append(face_metadata[i]) | |
summary['face_clusters'] = [ | |
{ | |
'cluster_id': f'cluster_{label}', | |
'faces': cluster | |
} for label, cluster in face_clusters.items() | |
] | |
# Visualize face clusters | |
st.subheader("Face Cluster Visualization") | |
fig = plot_face_clusters(face_embeddings, face_labels, face_metadata) | |
st.plotly_chart(fig) | |
# Search functionality | |
st.header("Search") | |
search_type = st.selectbox("Select search type", ["Combined", "Face"]) | |
if search_type == "Combined": | |
search_method = st.radio("Choose search method", ["Text", "Image"]) | |
if search_method == "Text": | |
query = st.text_input("Enter your search query") | |
if st.button("Search"): | |
results = combined_search(query, text_index, image_index, text_metadata, image_metadata, text_model, image_model) | |
st.subheader("Search Results") | |
for result in results: | |
st.write(f"Type: {result['type']}, Time: {result['data']['start']:.2f}s - {result['data']['end']:.2f}s, Distance: {result['distance']:.4f}") | |
if 'text' in result['data']: | |
st.write(f"Text: {result['data']['text']}") | |
clip_path = create_video_clip(video_path, result['data']['start'], result['data']['end'], f"temp_clip_{result['data']['start']}.mp4") | |
st.video(clip_path) | |
st.write("---") | |
else: | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
if st.button("Search"): | |
results = combined_search(image, text_index, image_index, text_metadata, image_metadata, text_model, image_model) | |
st.subheader("Image Search Results") | |
for result in results: | |
st.write(f"Type: {result['type']}, Time: {result['data']['start']:.2f}s - {result['data']['end']:.2f}s, Distance: {result['distance']:.4f}") | |
clip_path = create_video_clip(video_path, result['data']['start'], result['data']['end'], f"temp_clip_{result['data']['start']}.mp4") | |
st.video(clip_path) | |
st.write("---") | |
elif search_type == "Face": | |
face_search_type = st.radio("Choose face search method", ["Select from clusters", "Upload image"]) | |
if face_search_type == "Select from clusters": | |
cluster_id = st.selectbox("Select a face cluster", [f'cluster_{label}' for label in set(face_labels) if label != -1]) | |
if st.button("Search"): | |
selected_cluster = next(cluster for cluster in summary['face_clusters'] if cluster['cluster_id'] == cluster_id) | |
st.subheader("Face Cluster Search Results") | |
for face in selected_cluster['faces']: | |
st.write(f"Time: {face['start']:.2f}s - {face['end']:.2f}s") | |
clip_path = create_video_clip(video_path, face['start'], face['end'], f"temp_face_clip_{face['start']}.mp4") | |
st.video(clip_path) | |
st.write("---") | |
else: | |
uploaded_file = st.file_uploader("Choose a face image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
if st.button("Search"): | |
face_embedding = detect_and_embed_face(image, face_app) | |
if face_embedding is not None: | |
face_results, face_distances = face_search(face_embedding, face_index, face_metadata) | |
st.subheader("Face Search Results") | |
for result, distance in zip(face_results, face_distances): | |
st.write(f"Time: {result['start']:.2f}s - {result['end']:.2f}s, Distance: {distance:.4f}") | |
clip_path = create_video_clip(video_path, result['start'], result['end'], f"temp_face_clip_{result['start']}.mp4") | |
st.video(clip_path) | |
st.write("---") | |
else: | |
st.error("No face detected in the uploaded image. Please try another image.") |