awacke1 commited on
Commit
54e3aa1
·
verified ·
1 Parent(s): e3138e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -198
app.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  import json
8
  import os
9
  import glob
 
10
  from pathlib import Path
11
  from datetime import datetime, timedelta
12
  import edge_tts
@@ -20,37 +21,51 @@ from datasets import load_dataset
20
  import base64
21
  import re
22
 
23
- # 🧠 Initialize session state variables
24
- SESSION_VARS = {
25
- 'search_history': [], # Track search history
26
- 'last_voice_input': "", # Last voice input
27
- 'transcript_history': [], # Conversation history
28
- 'should_rerun': False, # Trigger for UI updates
29
- 'search_columns': [], # Available search columns
30
- 'initial_search_done': False, # First search flag
31
- 'tts_voice': "en-US-AriaNeural", # Default voice
32
- 'arxiv_last_query': "", # Last ArXiv search
33
- 'dataset_loaded': False, # Dataset load status
34
- 'current_page': 0, # Current data page
35
- 'data_cache': None, # Data cache
36
- 'dataset_info': None, # Dataset metadata
37
- 'nps_submitted': False, # Track if user submitted NPS
38
- 'nps_last_shown': None, # When NPS was last shown
39
- 'old_val': None, # Previous voice input value
40
- 'voice_text': None # Processed voice text
41
- }
42
 
43
- # Constants
44
  ROWS_PER_PAGE = 100
45
  MIN_SEARCH_SCORE = 0.3
46
  EXACT_MATCH_BOOST = 2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Initialize session state
49
  for var, default in SESSION_VARS.items():
50
  if var not in st.session_state:
51
  st.session_state[var] = default
52
 
53
- # Voice Component Setup
 
 
 
 
54
  def create_voice_component():
55
  """Create the voice input component"""
56
  mycomponent = components.declare_component(
@@ -59,9 +74,7 @@ def create_voice_component():
59
  )
60
  return mycomponent
61
 
62
- # Utility Functions
63
  def clean_for_speech(text: str) -> str:
64
- """Clean text for speech synthesis"""
65
  text = text.replace("\n", " ")
66
  text = text.replace("</s>", " ")
67
  text = text.replace("#", "")
@@ -82,7 +95,6 @@ async def edge_tts_generate_audio(text, voice="en-US-AriaNeural", rate=0, pitch=
82
  return out_fn
83
 
84
  def speak_with_edge_tts(text, voice="en-US-AriaNeural", rate=0, pitch=0):
85
- """Wrapper for edge TTS generation"""
86
  return asyncio.run(edge_tts_generate_audio(text, voice, rate, pitch))
87
 
88
  def play_and_download_audio(file_path):
@@ -94,12 +106,10 @@ def play_and_download_audio(file_path):
94
 
95
  @st.cache_resource
96
  def get_model():
97
- """Get sentence transformer model"""
98
  return SentenceTransformer('all-MiniLM-L6-v2')
99
 
100
  @st.cache_data
101
  def load_dataset_page(dataset_id, token, page, rows_per_page):
102
- """Load dataset page with caching"""
103
  try:
104
  start_idx = page * rows_per_page
105
  end_idx = start_idx + rows_per_page
@@ -116,7 +126,6 @@ def load_dataset_page(dataset_id, token, page, rows_per_page):
116
 
117
  @st.cache_data
118
  def get_dataset_info(dataset_id, token):
119
- """Get dataset info with caching"""
120
  try:
121
  dataset = load_dataset(dataset_id, token=token, streaming=True)
122
  return dataset['train'].info
@@ -125,7 +134,6 @@ def get_dataset_info(dataset_id, token):
125
  return None
126
 
127
  def fetch_dataset_info(dataset_id):
128
- """Fetch dataset information"""
129
  info_url = f"https://huggingface.co/api/datasets/{dataset_id}"
130
  try:
131
  response = requests.get(info_url, timeout=30)
@@ -136,18 +144,30 @@ def fetch_dataset_info(dataset_id):
136
  return None
137
 
138
  def generate_filename(text):
139
- """Generate unique filename from text"""
140
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
141
  safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower()
142
  safe_text = re.sub(r'[-\s]+', '-', safe_text)
143
- return f"{timestamp}_{safe_text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def render_result(result):
146
- """Render a single search result"""
147
  score = result.get('relevance_score', 0)
148
  result_filtered = {k: v for k, v in result.items()
149
  if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']}
150
-
151
  if 'youtube_id' in result:
152
  st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
153
 
@@ -183,8 +203,6 @@ def render_result(result):
183
  play_and_download_audio(audio_file)
184
 
185
  class FastDatasetSearcher:
186
- """Fast dataset search with semantic and token matching"""
187
-
188
  def __init__(self, dataset_id="tomg-group-umd/cinepile"):
189
  self.dataset_id = dataset_id
190
  self.text_model = get_model()
@@ -197,18 +215,16 @@ class FastDatasetSearcher:
197
  st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)
198
 
199
  def load_page(self, page=0):
200
- """Load a specific page of data"""
201
  return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
202
 
203
  def quick_search(self, query, df):
204
- """Perform quick search with semantic similarity"""
205
  if df.empty or not query.strip():
206
  return df
207
 
208
  try:
209
  searchable_cols = []
210
  for col in df.columns:
211
- sample_val = df[col].iloc[0]
212
  if not isinstance(sample_val, (np.ndarray, bytes)):
213
  searchable_cols.append(col)
214
 
@@ -253,7 +269,7 @@ class FastDatasetSearcher:
253
  if text.strip():
254
  text_tokens = set(text.lower().split())
255
  matching_terms = query_terms.intersection(text_tokens)
256
- keyword_score = len(matching_terms) / len(query_terms)
257
 
258
  text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
259
  semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0])
@@ -286,9 +302,13 @@ class FastDatasetSearcher:
286
  st.error(f"Search error: {str(e)}")
287
  return df
288
 
 
289
  def main():
290
  st.title("🎥 Smart Video & Voice Search")
291
 
 
 
 
292
  # Initialize components
293
  voice_component = create_voice_component()
294
  search = FastDatasetSearcher()
@@ -296,176 +316,31 @@ def main():
296
  # Voice input at top level
297
  voice_val = voice_component(my_input_value="Start speaking...")
298
 
299
- # Show voice input if detected
 
 
 
 
 
 
 
 
 
300
  if voice_val:
301
  voice_text = str(voice_val).strip()
302
  edited_input = st.text_area("✏️ Edit Voice Input:", value=voice_text, height=100)
303
 
 
304
  run_option = st.selectbox("Select Search Type:",
305
- ["Quick Search", "Deep Search", "Voice Summary"])
306
 
307
  col1, col2 = st.columns(2)
308
  with col1:
309
- autorun = st.checkbox("⚡ Auto-Run", value=False)
310
  with col2:
311
  full_audio = st.checkbox("🔊 Full Audio", value=False)
312
 
313
  input_changed = (voice_text != st.session_state.get('old_val'))
314
 
315
  if autorun and input_changed:
316
- st.session_state['old_val'] = voice_text
317
- with st.spinner("Processing voice input..."):
318
- if run_option == "Quick Search":
319
- results = search.quick_search(edited_input, search.load_page())
320
- for i, result in enumerate(results.iterrows(), 1):
321
- with st.expander(f"Result {i}", expanded=(i==1)):
322
- render_result(result[1])
323
-
324
- elif run_option == "Deep Search":
325
- with st.spinner("Performing deep search..."):
326
- results = []
327
- for page in range(3): # Search first 3 pages
328
- df = search.load_page(page)
329
- results.extend(search.quick_search(edited_input, df).iterrows())
330
-
331
- for i, result in enumerate(results, 1):
332
- with st.expander(f"Result {i}", expanded=(i==1)):
333
- render_result(result[1])
334
-
335
- elif run_option == "Voice Summary":
336
- audio_file = speak_with_edge_tts(edited_input)
337
- if audio_file:
338
- play_and_download_audio(audio_file)
339
-
340
- elif st.button("🔍 Search", key="voice_input_search"):
341
- st.session_state['old_val'] = voice_text
342
- with st.spinner("Processing..."):
343
- results = search.quick_search(edited_input, search.load_page())
344
- for i, result in enumerate(results.iterrows(), 1):
345
- with st.expander(f"Result {i}", expanded=(i==1)):
346
- render_result(result[1])
347
-
348
- # Create main tabs
349
- tab1, tab2, tab3, tab4 = st.tabs([
350
- "🔍 Search", "🎙️ Voice", "💾 History", "⚙️ Settings"
351
- ])
352
-
353
- with tab1:
354
- st.subheader("🔍 Search")
355
- col1, col2 = st.columns([3, 1])
356
- with col1:
357
- query = st.text_input("Enter search query:",
358
- value="" if st.session_state['initial_search_done'] else "")
359
- with col2:
360
- search_column = st.selectbox("Search in:",
361
- ["All Fields"] + st.session_state['search_columns'])
362
-
363
- col3, col4 = st.columns(2)
364
- with col3:
365
- num_results = st.slider("Max results:", 1, 100, 20)
366
- with col4:
367
- search_button = st.button("🔍 Search", key="main_search_button")
368
-
369
- if (search_button or not st.session_state['initial_search_done']) and query:
370
- st.session_state['initial_search_done'] = True
371
- selected_column = None if search_column == "All Fields" else search_column
372
-
373
- with st.spinner("Searching..."):
374
- df = search.load_page()
375
- results = search.quick_search(query, df)
376
-
377
- if len(results) > 0:
378
- st.session_state['search_history'].append({
379
- 'query': query,
380
- 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
381
- 'results': results[:5]
382
- })
383
-
384
- st.write(f"Found {len(results)} results:")
385
- for i, (_, result) in enumerate(results.iterrows(), 1):
386
- if i > num_results:
387
- break
388
- with st.expander(f"Result {i}", expanded=(i==1)):
389
- render_result(result)
390
- else:
391
- st.warning("No matching results found.")
392
-
393
- with tab2:
394
- st.subheader("🎙️ Voice Input")
395
- st.write("Use the voice input above to start speaking, or record a new message:")
396
-
397
- col1, col2 = st.columns(2)
398
- with col1:
399
- if st.button("🎙️ Start New Recording", key="start_recording_button"):
400
- st.session_state['recording'] = True
401
- st.experimental_rerun()
402
- with col2:
403
- if st.button("🛑 Stop Recording", key="stop_recording_button"):
404
- st.session_state['recording'] = False
405
- st.experimental_rerun()
406
-
407
- if st.session_state.get('recording', False):
408
- voice_component = create_voice_component()
409
- new_val = voice_component(my_input_value="Recording...")
410
- if new_val:
411
- st.text_area("Recorded Text:", value=new_val, height=100)
412
- if st.button("🔍 Search with Recording", key="recording_search_button"):
413
- with st.spinner("Processing recording..."):
414
- df = search.load_page()
415
- results = search.quick_search(new_val, df)
416
- for i, (_, result) in enumerate(results.iterrows(), 1):
417
- with st.expander(f"Result {i}", expanded=(i==1)):
418
- render_result(result)
419
-
420
- with tab3:
421
- st.subheader("💾 Search History")
422
- if not st.session_state['search_history']:
423
- st.info("No search history yet. Try searching for something!")
424
- else:
425
- for entry in reversed(st.session_state['search_history']):
426
- with st.expander(f"🕒 {entry['timestamp']} - {entry['query']}", expanded=False):
427
- for i, result in enumerate(entry['results'], 1):
428
- st.write(f"**Result {i}:**")
429
- if isinstance(result, pd.Series):
430
- render_result(result)
431
- else:
432
- st.write(result)
433
-
434
- with tab4:
435
- st.subheader("⚙️ Settings")
436
- st.write("Voice Settings:")
437
- default_voice = st.selectbox(
438
- "Default Voice:",
439
- [
440
- "en-US-AriaNeural",
441
- "en-US-GuyNeural",
442
- "en-GB-SoniaNeural",
443
- "en-GB-TonyNeural"
444
- ],
445
- index=0,
446
- key="default_voice_setting"
447
- )
448
-
449
- st.write("Search Settings:")
450
- st.slider("Minimum Search Score:", 0.0, 1.0, MIN_SEARCH_SCORE, 0.1, key="min_search_score")
451
- st.slider("Exact Match Boost:", 1.0, 3.0, EXACT_MATCH_BOOST, 0.1, key="exact_match_boost")
452
-
453
- if st.button("🗑️ Clear Search History", key="clear_history_button"):
454
- st.session_state['search_history'] = []
455
- st.success("Search history cleared!")
456
- st.experimental_rerun()
457
-
458
- # Sidebar with metrics
459
- with st.sidebar:
460
- st.subheader("📊 Search Metrics")
461
- total_searches = len(st.session_state['search_history'])
462
- st.metric("Total Searches", total_searches)
463
-
464
- if total_searches > 0:
465
- recent_searches = st.session_state['search_history'][-5:]
466
- st.write("Recent Searches:")
467
- for entry in reversed(recent_searches):
468
- st.write(f"🔍 {entry['query']}")
469
-
470
- if __name__ == "__main__":
471
- main()
 
7
  import json
8
  import os
9
  import glob
10
+ import random
11
  from pathlib import Path
12
  from datetime import datetime, timedelta
13
  import edge_tts
 
21
  import base64
22
  import re
23
 
24
+ # -------------------- Configuration & Constants --------------------
25
+ # User name assignment
26
+ USER_NAMES = [
27
+ "Alex", "Jordan", "Taylor", "Morgan", "Rowan", "Avery", "Riley", "Quinn",
28
+ "Casey", "Jesse", "Reese", "Skyler", "Ellis", "Devon", "Aubrey", "Kendall",
29
+ "Parker", "Dakota", "Sage", "Finley"
30
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
32
  ROWS_PER_PAGE = 100
33
  MIN_SEARCH_SCORE = 0.3
34
  EXACT_MATCH_BOOST = 2.0
35
+ SAVED_INPUTS_DIR = "saved_inputs"
36
+ os.makedirs(SAVED_INPUTS_DIR, exist_ok=True)
37
+
38
+ # -------------------- Session State Initialization --------------------
39
+ SESSION_VARS = {
40
+ 'search_history': [],
41
+ 'last_voice_input': "",
42
+ 'transcript_history': [],
43
+ 'should_rerun': False,
44
+ 'search_columns': [],
45
+ 'initial_search_done': False,
46
+ 'tts_voice': "en-US-AriaNeural",
47
+ 'arxiv_last_query': "",
48
+ 'dataset_loaded': False,
49
+ 'current_page': 0,
50
+ 'data_cache': None,
51
+ 'dataset_info': None,
52
+ 'nps_submitted': False,
53
+ 'nps_last_shown': None,
54
+ 'old_val': None,
55
+ 'voice_text': None,
56
+ 'user_name': None, # New: Track user name
57
+ 'max_items': 100 # Default max items
58
+ }
59
 
 
60
  for var, default in SESSION_VARS.items():
61
  if var not in st.session_state:
62
  st.session_state[var] = default
63
 
64
+ # Assign user name if not assigned
65
+ if st.session_state['user_name'] is None:
66
+ st.session_state['user_name'] = random.choice(USER_NAMES)
67
+
68
+ # -------------------- Utility Functions --------------------
69
  def create_voice_component():
70
  """Create the voice input component"""
71
  mycomponent = components.declare_component(
 
74
  )
75
  return mycomponent
76
 
 
77
  def clean_for_speech(text: str) -> str:
 
78
  text = text.replace("\n", " ")
79
  text = text.replace("</s>", " ")
80
  text = text.replace("#", "")
 
95
  return out_fn
96
 
97
  def speak_with_edge_tts(text, voice="en-US-AriaNeural", rate=0, pitch=0):
 
98
  return asyncio.run(edge_tts_generate_audio(text, voice, rate, pitch))
99
 
100
  def play_and_download_audio(file_path):
 
106
 
107
  @st.cache_resource
108
  def get_model():
 
109
  return SentenceTransformer('all-MiniLM-L6-v2')
110
 
111
  @st.cache_data
112
  def load_dataset_page(dataset_id, token, page, rows_per_page):
 
113
  try:
114
  start_idx = page * rows_per_page
115
  end_idx = start_idx + rows_per_page
 
126
 
127
  @st.cache_data
128
  def get_dataset_info(dataset_id, token):
 
129
  try:
130
  dataset = load_dataset(dataset_id, token=token, streaming=True)
131
  return dataset['train'].info
 
134
  return None
135
 
136
  def fetch_dataset_info(dataset_id):
 
137
  info_url = f"https://huggingface.co/api/datasets/{dataset_id}"
138
  try:
139
  response = requests.get(info_url, timeout=30)
 
144
  return None
145
 
146
  def generate_filename(text):
 
147
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
148
  safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower()
149
  safe_text = re.sub(r'[-\s]+', '-', safe_text)
150
+ return f"{timestamp}_{safe_text}.md"
151
+
152
+ def save_input_as_md(text):
153
+ if not text.strip():
154
+ return
155
+ fn = generate_filename(text)
156
+ full_path = os.path.join(SAVED_INPUTS_DIR, fn)
157
+ with open(full_path, 'w', encoding='utf-8') as f:
158
+ f.write(f"# User: {st.session_state['user_name']}\n")
159
+ f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
160
+ f.write(text)
161
+ return full_path
162
+
163
+ def list_saved_inputs():
164
+ files = sorted(glob.glob(os.path.join(SAVED_INPUTS_DIR, "*.md")))
165
+ return files
166
 
167
  def render_result(result):
 
168
  score = result.get('relevance_score', 0)
169
  result_filtered = {k: v for k, v in result.items()
170
  if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']}
 
171
  if 'youtube_id' in result:
172
  st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
173
 
 
203
  play_and_download_audio(audio_file)
204
 
205
  class FastDatasetSearcher:
 
 
206
  def __init__(self, dataset_id="tomg-group-umd/cinepile"):
207
  self.dataset_id = dataset_id
208
  self.text_model = get_model()
 
215
  st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)
216
 
217
  def load_page(self, page=0):
 
218
  return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
219
 
220
  def quick_search(self, query, df):
 
221
  if df.empty or not query.strip():
222
  return df
223
 
224
  try:
225
  searchable_cols = []
226
  for col in df.columns:
227
+ sample_val = df[col].iloc[0] if len(df) > 0 else ""
228
  if not isinstance(sample_val, (np.ndarray, bytes)):
229
  searchable_cols.append(col)
230
 
 
269
  if text.strip():
270
  text_tokens = set(text.lower().split())
271
  matching_terms = query_terms.intersection(text_tokens)
272
+ keyword_score = len(matching_terms) / len(query_terms) if len(query_terms) > 0 else 0.0
273
 
274
  text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
275
  semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0])
 
302
  st.error(f"Search error: {str(e)}")
303
  return df
304
 
305
+ # -------------------- Main App --------------------
306
  def main():
307
  st.title("🎥 Smart Video & Voice Search")
308
 
309
+ # Load saved inputs (conversation history)
310
+ saved_files = list_saved_inputs()
311
+
312
  # Initialize components
313
  voice_component = create_voice_component()
314
  search = FastDatasetSearcher()
 
316
  # Voice input at top level
317
  voice_val = voice_component(my_input_value="Start speaking...")
318
 
319
+ # User can override max items
320
+ with st.sidebar:
321
+ st.write(f"**Current User:** {st.session_state['user_name']}")
322
+ st.session_state['max_items'] = st.number_input("Max Items per search iteration:", min_value=1, max_value=1000, value=st.session_state['max_items'])
323
+ st.subheader("📝 Saved Inputs:")
324
+ # Show saved md files in order
325
+ for fpath in saved_files:
326
+ fname = os.path.basename(fpath)
327
+ st.write(f"- [{fname}]({fpath})")
328
+
329
  if voice_val:
330
  voice_text = str(voice_val).strip()
331
  edited_input = st.text_area("✏️ Edit Voice Input:", value=voice_text, height=100)
332
 
333
+ # Auto-run default True now
334
  run_option = st.selectbox("Select Search Type:",
335
+ ["Quick Search", "Deep Search", "Voice Summary"])
336
 
337
  col1, col2 = st.columns(2)
338
  with col1:
339
+ autorun = st.checkbox("⚡ Auto-Run", value=True)
340
  with col2:
341
  full_audio = st.checkbox("🔊 Full Audio", value=False)
342
 
343
  input_changed = (voice_text != st.session_state.get('old_val'))
344
 
345
  if autorun and input_changed:
346
+