|
import torch |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
MODEL_PATH = "trained_model2/distilroberta_model.pth" |
|
TOKENIZER_DIR = "trained_model2/distilroberta_tokenizer" |
|
|
|
|
|
tokenizer_rl = AutoTokenizer.from_pretrained(TOKENIZER_DIR) |
|
|
|
|
|
model_rl = AutoModelForSequenceClassification.from_pretrained('distilroberta-base', num_labels=2) |
|
model_rl.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) |
|
model_rl.eval() |
|
|
|
|
|
def classify_with_rl(text): |
|
inputs = tokenizer_rl(text, return_tensors="pt", truncation=True, padding=True, max_length=128) |
|
with torch.no_grad(): |
|
outputs = model_rl(**inputs) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
return {"spam_probability": max(0, min(1, float(probs[0][1])))} |
|
|
|
|
|
iface = gr.Interface(fn=classify_with_rl, inputs=gr.Textbox(), outputs="json") |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|