sms_filter / app.py
Thamaraikannan's picture
latest commit
efa6a7b
raw
history blame
926 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict, Any
from setfit import SetFitModel
# Define the input model
class TextInput(BaseModel):
text: List[str]
# Initialize the FastAPI app
app = FastAPI()
# Load the model once, when the app starts
model = SetFitModel.from_pretrained("assets")
@app.post("/predict")
async def predict(input: TextInput):
# Get the text input from the request
text_input = input.text
# Initialize a list to store the predictions
response: List[Dict[str, Any]] = []
# Predict using the loaded model for each message
for message in text_input:
pred = model.predict([message])[0] # Predict expects a list, so wrap the message in a list and take the first result
response.append({"Message": message, "label": pred})
return response
@app.get("/")
def read_root():
return {"message": "Welcome"}