Canstralian's picture
Create app.py
bf27104 verified
raw
history blame
844 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Initialize FastAPI app
app = FastAPI()
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("canstralian/CyberAttackDetection")
model = AutoModelForCausalLM.from_pretrained("canstralian/CyberAttackDetection")
# Define the input data model
class LogData(BaseModel):
log: str
@app.post("/predict")
async def predict(data: LogData):
# Tokenize the input log data
inputs = tokenizer(data.log, return_tensors="pt")
# Generate predictions
with torch.no_grad():
outputs = model.generate(**inputs)
# Decode the generated tokens to text
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"prediction": prediction}