thomasht86 commited on
Commit
8996eb9
1 Parent(s): 580ca24

deploy at 2024-08-24 17:35:22.783475

Browse files
Files changed (2) hide show
  1. main copy.py +861 -0
  2. main.py +11 -43
main copy.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fasthtml_hf import setup_hf_backup
2
+ from fasthtml.common import (
3
+ picolink,
4
+ serve,
5
+ Div,
6
+ Title,
7
+ Main,
8
+ Input,
9
+ Button,
10
+ A,
11
+ Section,
12
+ H2,
13
+ Ul,
14
+ Li,
15
+ P,
16
+ Img,
17
+ Details,
18
+ MarkdownJS,
19
+ HighlightJS,
20
+ Summary,
21
+ Script,
22
+ I,
23
+ Form,
24
+ RedirectResponse,
25
+ dataclass,
26
+ Favicon,
27
+ database,
28
+ get_key,
29
+ Table,
30
+ Thead,
31
+ Tr,
32
+ Th,
33
+ Tbody,
34
+ Td,
35
+ FileResponse,
36
+ fast_app,
37
+ Beforeware,
38
+ Hidden,
39
+ Request,
40
+ H3,
41
+ Style,
42
+ )
43
+ from fasthtml.components import Nav, Article, Header, Mark
44
+ from fasthtml.pico import Search, Grid, Fieldset, Label
45
+ from starlette.middleware import Middleware
46
+ from starlette.middleware.base import BaseHTTPMiddleware
47
+ from starlette.middleware.sessions import SessionMiddleware
48
+ from vespa.application import Vespa
49
+ import json
50
+ import os
51
+ import re
52
+ import time
53
+ from hmac import compare_digest
54
+ from io import StringIO
55
+ import csv
56
+ import tempfile
57
+ from enum import Enum
58
+ from typing import Tuple as T
59
+ from urllib.parse import quote
60
+ import uuid
61
+
62
+ DEV_MODE = False
63
+
64
+ if DEV_MODE:
65
+ print("Running in DEV_MODE - Hot reload enabled")
66
+ print("Loading environment variables from .env")
67
+ from dotenv import load_dotenv
68
+
69
+ load_dotenv()
70
+ else:
71
+ print("DEV_MODE disabled - environment variables loaded from system")
72
+
73
+ vespa_app_url = os.getenv("VESPA_APP_URL", None)
74
+ if vespa_app_url is None:
75
+ print("Please set the VESPA_APP_URL environment variable")
76
+ exit(1)
77
+
78
+ ADMIN_NAME = os.getenv("ADMIN_NAME", "admin")
79
+ ADMIN_PWD = os.getenv("ADMIN_PWD", "admin")
80
+
81
+ vespa_app: Vespa = Vespa(
82
+ url=vespa_app_url,
83
+ vespa_cloud_secret_token=os.getenv("VESPA_CLOUD_SECRET_TOKEN"),
84
+ )
85
+ status = vespa_app.get_application_status()
86
+ if status is None:
87
+ print("Could not connect to Vespa application")
88
+ else:
89
+ print("Connected to Vespa application!")
90
+
91
+ fa = Script(src="https://kit.fontawesome.com/664eb1a115.js", crossorigin="anonymous")
92
+ favicon = Favicon(
93
+ "https://search.vespa.ai/favicon.ico",
94
+ "https://search.vespa.ai/favicon.ico",
95
+ )
96
+ DB_FILE = "db/vespa.db"
97
+ db = database(DB_FILE)
98
+ queries = db.t.queries
99
+ if queries not in db.t:
100
+ # You can pass a dict, or kwargs, to most MiniDataAPI methods.
101
+ queries.create(
102
+ dict(qid=int, query=str, ranking=str, sess_id=str, timestamp=int), pk="qid"
103
+ )
104
+ # Add autoincrement to the qid column
105
+ db.query("ALTER TABLE queries ADD COLUMN qid INTEGER PRIMARY KEY AUTOINCREMENT")
106
+ Query = queries.dataclass()
107
+
108
+ # Add a classmethod to the Query dataclass to convert timestamp field to a human readable format
109
+ Query.get_datetime = lambda self: time.strftime(
110
+ "%Y-%m-%d %H:%M:%S", time.localtime(self.timestamp)
111
+ )
112
+
113
+ # Status code 303 is a redirect that can change POST to GET,
114
+ # so it's appropriate for a login page.
115
+ login_redir = RedirectResponse("/login", status_code=303)
116
+
117
+
118
+ def user_auth_before(req, sess):
119
+ # The `auth` key in the request scope is automatically provided
120
+ # to any handler which requests it, and can not be injected
121
+ # by the user using query params, cookies, etc, so it should
122
+ # be secure to use.
123
+ print(f"Session Data before route: {sess}")
124
+ auth = req.scope["auth"] = sess.get("auth", None)
125
+ print(f"Auth: {auth}")
126
+ if not auth:
127
+ return login_redir
128
+
129
+
130
+ spinner_css = Style("""
131
+ .htmx-indicator {
132
+ display: none; /* Hide spinner by default */
133
+ }
134
+
135
+ .htmx-indicator.htmx-request {
136
+ display: block;
137
+ }
138
+ """)
139
+
140
+ headers = (
141
+ picolink,
142
+ MarkdownJS(),
143
+ HighlightJS(langs=["json", "python"]),
144
+ favicon,
145
+ fa,
146
+ spinner_css,
147
+ )
148
+
149
+ # Read file contents once before starting the server
150
+ with open("README.md") as f:
151
+ README = f.read()
152
+ with open("main.py") as f:
153
+ SOURCE = f.read()
154
+
155
+ # Sesskey
156
+ sess_key_path = "session/.sesskey"
157
+ # Make sure session directory exists
158
+ os.makedirs("session", exist_ok=True)
159
+
160
+
161
+ # Middleware
162
+ class XFrameOptionsMiddleware(BaseHTTPMiddleware):
163
+ async def dispatch(self, request, call_next):
164
+ response = await call_next(request)
165
+ response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
166
+ return response
167
+
168
+ class SessionLoggingMiddleware(BaseHTTPMiddleware):
169
+ async def dispatch(self, request, call_next):
170
+ print(f"Before request: Session data: {request.session}")
171
+ response = await call_next(request)
172
+ print(f"After request: Session data: {request.session}")
173
+ return response
174
+
175
+ class DebugSessionMiddleware(SessionMiddleware):
176
+ async def __call__(self, scope, receive, send):
177
+ print(f"DebugSessionMiddleware: Before processing - Scope: {scope}")
178
+ await super().__call__(scope, receive, send)
179
+ print(f"DebugSessionMiddleware: After processing - Scope: {scope}")
180
+
181
+ from starlette.middleware.cors import CORSMiddleware
182
+
183
+ middlewares = [
184
+ Middleware(
185
+ SessionMiddleware,
186
+ secret_key=get_key(fname=sess_key_path),
187
+ max_age=3600,
188
+ #same_site='lax',
189
+ ),
190
+ Middleware(CORSMiddleware, allow_origins=['*']),
191
+ Middleware(XFrameOptionsMiddleware),
192
+ Middleware(SessionLoggingMiddleware),
193
+ #Middleware(DebugSessionMiddleware, secret_key=get_key(fname=sess_key_path)),
194
+ ]
195
+ bware = Beforeware(
196
+ user_auth_before,
197
+ skip=[
198
+ r"/favicon\.ico",
199
+ r"/static/.*",
200
+ r".*\.css",
201
+ r".*\.js",
202
+ "/",
203
+ "/login",
204
+ "/search",
205
+ "/document/.*",
206
+ "/expand/.*",
207
+ "/source",
208
+ "/about",
209
+ ],
210
+ )
211
+
212
+ app, rt = fast_app(
213
+ before=bware,
214
+ live=DEV_MODE,
215
+ hdrs=headers,
216
+ middleware=middlewares,
217
+ key_fname=sess_key_path,
218
+ same_site="None",
219
+ )
220
+
221
+
222
+ sesskey = get_key(fname=sess_key_path)
223
+ print(f"Session key: {sesskey}")
224
+
225
+
226
+ # enum class for rank profiles
227
+ class RankProfile(str, Enum):
228
+ bm25 = "bm25"
229
+ semantic = "semantic"
230
+ fusion = "fusion"
231
+
232
+
233
+ def get_navbar(admin: bool):
234
+ print(f"In get_navbar: {admin}")
235
+ bar = Nav(
236
+ Ul(
237
+ Li(
238
+ A(
239
+ Img(src="https://vespa.ai/assets/vespa-ai-logo-heather.svg"),
240
+ href="https://cloud.vespa.ai",
241
+ target="_blank",
242
+ style="margin: 10px;",
243
+ ),
244
+ )
245
+ ),
246
+ Ul(H2("Vespa-fastHTML demo")),
247
+ Ul(
248
+ # A question mark icon with link to an about page
249
+ A(
250
+ I(cls="fa fa-question-circle fa-2x"),
251
+ href="/about",
252
+ style="margin: 10px;",
253
+ title="About this app",
254
+ ),
255
+ A(
256
+ I(cls="fab fa-slack fa-2x"),
257
+ href="https://slack.vespa.ai/",
258
+ style="margin: 10px;",
259
+ target="_blank",
260
+ title="Join Vespa Slack channel",
261
+ ),
262
+ A(
263
+ I(cls="fab fa-github fa-2x"),
264
+ href="https://github.com/vespa-engine/sample-apps/tree/master/examples/fasthtml-demo",
265
+ style="margin: 10px;",
266
+ target="_blank",
267
+ title="View source code on GitHub",
268
+ ),
269
+ A(
270
+ I(cls="fa fa-code fa-2x"),
271
+ href="/source",
272
+ style="margin: 10px;",
273
+ title="View source code",
274
+ ),
275
+ # Login icon (link to /login) show tooltip on hover. MAke it hidden if admin is logged in
276
+ A(
277
+ I(cls="fa fa-shield fa-2x"),
278
+ href="/login" if not admin else "/admin",
279
+ style="margin: 10px;",
280
+ title="Admin login",
281
+ ),
282
+ # Logout icon if admin is logged in
283
+ A(
284
+ I(cls="fa fa-sign-out fa-2x"),
285
+ href="/logout",
286
+ style="margin: 10px;" if admin else "display: none;",
287
+ title="Logout",
288
+ ),
289
+ ),
290
+ # 10px margin to right of navbar
291
+ style="margin-right: 10px;",
292
+ )
293
+ return bar
294
+
295
+
296
+ def spinner_div(hidden: bool = False):
297
+ return Div(
298
+ A(
299
+ id="spinner",
300
+ aria_busy="true",
301
+ cls="htmx-indicator",
302
+ style="font-size: 2em;",
303
+ ),
304
+ style="text-align: center; margin-top: 40px;"
305
+ if not hidden
306
+ else "display: none;",
307
+ )
308
+
309
+
310
+ @app.route("/")
311
+ def get(sess):
312
+ # Can not get auth directly, as it is skipped in beforeware
313
+ auth = sess.get("auth", False)
314
+ queries = [
315
+ "Breast Cancer Cells Feed on Cholesterol",
316
+ "Treating Asthma With Plants vs. Pills",
317
+ "Testing Turmeric on Smokers",
318
+ "The Role of Pesticides in Parkinson's Disease",
319
+ ]
320
+ return (
321
+ Title("Vespa demo"),
322
+ get_navbar(auth),
323
+ Main(
324
+ # Search bar
325
+ Search(
326
+ Input(
327
+ type="search",
328
+ placeholder="Ask/search for medical information?",
329
+ id="userquery",
330
+ ),
331
+ # Get search results on button click with search-input as query parameter
332
+ Button(
333
+ "Search",
334
+ hx_get="/search",
335
+ # include userquery and id of selected ranking radio button
336
+ hx_include="#userquery, input[name=ranking]:checked",
337
+ hx_target="#results",
338
+ hx_indicator="#spinner",
339
+ ),
340
+ style="margin: 10% 10px 0 0;",
341
+ ),
342
+ Fieldset(
343
+ Input(type="radio", id="bm25", name="ranking", value="bm25"),
344
+ Label("BM25", htmlfor="bm25"),
345
+ Input(type="radio", id="semantic", name="ranking", value="semantic"),
346
+ Label("Semantic", htmlfor="semantic"),
347
+ Input(
348
+ type="radio",
349
+ id="fusion",
350
+ name="ranking",
351
+ value="fusion",
352
+ checked="",
353
+ ),
354
+ Label("Reciprocal Rank fusion", htmlfor="fusion"),
355
+ style="margin: 10px; text-align: center;",
356
+ id="ranking",
357
+ ),
358
+ H3("Example queries"),
359
+ # Buttons with predefined search queries
360
+ Grid(
361
+ *[
362
+ Button(
363
+ query,
364
+ hx_get="/search?userquery=" + query,
365
+ hx_include="input[name=ranking]:checked",
366
+ hx_target="#results",
367
+ hx_indicator="#spinner",
368
+ hx_on_click=f"document.getElementById('userquery').value='{query}'",
369
+ style="margin: 10px; padding: 5px;",
370
+ cls="secondary outline",
371
+ id=f"example-{qid}",
372
+ )
373
+ for qid, query in enumerate(queries)
374
+ ],
375
+ # Make the grid buttons have same height and distribute evenly and center align
376
+ style="grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));",
377
+ ),
378
+ # Section(
379
+ # Input(
380
+ # id="suggestion-input",
381
+ # list="search-options",
382
+ # placeholder="Search options",
383
+ # ),
384
+ # Datalist(
385
+ # *[
386
+ # Option(
387
+ # "Covid-19",
388
+ # value="Covid-19",
389
+ # ),
390
+ # Option(
391
+ # "Vaccine",
392
+ # value="Vaccine",
393
+ # ),
394
+ # ],
395
+ # id="search-options",
396
+ # ),
397
+ # id="suggestions",
398
+ # ),
399
+ # Display spinner div only if it #spinner does not exist
400
+ Section(
401
+ spinner_div(),
402
+ id="results",
403
+ hx_swap="innerHTML",
404
+ style="margin: 20px;",
405
+ ),
406
+ style="margin: 0 auto; width: 70%;",
407
+ id="main",
408
+ ),
409
+ )
410
+
411
+
412
+ @dataclass
413
+ class Login:
414
+ name: str
415
+ pwd: str
416
+
417
+
418
+ @app.get("/login")
419
+ def get_login_form(sess, error: bool = False):
420
+ auth = sess.get("auth", False)
421
+ frm = Form(
422
+ Input(id="name", placeholder="Name"),
423
+ Input(id="pwd", type="password", placeholder="Password"),
424
+ Button("login"),
425
+ action="/login",
426
+ method="post",
427
+ )
428
+ err_msg = P("Incorrect password", style="color: red;") if error else ""
429
+ return (
430
+ Title("Admin login"),
431
+ get_navbar(auth),
432
+ Main(
433
+ err_msg,
434
+ frm,
435
+ style="width: 50%; margin: 10% auto;",
436
+ ),
437
+ )
438
+
439
+
440
+ @app.post("/login")
441
+ def post(login: Login, sess):
442
+ if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
443
+ # Incorrect password - add error message
444
+ return RedirectResponse("/login?error=True", status_code=303)
445
+ print(f"Session after setting auth: {sess}")
446
+ response = RedirectResponse("/admin", status_code=303)
447
+ print(f"Cookies being set: {response.headers.get('Set-Cookie')}")
448
+ return response
449
+
450
+
451
+ @app.get("/logout")
452
+ def logout(sess):
453
+ sess["auth"] = False
454
+ return RedirectResponse("/")
455
+
456
+
457
+ def replace_hi_with_strong(text):
458
+ parts = re.split(r"(<hi>|</hi>)", text)
459
+ elements = []
460
+ open_tag = False
461
+ for part in parts:
462
+ if part == "<hi>":
463
+ open_tag = True
464
+ elif part == "</hi>":
465
+ open_tag = False
466
+ elif open_tag:
467
+ elements.append(Mark(part))
468
+ else:
469
+ elements.append(part)
470
+ return elements
471
+
472
+
473
+ def log_query_to_db(query, ranking, sess):
474
+ queries.insert(
475
+ Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
476
+ )
477
+ if 'user_id' not in sess:
478
+ sess['user_id'] = str(uuid.uuid4())
479
+
480
+ if 'queries' not in sess:
481
+ sess['queries'] = []
482
+
483
+ query_data = {
484
+ 'query': query,
485
+ 'ranking': ranking,
486
+ 'timestamp': int(time.time())
487
+ }
488
+ sess['queries'].append(query_data)
489
+
490
+ # Limit the number of queries stored in the session to prevent it from growing too large
491
+ sess['queries'] = sess['queries'][-100:] # Keep only the last 100 queries
492
+
493
+ return query_data
494
+
495
+
496
+ def parse_results(records):
497
+ return [
498
+ Article(
499
+ Header(
500
+ H2(
501
+ A(
502
+ result["title"],
503
+ hx_get=f"/document/{result['id']}",
504
+ hx_target="#results",
505
+ )
506
+ )
507
+ ),
508
+ Div(
509
+ P(
510
+ *replace_hi_with_strong(
511
+ result["body"][:300] + "..."
512
+ ), # Display first 300 characters of body
513
+ ),
514
+ Div(
515
+ # Button with "Show more" - center align
516
+ Button(
517
+ "Show more",
518
+ hx_post=f"/expand/{result['id']}?expand=true",
519
+ hx_target=f"#{result['id']}",
520
+ hx_include=f"#{result['id']}-full",
521
+ cls="outline secondary",
522
+ # Style to fill whole width of parent div
523
+ style="width: 100%;",
524
+ ),
525
+ style="text-align: center;",
526
+ ),
527
+ id=result["id"],
528
+ ),
529
+ Hidden(result["body"], id=f"{result['id']}-full"),
530
+ )
531
+ for result in records
532
+ ]
533
+
534
+
535
+ @app.post("/expand/{docid}")
536
+ async def expand(request: Request, docid: str, expand: bool):
537
+ print(f"Expanding {docid}")
538
+ form_data = await request.form()
539
+ result = form_data.get(f"{docid}-full")
540
+ if not expand:
541
+ result = result[:300] + "..."
542
+ return (
543
+ Div(
544
+ P(
545
+ *replace_hi_with_strong(result), # Display full body
546
+ ),
547
+ Div(
548
+ # Button with "Show less" - center align
549
+ Button(
550
+ "Show less" if expand else "Show more",
551
+ hx_post=f"/expand/{docid}?expand="
552
+ + ("false" if expand else "true"),
553
+ hx_target=f"#{docid}",
554
+ hx_include=f"#{docid}-full",
555
+ cls="outline secondary",
556
+ # Style to fill whole width of parent div
557
+ style="width: 100%;",
558
+ ),
559
+ style="text-align: center;",
560
+ ),
561
+ id=docid,
562
+ ),
563
+ )
564
+
565
+
566
+ # Returns tuple of (yql, body(dict)) based on the ranking profile
567
+ def get_yql(ranking: RankProfile, userquery: str) -> T[str, dict]:
568
+ if ranking == RankProfile.bm25:
569
+ yql = "select * from sources * where userQuery() limit 10"
570
+ body = {}
571
+ elif ranking == RankProfile.semantic:
572
+ yql = "select * from sources * where ({targetHits:10}nearestNeighbor(embedding,q)) limit 10"
573
+ body = {"input.query(q)": f"embed({userquery})"}
574
+ elif ranking == RankProfile.fusion:
575
+ yql = "select * from sources * where rank({targetHits:1000}nearestNeighbor(embedding,q), userQuery()) limit 10"
576
+ body = {"input.query(q)": f"embed({userquery})"}
577
+ return yql, body
578
+
579
+
580
+ @app.get("/search")
581
+ async def search(userquery: str, ranking: str, sess):
582
+ print(sess)
583
+ quoted = quote(userquery) + "&ranking=" + ranking
584
+ log_query_to_db(userquery, ranking, sess)
585
+ yql, body = get_yql(ranking, userquery)
586
+ async with vespa_app.asyncio() as session:
587
+ resp = await session.query(
588
+ yql=yql,
589
+ query=userquery,
590
+ hits=10,
591
+ ranking=str(ranking),
592
+ body=body,
593
+ )
594
+ records = []
595
+ fields = ["id", "title", "body"]
596
+ for hit in resp.hits:
597
+ record = {}
598
+ for field in fields:
599
+ record[field] = hit["fields"][field]
600
+ records.append(record)
601
+ results = parse_results(records)
602
+ json_dump = json.dumps(resp.get_json(), indent=4)
603
+ return Div(
604
+ spinner_div(),
605
+ # Accordion (with Details)
606
+ Details(
607
+ Summary("Full JSON response"),
608
+ Div(
609
+ f"""```json\n{json_dump}\n```""",
610
+ cls="marked",
611
+ ),
612
+ ),
613
+ H2(
614
+ "Search Results",
615
+ ),
616
+ Div(
617
+ *results,
618
+ id="all-searchresults",
619
+ ),
620
+ )
621
+
622
+
623
+ @app.get("/download_csv")
624
+ def download_csv(auth):
625
+ queries_dict = list(db.query("SELECT * FROM queries"))
626
+ queries = [Query(**query) for query in queries_dict]
627
+
628
+ # Create CSV in memory
629
+ csv_file = StringIO()
630
+ csv_writer = csv.writer(csv_file)
631
+ csv_writer.writerow(["Query", "Session ID", "Timestamp"])
632
+ for query in queries:
633
+ csv_writer.writerow([query.query, query.sess_id, query.timestamp])
634
+
635
+ # Move to the beginning of the StringIO object
636
+ csv_file.seek(0)
637
+
638
+ # Save CSV to a temporary file
639
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
640
+ temp_file.write(csv_file.getvalue().encode("utf-8"))
641
+ temp_file.close()
642
+
643
+ return FileResponse(
644
+ temp_file.name,
645
+ filename="queries.csv",
646
+ media_type="text/csv",
647
+ content_disposition_type="attachment",
648
+ )
649
+
650
+
651
+ @app.get("/admin")
652
+ def get_admin(auth, page: int = 1):
653
+ limit = 15
654
+ offset = (page - 1) * limit
655
+ total_queries_result = list(
656
+ db.query("SELECT COUNT(*) AS count FROM queries ORDER BY timestamp DESC")
657
+ )
658
+ total_queries = total_queries_result[0]["count"]
659
+ queries_dict = list(
660
+ db.query(f"SELECT * FROM queries LIMIT {limit} OFFSET {offset}")
661
+ )
662
+ queries = [Query(**query) for query in queries_dict]
663
+
664
+ total_pages = (
665
+ total_queries + limit - 1
666
+ ) // limit # Calculate total number of pages
667
+
668
+ # Define the range of pages to display
669
+ page_window = 5 # Number of pages to display at once
670
+ start_page = max(1, page - page_window // 2)
671
+ end_page = min(total_pages, start_page + page_window - 1)
672
+
673
+ # Adjust the start and end pages if they exceed the limits
674
+ if end_page - start_page < page_window:
675
+ start_page = max(1, end_page - page_window + 1)
676
+
677
+ # Pagination controls with "First", "Previous", "Next", and "Last"
678
+ pagination_controls = Div(
679
+ A(
680
+ "First",
681
+ href="/admin?page=1",
682
+ style="margin: 5px;"
683
+ if page > 1
684
+ else "margin: 5px; color: grey; pointer-events: none;",
685
+ ),
686
+ A(
687
+ "Previous",
688
+ href=f"/admin?page={page - 1}",
689
+ style="margin: 5px;"
690
+ if page > 1
691
+ else "margin: 5px; color: grey; pointer-events: none;",
692
+ ),
693
+ *[
694
+ A(
695
+ f"{i}",
696
+ href=f"/admin?page={i}",
697
+ style="margin: 5px;"
698
+ if i != page
699
+ else "margin: 5px; font-weight: bold;",
700
+ )
701
+ for i in range(start_page, end_page + 1)
702
+ ],
703
+ A(
704
+ "Next",
705
+ href=f"/admin?page={page + 1}",
706
+ style="margin: 5px;"
707
+ if page < total_pages
708
+ else "margin: 5px; color: grey; pointer-events: none;",
709
+ ),
710
+ A(
711
+ "Last",
712
+ href=f"/admin?page={total_pages}",
713
+ style="margin: 5px;"
714
+ if page < total_pages
715
+ else "margin: 5px; color: grey; pointer-events: none;",
716
+ ),
717
+ style="text-align: center; margin: 20px;",
718
+ )
719
+
720
+ # Total pages indication
721
+ total_pages_indicator = Div(
722
+ f"Page {page} of {total_pages}",
723
+ style="text-align: center; margin: 10px;",
724
+ )
725
+
726
+ return (
727
+ Title("Admin"),
728
+ get_navbar(auth),
729
+ Main(
730
+ Div(
731
+ A(
732
+ I(cls="fa fa-arrow-left"),
733
+ "Back",
734
+ href="/",
735
+ title="Back to main page",
736
+ style="margin: 10px;",
737
+ ),
738
+ style="margin: 10px;",
739
+ ),
740
+ H2("Queries"),
741
+ # Table of all queries
742
+ Table(
743
+ Thead(
744
+ Tr(
745
+ Th("Query"),
746
+ Th("Session ID"),
747
+ Th("Datetime"),
748
+ )
749
+ ),
750
+ Tbody(
751
+ *[
752
+ Tr(
753
+ Td(query.query),
754
+ Td(query.sess_id),
755
+ Td(query.get_datetime()),
756
+ )
757
+ for query in queries
758
+ ],
759
+ ),
760
+ cls="striped",
761
+ ),
762
+ total_pages_indicator, # Include the total pages indicator here
763
+ pagination_controls,
764
+ Div(
765
+ A(
766
+ I(cls="fa fa-download fa-2x"),
767
+ " Download CSV",
768
+ href="/download_csv",
769
+ style="margin: 10px; float: right;",
770
+ title="Download queries as CSV",
771
+ ),
772
+ style="text-align: right; margin: 20px;",
773
+ ),
774
+ style="width: 80%; margin: 40px auto;",
775
+ ),
776
+ )
777
+
778
+
779
+ @app.get("/source")
780
+ def get_source(auth, sess):
781
+ # Back icon to go back to main page in top left corner
782
+ return (
783
+ Title("Source code"),
784
+ get_navbar(auth),
785
+ Main(
786
+ Div(
787
+ A(
788
+ I(cls="fa fa-arrow-left"),
789
+ "Back",
790
+ href="/",
791
+ title="Back to main page",
792
+ style="margin: 10px;",
793
+ ),
794
+ Div(
795
+ f"""### `main.py`\n### This is the complete source code for this app \n```python\n{SOURCE}\n```""",
796
+ cls="marked",
797
+ style="margin: 10px;",
798
+ ),
799
+ style="width: 80%; margin: 40px auto;",
800
+ ),
801
+ ),
802
+ )
803
+
804
+
805
+ @app.get("/about")
806
+ def get_about(auth, sess):
807
+ # Strip everything before the FIRST # in the README
808
+ stripped_readme = re.sub(
809
+ r"^.*?(?=# FastHTML Vespa frontend)", "", README, flags=re.DOTALL
810
+ )
811
+
812
+ return (
813
+ Title("About this app"),
814
+ get_navbar(auth),
815
+ Main(
816
+ Div(
817
+ A(
818
+ I(cls="fa fa-arrow-left"),
819
+ "Back",
820
+ href="/",
821
+ title="Back to main page",
822
+ style="margin: 10px;",
823
+ ),
824
+ Div(
825
+ stripped_readme,
826
+ cls="marked",
827
+ style="margin: 10px;",
828
+ ),
829
+ style="width: 80%; margin: 40px auto;",
830
+ ),
831
+ ),
832
+ )
833
+
834
+
835
+ @app.get("/document/{docid}")
836
+ def get_document(docid: str, sess):
837
+ resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
838
+ doc = resp.json
839
+ # Link with Back to search results at top of page
840
+ last_query = sess.get('queries', [{}])[-1].get('query', '')
841
+ return Main(
842
+ Div(
843
+ A(
844
+ I(cls="fa fa-arrow-left"),
845
+ "Back to search results",
846
+ hx_get=f"/search?userquery={last_query}",
847
+ hx_target="#results",
848
+ style="margin: 10px;",
849
+ ),
850
+ H2(doc["fields"]["title"], style="margin: 10px;"),
851
+ P(doc["fields"]["body"], cls="marked"),
852
+ ),
853
+ )
854
+
855
+
856
+ if not DEV_MODE:
857
+ try:
858
+ setup_hf_backup(app)
859
+ except Exception as e:
860
+ print(f"Error setting up hf backup: {e}")
861
+ serve()
main.py CHANGED
@@ -57,7 +57,6 @@ import tempfile
57
  from enum import Enum
58
  from typing import Tuple as T
59
  from urllib.parse import quote
60
- import uuid
61
 
62
  DEV_MODE = False
63
 
@@ -165,32 +164,14 @@ class XFrameOptionsMiddleware(BaseHTTPMiddleware):
165
  response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
166
  return response
167
 
168
- class SessionLoggingMiddleware(BaseHTTPMiddleware):
169
- async def dispatch(self, request, call_next):
170
- print(f"Before request: Session data: {request.session}")
171
- response = await call_next(request)
172
- print(f"After request: Session data: {request.session}")
173
- return response
174
-
175
- class DebugSessionMiddleware(SessionMiddleware):
176
- async def __call__(self, scope, receive, send):
177
- print(f"DebugSessionMiddleware: Before processing - Scope: {scope}")
178
- await super().__call__(scope, receive, send)
179
- print(f"DebugSessionMiddleware: After processing - Scope: {scope}")
180
-
181
- from starlette.middleware.cors import CORSMiddleware
182
 
183
  middlewares = [
184
  Middleware(
185
  SessionMiddleware,
186
  secret_key=get_key(fname=sess_key_path),
187
  max_age=3600,
188
- #same_site='lax',
189
  ),
190
- Middleware(CORSMiddleware, allow_origins=['*']),
191
  Middleware(XFrameOptionsMiddleware),
192
- Middleware(SessionLoggingMiddleware),
193
- #Middleware(DebugSessionMiddleware, secret_key=get_key(fname=sess_key_path)),
194
  ]
195
  bware = Beforeware(
196
  user_auth_before,
@@ -314,6 +295,7 @@ def get(sess):
314
  queries = [
315
  "Breast Cancer Cells Feed on Cholesterol",
316
  "Treating Asthma With Plants vs. Pills",
 
317
  "Testing Turmeric on Smokers",
318
  "The Role of Pesticides in Parkinson's Disease",
319
  ]
@@ -442,10 +424,9 @@ def post(login: Login, sess):
442
  if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
443
  # Incorrect password - add error message
444
  return RedirectResponse("/login?error=True", status_code=303)
445
- print(f"Session after setting auth: {sess}")
446
- response = RedirectResponse("/admin", status_code=303)
447
- print(f"Cookies being set: {response.headers.get('Set-Cookie')}")
448
- return response
449
 
450
 
451
  @app.get("/logout")
@@ -471,26 +452,9 @@ def replace_hi_with_strong(text):
471
 
472
 
473
  def log_query_to_db(query, ranking, sess):
474
- queries.insert(
475
  Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
476
  )
477
- if 'user_id' not in sess:
478
- sess['user_id'] = str(uuid.uuid4())
479
-
480
- if 'queries' not in sess:
481
- sess['queries'] = []
482
-
483
- query_data = {
484
- 'query': query,
485
- 'ranking': ranking,
486
- 'timestamp': int(time.time())
487
- }
488
- sess['queries'].append(query_data)
489
-
490
- # Limit the number of queries stored in the session to prevent it from growing too large
491
- sess['queries'] = sess['queries'][-100:] # Keep only the last 100 queries
492
-
493
- return query_data
494
 
495
 
496
  def parse_results(records):
@@ -580,7 +544,12 @@ def get_yql(ranking: RankProfile, userquery: str) -> T[str, dict]:
580
  @app.get("/search")
581
  async def search(userquery: str, ranking: str, sess):
582
  print(sess)
 
 
583
  quoted = quote(userquery) + "&ranking=" + ranking
 
 
 
584
  log_query_to_db(userquery, ranking, sess)
585
  yql, body = get_yql(ranking, userquery)
586
  async with vespa_app.asyncio() as session:
@@ -837,13 +806,12 @@ def get_document(docid: str, sess):
837
  resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
838
  doc = resp.json
839
  # Link with Back to search results at top of page
840
- last_query = sess.get('queries', [{}])[-1].get('query', '')
841
  return Main(
842
  Div(
843
  A(
844
  I(cls="fa fa-arrow-left"),
845
  "Back to search results",
846
- hx_get=f"/search?userquery={last_query}",
847
  hx_target="#results",
848
  style="margin: 10px;",
849
  ),
 
57
  from enum import Enum
58
  from typing import Tuple as T
59
  from urllib.parse import quote
 
60
 
61
  DEV_MODE = False
62
 
 
164
  response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
165
  return response
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  middlewares = [
169
  Middleware(
170
  SessionMiddleware,
171
  secret_key=get_key(fname=sess_key_path),
172
  max_age=3600,
 
173
  ),
 
174
  Middleware(XFrameOptionsMiddleware),
 
 
175
  ]
176
  bware = Beforeware(
177
  user_auth_before,
 
295
  queries = [
296
  "Breast Cancer Cells Feed on Cholesterol",
297
  "Treating Asthma With Plants vs. Pills",
298
+ "Alkylphenol Endocrine Disruptors",
299
  "Testing Turmeric on Smokers",
300
  "The Role of Pesticides in Parkinson's Disease",
301
  ]
 
424
  if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
425
  # Incorrect password - add error message
426
  return RedirectResponse("/login?error=True", status_code=303)
427
+ sess["auth"] = True
428
+ print(f"Sess after login: {sess}")
429
+ return RedirectResponse("/admin", status_code=303)
 
430
 
431
 
432
  @app.get("/logout")
 
452
 
453
 
454
  def log_query_to_db(query, ranking, sess):
455
+ return queries.insert(
456
  Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
457
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
 
460
  def parse_results(records):
 
544
  @app.get("/search")
545
  async def search(userquery: str, ranking: str, sess):
546
  print(sess)
547
+ if "queries" not in sess:
548
+ sess["queries"] = []
549
  quoted = quote(userquery) + "&ranking=" + ranking
550
+ sess["queries"].append(quoted)
551
+ print(f"Searching for: {userquery}")
552
+ print(f"Ranking: {ranking}")
553
  log_query_to_db(userquery, ranking, sess)
554
  yql, body = get_yql(ranking, userquery)
555
  async with vespa_app.asyncio() as session:
 
806
  resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
807
  doc = resp.json
808
  # Link with Back to search results at top of page
 
809
  return Main(
810
  Div(
811
  A(
812
  I(cls="fa fa-arrow-left"),
813
  "Back to search results",
814
+ hx_get=f"/search?userquery={sess['queries'][-1]}",
815
  hx_target="#results",
816
  style="margin: 10px;",
817
  ),