AavV4 commited on
Commit
87bdd54
·
verified ·
1 Parent(s): 6287031

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -3,28 +3,25 @@ import xgboost as xgb
3
  import tensorflow as tf
4
  from transformers import RobertaTokenizer, TFRobertaModel
5
 
6
- # Define model paths
7
- PU_MODEL_PATH = "trained_model_pu/pu_model.json"
 
 
8
 
9
- # Load PU model from Hugging Face
10
- pu_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
11
- roberta_model = TFRobertaModel.from_pretrained('roberta-base')
12
- roberta_model.trainable = False
13
-
14
- # Load XGBoost classifier
15
  pu_classifier = xgb.Booster()
16
- pu_classifier.load_model(PU_MODEL_PATH)
17
 
18
- # PU model classification function
19
  def classify_with_pu(text):
20
  inputs = pu_tokenizer(text, return_tensors="tf", truncation=True, max_length=128)
21
  embeddings = roberta_model(inputs).last_hidden_state[:, 0, :]
22
  dmatrix = xgb.DMatrix(embeddings.numpy())
23
  pu_probs = pu_classifier.predict(dmatrix)
24
- return {"spam_probability": max(0, min(1, float(pu_probs[0]) if pu_probs.size > 0 else 0.5))}
25
 
26
  # Create API
27
- iface = gr.Interface(fn=classify_with_pu, inputs=gr.Textbox(), outputs="json")
28
 
29
  # Launch API
30
  if __name__ == "__main__":
 
3
  import tensorflow as tf
4
  from transformers import RobertaTokenizer, TFRobertaModel
5
 
6
+ # Load PU model
7
+ pu_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
8
+ roberta_model = TFRobertaModel.from_pretrained("roberta-base")
9
+ roberta_model.trainable = False # Freeze RoBERTa model
10
 
11
+ # Load XGBoost classifier directly from the uploaded file
 
 
 
 
 
12
  pu_classifier = xgb.Booster()
13
+ pu_classifier.load_model("xgboost_spam_filter.model") # Use the filename directly
14
 
15
+ # Classification function
16
  def classify_with_pu(text):
17
  inputs = pu_tokenizer(text, return_tensors="tf", truncation=True, max_length=128)
18
  embeddings = roberta_model(inputs).last_hidden_state[:, 0, :]
19
  dmatrix = xgb.DMatrix(embeddings.numpy())
20
  pu_probs = pu_classifier.predict(dmatrix)
21
+ return {"prediction": "Spam" if pu_probs[0] > 0.5 else "Not Spam", "probability": float(pu_probs[0])}
22
 
23
  # Create API
24
+ iface = gr.Interface(fn=classify_with_pu, inputs=gr.Textbox(label="Enter Message"), outputs="json")
25
 
26
  # Launch API
27
  if __name__ == "__main__":