Thamaraikannan commited on
Commit
efa6a7b
·
1 Parent(s): 0b975e1

latest commit

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -1,20 +1,34 @@
1
- from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
 
3
  from setfit import SetFitModel
4
 
5
- app = FastAPI()
6
-
7
-
8
  class TextInput(BaseModel):
9
- text: str
 
 
 
10
 
 
 
11
 
12
  @app.post("/predict")
13
  async def predict(input: TextInput):
 
14
  text_input = input.text
15
- model = SetFitModel.from_pretrained("assets")
16
- preds = model.predict(text_input)
17
- return {"label": preds}
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  @app.get("/")
 
1
+ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from typing import List, Dict, Any
4
  from setfit import SetFitModel
5
 
6
+ # Define the input model
 
 
7
  class TextInput(BaseModel):
8
+ text: List[str]
9
+
10
+ # Initialize the FastAPI app
11
+ app = FastAPI()
12
 
13
+ # Load the model once, when the app starts
14
+ model = SetFitModel.from_pretrained("assets")
15
 
16
  @app.post("/predict")
17
  async def predict(input: TextInput):
18
+ # Get the text input from the request
19
  text_input = input.text
20
+
21
+ # Initialize a list to store the predictions
22
+ response: List[Dict[str, Any]] = []
23
+
24
+ # Predict using the loaded model for each message
25
+ for message in text_input:
26
+ pred = model.predict([message])[0] # Predict expects a list, so wrap the message in a list and take the first result
27
+ response.append({"Message": message, "label": pred})
28
+
29
+ return response
30
+
31
+
32
 
33
 
34
  @app.get("/")