Update app.py
Browse files
app.py
CHANGED
@@ -18,13 +18,47 @@ if 'current_page' not in st.session_state:
|
|
18 |
st.session_state['current_page'] = 0
|
19 |
if 'data_cache' not in st.session_state:
|
20 |
st.session_state['data_cache'] = None
|
|
|
|
|
21 |
|
22 |
ROWS_PER_PAGE = 100 # Number of rows to load at a time
|
23 |
|
24 |
@st.cache_resource
|
25 |
def get_model():
|
|
|
26 |
return SentenceTransformer('all-MiniLM-L6-v2')
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
class FastDatasetSearcher:
|
29 |
def __init__(self, dataset_id="tomg-group-umd/cinepile"):
|
30 |
self.dataset_id = dataset_id
|
@@ -33,54 +67,35 @@ class FastDatasetSearcher:
|
|
33 |
if not self.token:
|
34 |
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
|
35 |
st.stop()
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
"""Load dataset metadata only"""
|
41 |
-
try:
|
42 |
-
dataset = load_dataset(
|
43 |
-
self.dataset_id,
|
44 |
-
token=self.token,
|
45 |
-
streaming=True
|
46 |
-
)
|
47 |
-
self.dataset_info = dataset['train'].info
|
48 |
-
return True
|
49 |
-
except Exception as e:
|
50 |
-
st.error(f"Error loading dataset: {str(e)}")
|
51 |
-
return False
|
52 |
|
53 |
def load_page(self, page=0):
|
54 |
-
"""Load a specific page of data"""
|
55 |
-
|
56 |
-
return st.session_state['data_cache']
|
57 |
-
|
58 |
-
try:
|
59 |
-
dataset = load_dataset(
|
60 |
-
self.dataset_id,
|
61 |
-
token=self.token,
|
62 |
-
streaming=False,
|
63 |
-
split=f'train[{page*ROWS_PER_PAGE}:{(page+1)*ROWS_PER_PAGE}]'
|
64 |
-
)
|
65 |
-
df = pd.DataFrame(dataset)
|
66 |
-
st.session_state['data_cache'] = df
|
67 |
-
st.session_state['current_page'] = page
|
68 |
-
return df
|
69 |
-
except Exception as e:
|
70 |
-
st.error(f"Error loading page {page}: {str(e)}")
|
71 |
-
return pd.DataFrame()
|
72 |
|
73 |
def quick_search(self, query, df):
|
74 |
"""Fast search on current page"""
|
|
|
|
|
|
|
75 |
scores = []
|
76 |
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
|
77 |
|
78 |
for _, row in df.iterrows():
|
79 |
# Combine all searchable text fields
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
# Quick keyword match
|
83 |
-
keyword_score = text.lower().count(query.lower()) / len(text.split())
|
84 |
|
85 |
# Semantic search on combined text
|
86 |
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
|
@@ -91,8 +106,29 @@ class FastDatasetSearcher:
|
|
91 |
scores.append(combined_score)
|
92 |
|
93 |
# Get top results
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def main():
|
98 |
st.title("🎥 Fast Video Dataset Search")
|
@@ -100,19 +136,31 @@ def main():
|
|
100 |
# Initialize search class
|
101 |
searcher = FastDatasetSearcher()
|
102 |
|
103 |
-
#
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# Load current page
|
107 |
-
with st.spinner(f"Loading page {
|
108 |
-
df = searcher.load_page(
|
109 |
|
110 |
if df.empty:
|
111 |
st.warning("No data available for this page.")
|
112 |
return
|
113 |
|
114 |
# Search interface
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
if query:
|
118 |
with st.spinner("Searching..."):
|
@@ -120,33 +168,23 @@ def main():
|
|
120 |
|
121 |
# Display results
|
122 |
st.write(f"Found {len(results)} results on this page:")
|
123 |
-
for i, (_, result) in enumerate(results.iterrows(), 1):
|
124 |
-
|
125 |
-
|
126 |
-
# Display video if available
|
127 |
-
if 'youtube_id' in result:
|
128 |
-
st.video(
|
129 |
-
f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}"
|
130 |
-
)
|
131 |
-
|
132 |
-
# Display other fields
|
133 |
-
for key, value in result.items():
|
134 |
-
if isinstance(value, (str, int, float)):
|
135 |
-
st.write(f"**{key}:** {value}")
|
136 |
|
137 |
# Show raw data
|
138 |
-
st.
|
139 |
-
|
140 |
|
141 |
# Navigation buttons
|
142 |
cols = st.columns(2)
|
143 |
with cols[0]:
|
144 |
-
if st.button("Previous Page") and
|
145 |
-
st.session_state['current_page']
|
146 |
st.rerun()
|
147 |
with cols[1]:
|
148 |
-
if st.button("Next Page"):
|
149 |
-
st.session_state['current_page']
|
150 |
st.rerun()
|
151 |
|
152 |
if __name__ == "__main__":
|
|
|
18 |
st.session_state['current_page'] = 0
|
19 |
if 'data_cache' not in st.session_state:
|
20 |
st.session_state['data_cache'] = None
|
21 |
+
if 'dataset_info' not in st.session_state:
|
22 |
+
st.session_state['dataset_info'] = None
|
23 |
|
24 |
ROWS_PER_PAGE = 100 # Number of rows to load at a time
|
25 |
|
26 |
@st.cache_resource
|
27 |
def get_model():
|
28 |
+
"""Cache the model loading"""
|
29 |
return SentenceTransformer('all-MiniLM-L6-v2')
|
30 |
|
31 |
+
@st.cache_data
|
32 |
+
def load_dataset_page(dataset_id, token, page, rows_per_page):
|
33 |
+
"""Load and cache a specific page of data"""
|
34 |
+
try:
|
35 |
+
start_idx = page * rows_per_page
|
36 |
+
end_idx = start_idx + rows_per_page
|
37 |
+
dataset = load_dataset(
|
38 |
+
dataset_id,
|
39 |
+
token=token,
|
40 |
+
streaming=False,
|
41 |
+
split=f'train[{start_idx}:{end_idx}]'
|
42 |
+
)
|
43 |
+
return pd.DataFrame(dataset)
|
44 |
+
except Exception as e:
|
45 |
+
st.error(f"Error loading page {page}: {str(e)}")
|
46 |
+
return pd.DataFrame()
|
47 |
+
|
48 |
+
@st.cache_data
|
49 |
+
def get_dataset_info(dataset_id, token):
|
50 |
+
"""Load and cache dataset information"""
|
51 |
+
try:
|
52 |
+
dataset = load_dataset(
|
53 |
+
dataset_id,
|
54 |
+
token=token,
|
55 |
+
streaming=True
|
56 |
+
)
|
57 |
+
return dataset['train'].info
|
58 |
+
except Exception as e:
|
59 |
+
st.error(f"Error loading dataset info: {str(e)}")
|
60 |
+
return None
|
61 |
+
|
62 |
class FastDatasetSearcher:
|
63 |
def __init__(self, dataset_id="tomg-group-umd/cinepile"):
|
64 |
self.dataset_id = dataset_id
|
|
|
67 |
if not self.token:
|
68 |
st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
|
69 |
st.stop()
|
70 |
+
|
71 |
+
# Load dataset info if not already loaded
|
72 |
+
if st.session_state['dataset_info'] is None:
|
73 |
+
st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
def load_page(self, page=0):
|
76 |
+
"""Load a specific page of data using cached function"""
|
77 |
+
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def quick_search(self, query, df):
|
80 |
"""Fast search on current page"""
|
81 |
+
if df.empty:
|
82 |
+
return df
|
83 |
+
|
84 |
scores = []
|
85 |
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
|
86 |
|
87 |
for _, row in df.iterrows():
|
88 |
# Combine all searchable text fields
|
89 |
+
text_values = []
|
90 |
+
for v in row.values():
|
91 |
+
if isinstance(v, (str, int, float)):
|
92 |
+
text_values.append(str(v))
|
93 |
+
elif isinstance(v, (list, dict)):
|
94 |
+
text_values.append(str(v))
|
95 |
+
text = ' '.join(text_values)
|
96 |
|
97 |
# Quick keyword match
|
98 |
+
keyword_score = text.lower().count(query.lower()) / (len(text.split()) + 1) # Add 1 to avoid division by zero
|
99 |
|
100 |
# Semantic search on combined text
|
101 |
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
|
|
|
106 |
scores.append(combined_score)
|
107 |
|
108 |
# Get top results
|
109 |
+
results_df = df.copy()
|
110 |
+
results_df['score'] = scores
|
111 |
+
return results_df.sort_values('score', ascending=False)
|
112 |
+
|
113 |
+
def render_result(result):
|
114 |
+
"""Render a single search result"""
|
115 |
+
score = result.pop('score', 0)
|
116 |
+
|
117 |
+
# Display video if available
|
118 |
+
if 'youtube_id' in result:
|
119 |
+
st.video(
|
120 |
+
f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}"
|
121 |
+
)
|
122 |
+
|
123 |
+
# Display other fields
|
124 |
+
cols = st.columns([2, 1])
|
125 |
+
with cols[0]:
|
126 |
+
for key, value in result.items():
|
127 |
+
if isinstance(value, (str, int, float)):
|
128 |
+
st.write(f"**{key}:** {value}")
|
129 |
+
|
130 |
+
with cols[1]:
|
131 |
+
st.metric("Relevance Score", f"{score:.2%}")
|
132 |
|
133 |
def main():
|
134 |
st.title("🎥 Fast Video Dataset Search")
|
|
|
136 |
# Initialize search class
|
137 |
searcher = FastDatasetSearcher()
|
138 |
|
139 |
+
# Show dataset info
|
140 |
+
if st.session_state['dataset_info']:
|
141 |
+
st.sidebar.write("### Dataset Info")
|
142 |
+
st.sidebar.write(f"Total examples: {st.session_state['dataset_info'].splits['train'].num_examples:,}")
|
143 |
+
|
144 |
+
total_pages = st.session_state['dataset_info'].splits['train'].num_examples // ROWS_PER_PAGE
|
145 |
+
current_page = st.number_input("Page", min_value=0, max_value=total_pages, value=st.session_state['current_page'])
|
146 |
+
else:
|
147 |
+
current_page = st.number_input("Page", min_value=0, value=st.session_state['current_page'])
|
148 |
|
149 |
# Load current page
|
150 |
+
with st.spinner(f"Loading page {current_page}..."):
|
151 |
+
df = searcher.load_page(current_page)
|
152 |
|
153 |
if df.empty:
|
154 |
st.warning("No data available for this page.")
|
155 |
return
|
156 |
|
157 |
# Search interface
|
158 |
+
col1, col2 = st.columns([3, 1])
|
159 |
+
with col1:
|
160 |
+
query = st.text_input("Search in current page:",
|
161 |
+
help="Searches within currently loaded data")
|
162 |
+
with col2:
|
163 |
+
max_results = st.slider("Max results", 1, ROWS_PER_PAGE, 10)
|
164 |
|
165 |
if query:
|
166 |
with st.spinner("Searching..."):
|
|
|
168 |
|
169 |
# Display results
|
170 |
st.write(f"Found {len(results)} results on this page:")
|
171 |
+
for i, (_, result) in enumerate(results.head(max_results).iterrows(), 1):
|
172 |
+
with st.expander(f"Result {i}", expanded=i==1):
|
173 |
+
render_result(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
# Show raw data
|
176 |
+
with st.expander("Show Raw Data"):
|
177 |
+
st.dataframe(df)
|
178 |
|
179 |
# Navigation buttons
|
180 |
cols = st.columns(2)
|
181 |
with cols[0]:
|
182 |
+
if st.button("⬅️ Previous Page") and current_page > 0:
|
183 |
+
st.session_state['current_page'] = current_page - 1
|
184 |
st.rerun()
|
185 |
with cols[1]:
|
186 |
+
if st.button("Next Page ➡️"):
|
187 |
+
st.session_state['current_page'] = current_page + 1
|
188 |
st.rerun()
|
189 |
|
190 |
if __name__ == "__main__":
|