me / inference_server.py
streetyogi's picture
Update inference_server.py
0e3f8e5
raw
history blame
975 Bytes
import logging
from sklearn.linear_model import SGDClassifier
import uvicorn
from fastapi import FastAPI
app = FastAPI()
def predict(input_text: str):
data = [[ord(c) for c in input_text]] # Convert the string to a list of ASCII values
model = train(data)
# Make a prediction
prediction = model.predict([[ord(c) for c in 'abc']]) # Convert the input string to a list of ASCII values
return {"prediction": prediction}
def train(X):
model = SGDClassifier()
model.fit(X, X) # In this case, we are using the input data as the labels
return model
# Here you can do things such as load your models
@app.get("/")
def read_root(input_text):
logging.info("Received request with input_text: %s", input_text)
try:
result = predict(input_text)
logging.info("Prediction made: %s", result)
return {"result": 1}
except Exception as e:
logging.error("An error occured: %s", e)
return {"error": str(e)}