Canstralian commited on
Commit
9819ce2
·
verified ·
1 Parent(s): 875196b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -17
app.py CHANGED
@@ -1,29 +1,25 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
5
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
- # Load the tokenizer and model
10
- tokenizer = AutoTokenizer.from_pretrained("canstralian/CyberAttackDetection")
11
- model = AutoModelForCausalLM.from_pretrained("canstralian/CyberAttackDetection")
12
 
13
  # Define the input data model
14
  class LogData(BaseModel):
15
  log: str
16
 
17
  @app.post("/predict")
18
- async def predict(data: LogData):
19
- # Tokenize the input log data
20
- inputs = tokenizer(data.log, return_tensors="pt")
21
-
22
- # Generate predictions
23
- with torch.no_grad():
24
- outputs = model.generate(**inputs)
25
-
26
- # Decode the generated tokens to text
27
- prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
-
29
- return {"prediction": prediction}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from src.model_inference import predict
4
+ from src.utils import setup_logging, log_info, log_error
5
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
+ # Set up logging
10
+ setup_logging()
 
11
 
12
  # Define the input data model
13
  class LogData(BaseModel):
14
  log: str
15
 
16
  @app.post("/predict")
17
+ async def predict_route(data: LogData):
18
+ try:
19
+ # Perform prediction
20
+ prediction = predict(data.log)
21
+ log_info(f'Prediction: {prediction}')
22
+ return {"prediction": prediction}
23
+ except Exception as e:
24
+ log_error(f'An error occurred: {e}')
25
+ return {"error": str(e)}