Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,28 +3,25 @@ import xgboost as xgb
|
|
3 |
import tensorflow as tf
|
4 |
from transformers import RobertaTokenizer, TFRobertaModel
|
5 |
|
6 |
-
#
|
7 |
-
|
|
|
|
|
8 |
|
9 |
-
# Load
|
10 |
-
pu_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
11 |
-
roberta_model = TFRobertaModel.from_pretrained('roberta-base')
|
12 |
-
roberta_model.trainable = False
|
13 |
-
|
14 |
-
# Load XGBoost classifier
|
15 |
pu_classifier = xgb.Booster()
|
16 |
-
pu_classifier.load_model(
|
17 |
|
18 |
-
#
|
19 |
def classify_with_pu(text):
|
20 |
inputs = pu_tokenizer(text, return_tensors="tf", truncation=True, max_length=128)
|
21 |
embeddings = roberta_model(inputs).last_hidden_state[:, 0, :]
|
22 |
dmatrix = xgb.DMatrix(embeddings.numpy())
|
23 |
pu_probs = pu_classifier.predict(dmatrix)
|
24 |
-
return {"
|
25 |
|
26 |
# Create API
|
27 |
-
iface = gr.Interface(fn=classify_with_pu, inputs=gr.Textbox(), outputs="json")
|
28 |
|
29 |
# Launch API
|
30 |
if __name__ == "__main__":
|
|
|
3 |
import tensorflow as tf
|
4 |
from transformers import RobertaTokenizer, TFRobertaModel
|
5 |
|
6 |
+
# Load PU model
|
7 |
+
pu_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
8 |
+
roberta_model = TFRobertaModel.from_pretrained("roberta-base")
|
9 |
+
roberta_model.trainable = False # Freeze RoBERTa model
|
10 |
|
11 |
+
# Load XGBoost classifier directly from the uploaded file
|
|
|
|
|
|
|
|
|
|
|
12 |
pu_classifier = xgb.Booster()
|
13 |
+
pu_classifier.load_model("xgboost_spam_filter.model") # Use the filename directly
|
14 |
|
15 |
+
# Classification function
|
16 |
def classify_with_pu(text):
|
17 |
inputs = pu_tokenizer(text, return_tensors="tf", truncation=True, max_length=128)
|
18 |
embeddings = roberta_model(inputs).last_hidden_state[:, 0, :]
|
19 |
dmatrix = xgb.DMatrix(embeddings.numpy())
|
20 |
pu_probs = pu_classifier.predict(dmatrix)
|
21 |
+
return {"prediction": "Spam" if pu_probs[0] > 0.5 else "Not Spam", "probability": float(pu_probs[0])}
|
22 |
|
23 |
# Create API
|
24 |
+
iface = gr.Interface(fn=classify_with_pu, inputs=gr.Textbox(label="Enter Message"), outputs="json")
|
25 |
|
26 |
# Launch API
|
27 |
if __name__ == "__main__":
|