jrno commited on
Commit
9c3a55c
·
1 Parent(s): 49d8290

num_recommendations and placeholder for track history

Browse files
Files changed (2) hide show
  1. .gitattributes +1 -0
  2. server.py +17 -3
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.csv filter=lfs diff=lfs merge=lfs -text
server.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastai.collab import load_learner
2
- from fastapi import FastAPI
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from custom_models import DotProductBias
5
  import asyncio
@@ -34,9 +34,22 @@ async def startup_event():
34
  loop = asyncio.get_event_loop() # get event loop
35
  tasks = [asyncio.ensure_future(setup_learner())] # assign some task
36
  learn = (await asyncio.gather(*tasks))[0]
37
-
 
 
 
 
 
 
 
 
 
 
 
38
  @app.get("/recommend/{user_id}")
39
- async def analyze(user_id: str):
 
 
40
  not_listened_songs = ["Revelry, Kings of Leon, 2008", "Gears, Miss May I, 2010", "Sexy Bitch, David Guetta, 2009"]
41
  input_dataframe = pd.DataFrame({'user_id': ["440abe26940ae9d9268157222a4a3d5735d44ed8"] * len(not_listened_songs), 'entry': not_listened_songs})
42
  test_dl = learn.dls.test_dl(input_dataframe)
@@ -47,3 +60,4 @@ async def analyze(user_id: str):
47
 
48
  if __name__ == "__main__":
49
  uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
 
 
1
  from fastai.collab import load_learner
2
+ from fastapi import FastAPI, Query
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from custom_models import DotProductBias
5
  import asyncio
 
34
  loop = asyncio.get_event_loop() # get event loop
35
  tasks = [asyncio.ensure_future(setup_learner())] # assign some task
36
  learn = (await asyncio.gather(*tasks))[0]
37
+
38
+ @app.get('/user/{user_id}/history')
39
+ async def get_user_track_history(user_id: str):
40
+ return {
41
+ "user_id": user_id,
42
+ "history": [
43
+ {"track_id": "1", "genre": "Rock", "year": "2008", "artist": "Kings of Leon", "name": "Revelry"},
44
+ {"track_id": "2", "genre": "Metalcore", "year": "2010", "artist": "Miss May I", "name": "Gears"},
45
+ {"track_id": "3", "genre": "Electro", "year": "2009", "artist": "David Guetta", "name": "Sexy Bitch"}
46
+ ]
47
+ }
48
+
49
  @app.get("/recommend/{user_id}")
50
+ async def get_recommendations_for_user(user_id: str, num_recommendations: int = Query(5)):
51
+ print(num_recommendations)
52
+ print(user_id)
53
  not_listened_songs = ["Revelry, Kings of Leon, 2008", "Gears, Miss May I, 2010", "Sexy Bitch, David Guetta, 2009"]
54
  input_dataframe = pd.DataFrame({'user_id': ["440abe26940ae9d9268157222a4a3d5735d44ed8"] * len(not_listened_songs), 'entry': not_listened_songs})
55
  test_dl = learn.dls.test_dl(input_dataframe)
 
60
 
61
  if __name__ == "__main__":
62
  uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
63
+