awacke1 commited on
Commit
9b95cb7
·
verified ·
1 Parent(s): bdefc08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -62
app.py CHANGED
@@ -18,13 +18,47 @@ if 'current_page' not in st.session_state:
18
  st.session_state['current_page'] = 0
19
  if 'data_cache' not in st.session_state:
20
  st.session_state['data_cache'] = None
 
 
21
 
22
  ROWS_PER_PAGE = 100 # Number of rows to load at a time
23
 
24
  @st.cache_resource
25
  def get_model():
 
26
  return SentenceTransformer('all-MiniLM-L6-v2')
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class FastDatasetSearcher:
29
  def __init__(self, dataset_id="tomg-group-umd/cinepile"):
30
  self.dataset_id = dataset_id
@@ -33,54 +67,35 @@ class FastDatasetSearcher:
33
  if not self.token:
34
  st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
35
  st.stop()
36
- self.load_dataset_info()
37
-
38
- @st.cache_data
39
- def load_dataset_info(self):
40
- """Load dataset metadata only"""
41
- try:
42
- dataset = load_dataset(
43
- self.dataset_id,
44
- token=self.token,
45
- streaming=True
46
- )
47
- self.dataset_info = dataset['train'].info
48
- return True
49
- except Exception as e:
50
- st.error(f"Error loading dataset: {str(e)}")
51
- return False
52
 
53
  def load_page(self, page=0):
54
- """Load a specific page of data"""
55
- if st.session_state['data_cache'] is not None and st.session_state['current_page'] == page:
56
- return st.session_state['data_cache']
57
-
58
- try:
59
- dataset = load_dataset(
60
- self.dataset_id,
61
- token=self.token,
62
- streaming=False,
63
- split=f'train[{page*ROWS_PER_PAGE}:{(page+1)*ROWS_PER_PAGE}]'
64
- )
65
- df = pd.DataFrame(dataset)
66
- st.session_state['data_cache'] = df
67
- st.session_state['current_page'] = page
68
- return df
69
- except Exception as e:
70
- st.error(f"Error loading page {page}: {str(e)}")
71
- return pd.DataFrame()
72
 
73
  def quick_search(self, query, df):
74
  """Fast search on current page"""
 
 
 
75
  scores = []
76
  query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
77
 
78
  for _, row in df.iterrows():
79
  # Combine all searchable text fields
80
- text = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float)))
 
 
 
 
 
 
81
 
82
  # Quick keyword match
83
- keyword_score = text.lower().count(query.lower()) / len(text.split())
84
 
85
  # Semantic search on combined text
86
  text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
@@ -91,8 +106,29 @@ class FastDatasetSearcher:
91
  scores.append(combined_score)
92
 
93
  # Get top results
94
- df['score'] = scores
95
- return df.sort_values('score', ascending=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def main():
98
  st.title("🎥 Fast Video Dataset Search")
@@ -100,19 +136,31 @@ def main():
100
  # Initialize search class
101
  searcher = FastDatasetSearcher()
102
 
103
- # Page navigation
104
- page = st.number_input("Page", min_value=0, value=st.session_state['current_page'])
 
 
 
 
 
 
 
105
 
106
  # Load current page
107
- with st.spinner(f"Loading page {page}..."):
108
- df = searcher.load_page(page)
109
 
110
  if df.empty:
111
  st.warning("No data available for this page.")
112
  return
113
 
114
  # Search interface
115
- query = st.text_input("Search in current page:", help="Searches within currently loaded data")
 
 
 
 
 
116
 
117
  if query:
118
  with st.spinner("Searching..."):
@@ -120,33 +168,23 @@ def main():
120
 
121
  # Display results
122
  st.write(f"Found {len(results)} results on this page:")
123
- for i, (_, result) in enumerate(results.iterrows(), 1):
124
- score = result.pop('score')
125
- with st.expander(f"Result {i} (Score: {score:.2%})", expanded=i==1):
126
- # Display video if available
127
- if 'youtube_id' in result:
128
- st.video(
129
- f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}"
130
- )
131
-
132
- # Display other fields
133
- for key, value in result.items():
134
- if isinstance(value, (str, int, float)):
135
- st.write(f"**{key}:** {value}")
136
 
137
  # Show raw data
138
- st.subheader("Raw Data")
139
- st.dataframe(df)
140
 
141
  # Navigation buttons
142
  cols = st.columns(2)
143
  with cols[0]:
144
- if st.button("Previous Page") and page > 0:
145
- st.session_state['current_page'] -= 1
146
  st.rerun()
147
  with cols[1]:
148
- if st.button("Next Page"):
149
- st.session_state['current_page'] += 1
150
  st.rerun()
151
 
152
  if __name__ == "__main__":
 
18
  st.session_state['current_page'] = 0
19
  if 'data_cache' not in st.session_state:
20
  st.session_state['data_cache'] = None
21
+ if 'dataset_info' not in st.session_state:
22
+ st.session_state['dataset_info'] = None
23
 
24
  ROWS_PER_PAGE = 100 # Number of rows to load at a time
25
 
26
  @st.cache_resource
27
  def get_model():
28
+ """Cache the model loading"""
29
  return SentenceTransformer('all-MiniLM-L6-v2')
30
 
31
+ @st.cache_data
32
+ def load_dataset_page(dataset_id, token, page, rows_per_page):
33
+ """Load and cache a specific page of data"""
34
+ try:
35
+ start_idx = page * rows_per_page
36
+ end_idx = start_idx + rows_per_page
37
+ dataset = load_dataset(
38
+ dataset_id,
39
+ token=token,
40
+ streaming=False,
41
+ split=f'train[{start_idx}:{end_idx}]'
42
+ )
43
+ return pd.DataFrame(dataset)
44
+ except Exception as e:
45
+ st.error(f"Error loading page {page}: {str(e)}")
46
+ return pd.DataFrame()
47
+
48
+ @st.cache_data
49
+ def get_dataset_info(dataset_id, token):
50
+ """Load and cache dataset information"""
51
+ try:
52
+ dataset = load_dataset(
53
+ dataset_id,
54
+ token=token,
55
+ streaming=True
56
+ )
57
+ return dataset['train'].info
58
+ except Exception as e:
59
+ st.error(f"Error loading dataset info: {str(e)}")
60
+ return None
61
+
62
  class FastDatasetSearcher:
63
  def __init__(self, dataset_id="tomg-group-umd/cinepile"):
64
  self.dataset_id = dataset_id
 
67
  if not self.token:
68
  st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
69
  st.stop()
70
+
71
+ # Load dataset info if not already loaded
72
+ if st.session_state['dataset_info'] is None:
73
+ st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def load_page(self, page=0):
76
+ """Load a specific page of data using cached function"""
77
+ return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def quick_search(self, query, df):
80
  """Fast search on current page"""
81
+ if df.empty:
82
+ return df
83
+
84
  scores = []
85
  query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
86
 
87
  for _, row in df.iterrows():
88
  # Combine all searchable text fields
89
+ text_values = []
90
+ for v in row.values():
91
+ if isinstance(v, (str, int, float)):
92
+ text_values.append(str(v))
93
+ elif isinstance(v, (list, dict)):
94
+ text_values.append(str(v))
95
+ text = ' '.join(text_values)
96
 
97
  # Quick keyword match
98
+ keyword_score = text.lower().count(query.lower()) / (len(text.split()) + 1) # Add 1 to avoid division by zero
99
 
100
  # Semantic search on combined text
101
  text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
 
106
  scores.append(combined_score)
107
 
108
  # Get top results
109
+ results_df = df.copy()
110
+ results_df['score'] = scores
111
+ return results_df.sort_values('score', ascending=False)
112
+
113
+ def render_result(result):
114
+ """Render a single search result"""
115
+ score = result.pop('score', 0)
116
+
117
+ # Display video if available
118
+ if 'youtube_id' in result:
119
+ st.video(
120
+ f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}"
121
+ )
122
+
123
+ # Display other fields
124
+ cols = st.columns([2, 1])
125
+ with cols[0]:
126
+ for key, value in result.items():
127
+ if isinstance(value, (str, int, float)):
128
+ st.write(f"**{key}:** {value}")
129
+
130
+ with cols[1]:
131
+ st.metric("Relevance Score", f"{score:.2%}")
132
 
133
  def main():
134
  st.title("🎥 Fast Video Dataset Search")
 
136
  # Initialize search class
137
  searcher = FastDatasetSearcher()
138
 
139
+ # Show dataset info
140
+ if st.session_state['dataset_info']:
141
+ st.sidebar.write("### Dataset Info")
142
+ st.sidebar.write(f"Total examples: {st.session_state['dataset_info'].splits['train'].num_examples:,}")
143
+
144
+ total_pages = st.session_state['dataset_info'].splits['train'].num_examples // ROWS_PER_PAGE
145
+ current_page = st.number_input("Page", min_value=0, max_value=total_pages, value=st.session_state['current_page'])
146
+ else:
147
+ current_page = st.number_input("Page", min_value=0, value=st.session_state['current_page'])
148
 
149
  # Load current page
150
+ with st.spinner(f"Loading page {current_page}..."):
151
+ df = searcher.load_page(current_page)
152
 
153
  if df.empty:
154
  st.warning("No data available for this page.")
155
  return
156
 
157
  # Search interface
158
+ col1, col2 = st.columns([3, 1])
159
+ with col1:
160
+ query = st.text_input("Search in current page:",
161
+ help="Searches within currently loaded data")
162
+ with col2:
163
+ max_results = st.slider("Max results", 1, ROWS_PER_PAGE, 10)
164
 
165
  if query:
166
  with st.spinner("Searching..."):
 
168
 
169
  # Display results
170
  st.write(f"Found {len(results)} results on this page:")
171
+ for i, (_, result) in enumerate(results.head(max_results).iterrows(), 1):
172
+ with st.expander(f"Result {i}", expanded=i==1):
173
+ render_result(result)
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Show raw data
176
+ with st.expander("Show Raw Data"):
177
+ st.dataframe(df)
178
 
179
  # Navigation buttons
180
  cols = st.columns(2)
181
  with cols[0]:
182
+ if st.button("⬅️ Previous Page") and current_page > 0:
183
+ st.session_state['current_page'] = current_page - 1
184
  st.rerun()
185
  with cols[1]:
186
+ if st.button("Next Page ➡️"):
187
+ st.session_state['current_page'] = current_page + 1
188
  st.rerun()
189
 
190
  if __name__ == "__main__":