Spaces:
Runtime error
Runtime error
Avril Lalaine
commited on
Commit
·
0ad9aa8
1
Parent(s):
28bff37
Add flask app with dockerfire
Browse files- Dockerfile +11 -0
- app.py +133 -0
- model.py +87 -0
- requirements.txt +12 -0
- templates/index.html +155 -0
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY . .
|
6 |
+
|
7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
8 |
+
|
9 |
+
EXPOSE 8080
|
10 |
+
|
11 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from flask import Flask, render_template, request, jsonify
|
3 |
+
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
import torch
|
5 |
+
|
6 |
+
app = Flask(__name__)
|
7 |
+
|
8 |
+
# Configuration # Directory containing model files
|
9 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
MAX_LENGTH = 512
|
11 |
+
BERT_TOKENIZER = 'bert-base-uncased'
|
12 |
+
ROBERTA_TOKENIZER = 'jcblaise/roberta-tagalog-base'
|
13 |
+
ELECTRA_TOKENIZER = 'google/electra-base-discriminator'
|
14 |
+
|
15 |
+
|
16 |
+
LABELS = ["fake", "real"]
|
17 |
+
|
18 |
+
class Classifier:
|
19 |
+
def __init__(self, model_path, device, tokenizer_name):
|
20 |
+
self.device = device
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
22 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
23 |
+
model_path,
|
24 |
+
local_files_only=True,
|
25 |
+
device_map=device
|
26 |
+
)
|
27 |
+
self.model.eval()
|
28 |
+
|
29 |
+
def predict(self, text):
|
30 |
+
"""Make prediction for a single text"""
|
31 |
+
# Tokenize
|
32 |
+
inputs = self.tokenizer(
|
33 |
+
text,
|
34 |
+
truncation=True,
|
35 |
+
max_length=MAX_LENGTH,
|
36 |
+
padding=True,
|
37 |
+
return_tensors="pt"
|
38 |
+
).to(self.device)
|
39 |
+
|
40 |
+
# Get prediction
|
41 |
+
with torch.no_grad():
|
42 |
+
outputs = self.model(**inputs)
|
43 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
44 |
+
predicted_class = torch.argmax(probabilities, dim=-1).item()
|
45 |
+
confidence_scores = probabilities[0].tolist()
|
46 |
+
|
47 |
+
# Format results
|
48 |
+
result = {
|
49 |
+
'predicted_class': LABELS[predicted_class],
|
50 |
+
'confidence_scores': {
|
51 |
+
label: score
|
52 |
+
for label, score in zip(LABELS, confidence_scores)
|
53 |
+
}
|
54 |
+
}
|
55 |
+
return result
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
@app.route('/')
|
60 |
+
def home():
|
61 |
+
return render_template('index.html')
|
62 |
+
|
63 |
+
@app.route('/detect', methods=['POST'])
|
64 |
+
def detect():
|
65 |
+
|
66 |
+
try:
|
67 |
+
data = request.get_json()
|
68 |
+
news_text = data.get('text')
|
69 |
+
|
70 |
+
model_chosen = data.get('model')
|
71 |
+
|
72 |
+
print(model_chosen)
|
73 |
+
|
74 |
+
if not news_text:
|
75 |
+
return jsonify({
|
76 |
+
'status': 'error',
|
77 |
+
'message': 'No text provided'
|
78 |
+
}), 400
|
79 |
+
|
80 |
+
switch={
|
81 |
+
'nonaug-bert':'bert-nonaug',
|
82 |
+
'aug-bert':'bert-aug',
|
83 |
+
'nonaug-tagbert':'tagbert-nonaug',
|
84 |
+
'aug-tagbert':'tagbert-aug',
|
85 |
+
'nonaug-electra':'electra-nonaug',
|
86 |
+
'aug-electra':'electra-aug'
|
87 |
+
}
|
88 |
+
|
89 |
+
model_p = switch.get(model_chosen)
|
90 |
+
|
91 |
+
print("model",model_p)
|
92 |
+
|
93 |
+
MODEL_PATH = Path("D:\\Aplil\\skibidi-thesis\\webapp", model_p)
|
94 |
+
|
95 |
+
|
96 |
+
print(MODEL_PATH)
|
97 |
+
|
98 |
+
tokenizer = model_chosen.split("-")[1]
|
99 |
+
|
100 |
+
tokenizer_chosen = {
|
101 |
+
'bert':BERT_TOKENIZER,
|
102 |
+
'tagbert':ROBERTA_TOKENIZER,
|
103 |
+
'electra':ELECTRA_TOKENIZER
|
104 |
+
}
|
105 |
+
|
106 |
+
print(tokenizer)
|
107 |
+
|
108 |
+
classifier = Classifier(MODEL_PATH,DEVICE,tokenizer_chosen.get(tokenizer))
|
109 |
+
|
110 |
+
result = classifier.predict(news_text)
|
111 |
+
print(result['confidence_scores'])
|
112 |
+
|
113 |
+
|
114 |
+
if result['predicted_class'] == "fake":
|
115 |
+
out = "News Needs Further Validation"
|
116 |
+
else:
|
117 |
+
out = "News is Real"
|
118 |
+
|
119 |
+
|
120 |
+
return jsonify({
|
121 |
+
'status': 'success',
|
122 |
+
'prediction': out,
|
123 |
+
'confidence':result['confidence_scores']
|
124 |
+
})
|
125 |
+
|
126 |
+
except Exception as e:
|
127 |
+
return jsonify({
|
128 |
+
'status': 'error',
|
129 |
+
'message': str(e)
|
130 |
+
}), 400
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
app.run(debug=True)
|
model.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
import numpy as np
|
5 |
+
from sklearn.metrics import accuracy_score,recall_score,precision_score,f1_score
|
6 |
+
from transformers import BertTokenizer, BertForSequenceClassification, Trainer,TrainingArguments
|
7 |
+
|
8 |
+
# no augment dataset
|
9 |
+
# df = df = pd.read_csv(r".\train_set.csv")
|
10 |
+
|
11 |
+
# with augment training dataset
|
12 |
+
df = pd.read_csv(r".\cleaned_combined_aug_set.csv")
|
13 |
+
# df.info()
|
14 |
+
value_counts = df['label'].value_counts()
|
15 |
+
print(value_counts)
|
16 |
+
|
17 |
+
|
18 |
+
test_df = pd.read_csv(r".\test_set.csv")
|
19 |
+
# test_df.info()
|
20 |
+
test_df['label'].value_counts()
|
21 |
+
|
22 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
23 |
+
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2)
|
24 |
+
|
25 |
+
model = model.to('cuda')
|
26 |
+
|
27 |
+
# independent var
|
28 |
+
X = list(df['article'])
|
29 |
+
X_test = list(test_df['article'])
|
30 |
+
|
31 |
+
#dependent
|
32 |
+
y= list(df['label'])
|
33 |
+
y_test = list(test_df['label'])
|
34 |
+
|
35 |
+
max_length = 512
|
36 |
+
train_encodings = tokenizer(X, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt')
|
37 |
+
test_encodings = tokenizer(X_test, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt')
|
38 |
+
|
39 |
+
class CustomDataset(Dataset):
|
40 |
+
def __init__(self, encodings, labels):
|
41 |
+
self.encodings = encodings
|
42 |
+
self.labels = labels
|
43 |
+
|
44 |
+
def __getitem__(self, idx):
|
45 |
+
item = {key: val[idx] for key, val in self.encodings.items()}
|
46 |
+
item['labels'] = torch.tensor(self.labels[idx])
|
47 |
+
return item
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return len(self.labels)
|
51 |
+
|
52 |
+
torch_train_dataset = CustomDataset(train_encodings,y)
|
53 |
+
torch_test_dataset = CustomDataset(test_encodings,y_test)
|
54 |
+
|
55 |
+
training_args = TrainingArguments(
|
56 |
+
output_dir='./results/fake-news-bert-aug',
|
57 |
+
evaluation_strategy='epoch',
|
58 |
+
learning_rate=2e-5,
|
59 |
+
per_device_train_batch_size=16,
|
60 |
+
per_device_eval_batch_size=16,
|
61 |
+
num_train_epochs=3
|
62 |
+
)
|
63 |
+
|
64 |
+
def compute_metrics(p):
|
65 |
+
print(type(p))
|
66 |
+
pred, labels = p
|
67 |
+
pred = np.argmax(pred,axis=1)
|
68 |
+
|
69 |
+
accuracy = accuracy_score(y_true=labels,y_pred=pred)
|
70 |
+
recall = recall_score(y_true=labels,y_pred=pred)
|
71 |
+
precision = precision_score(y_true=labels,y_pred=pred)
|
72 |
+
f1 = f1_score(y_true=labels,y_pred=pred)
|
73 |
+
|
74 |
+
return {"accuracy":accuracy,"precision":precision,"recall":recall,"f1":f1}
|
75 |
+
|
76 |
+
trainer = Trainer(
|
77 |
+
model=model,
|
78 |
+
args=training_args,
|
79 |
+
train_dataset=torch_train_dataset,
|
80 |
+
eval_dataset=torch_test_dataset,
|
81 |
+
compute_metrics=compute_metrics
|
82 |
+
)
|
83 |
+
|
84 |
+
trainer.train()
|
85 |
+
|
86 |
+
def predict(text):
|
87 |
+
return trainer.predict(text)
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
blinker==1.8.2
|
2 |
+
click==8.1.7
|
3 |
+
colorama==0.4.6
|
4 |
+
Flask==3.0.3
|
5 |
+
importlib_metadata==8.5.0
|
6 |
+
itsdangerous==2.2.0
|
7 |
+
Jinja2==3.1.4
|
8 |
+
MarkupSafe==2.1.5
|
9 |
+
Werkzeug==3.0.4
|
10 |
+
zipp==3.20.2
|
11 |
+
transformers==4.33.3
|
12 |
+
torch==2.4.1+cu118
|
templates/index.html
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8" />
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
6 |
+
<title>Fake News Detection using AugTagalog-BERT</title>
|
7 |
+
<link rel="icon" type="image/png" href="{{ url_for('static', filename='bert.png') }}" />
|
8 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
9 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css" />
|
10 |
+
</head>
|
11 |
+
<body class="min-h-screen flex flex-col justify-between font-sans bg-gray-100">
|
12 |
+
<header class="bg-gray-800 w-full py-2 flex items-center justify-start">
|
13 |
+
<div class="flex items-center space-x-4 ml-4">
|
14 |
+
<img src="{{ url_for('static', filename='bert.png') }}" alt="BERT Logo" class="w-8 h-8" />
|
15 |
+
<h1 class="text-white text-md font-bold">Fake News Detection using AugTagalog-BERT</h1>
|
16 |
+
</div>
|
17 |
+
</header>
|
18 |
+
|
19 |
+
<div class="flex-grow flex items-center justify-center px-10 py-12">
|
20 |
+
<div class="grid grid-cols-1 md:grid-cols-2 w-full gap-12 max-w-6xl">
|
21 |
+
<div class="flex flex-col p-12 space-y-8 bg-white rounded-lg shadow-lg">
|
22 |
+
<h2 class="text-2xl font-semibold text-gray-800">Tagalog Fake News Classifier</h2>
|
23 |
+
|
24 |
+
<div>
|
25 |
+
<label for="models" class="block text-lg font-medium text-gray-600 mb-2">Choose a model:</label>
|
26 |
+
<div class="relative">
|
27 |
+
<select id="models" name="models" class="w-full bg-white text-gray-900 text-lg rounded-md border border-gray-300 focus:border-gray-500 focus:ring focus:ring-gray-200 py-3 pl-4 pr-10 appearance-none transition duration-200">
|
28 |
+
<option value="nonaug-bert">Non-Augmented BERT Model</option>
|
29 |
+
<option value="aug-bert">Augmented BERT Model</option>
|
30 |
+
<option value="nonaug-tagbert">Non-Augmented Tagalog-RoBERTa Model</option>
|
31 |
+
<option value="aug-tagbert" selected>Augmented Tagalog-RoBERTa Model</option>
|
32 |
+
<option value="nonaug-electra">Non-Augmented ELECTRA</option>
|
33 |
+
<option value="aug-electra">Augmented ELECTRA</option>
|
34 |
+
</select>
|
35 |
+
<div class="absolute inset-y-0 right-0 flex items-center pr-4 pointer-events-none">
|
36 |
+
<i class="fas fa-chevron-down text-gray-500"></i>
|
37 |
+
</div>
|
38 |
+
</div>
|
39 |
+
</div>
|
40 |
+
|
41 |
+
<div class="relative w-full">
|
42 |
+
<label for="newsInput" class="block text-lg font-medium text-gray-600 mb-2">Input News:</label>
|
43 |
+
<textarea
|
44 |
+
id="newsInput"
|
45 |
+
class="h-40 w-full border-2 border-gray-300 rounded-lg pl-4 pr-4 py-3 focus:outline-none focus:ring-2 focus:ring-gray-500 focus:border-gray-500 transition duration-200"
|
46 |
+
placeholder="Paste your text here..."
|
47 |
+
rows="6"
|
48 |
+
></textarea>
|
49 |
+
</div>
|
50 |
+
|
51 |
+
<div class="flex justify-center">
|
52 |
+
<button
|
53 |
+
id="detectBtn"
|
54 |
+
class="bg-gray-800 text-white font-semibold py-3 px-8 rounded-lg hover:bg-gray-600 transition duration-300"
|
55 |
+
>
|
56 |
+
Detect
|
57 |
+
</button>
|
58 |
+
</div>
|
59 |
+
</div>
|
60 |
+
|
61 |
+
<div class="flex flex-col justify-center p-12 bg-white rounded-lg shadow-lg">
|
62 |
+
<div id="resultContainer" class="opacity-0 transition-opacity duration-500 h-full flex flex-col justify-center">
|
63 |
+
<div class="p-8 bg-gradient-to-b from-blue-50 to-white rounded-lg shadow-md">
|
64 |
+
<h2 class="text-3xl font-semibold mb-6 text-center text-gray-700">Result</h2>
|
65 |
+
<p id="result" class="text-center text-lg font-semibold p-4 rounded-lg border text-gray-800"></p>
|
66 |
+
|
67 |
+
<div class="mt-8">
|
68 |
+
<h3 class="text-lg font-bold text-center text-gray-700 mb-4">Confidence Levels</h3>
|
69 |
+
<div class="grid grid-cols-2 gap-4">
|
70 |
+
<div class="p-4 bg-red-100 rounded-lg shadow-sm text-center">
|
71 |
+
<h4 class="font-semibold text-red-600">Fake</h4>
|
72 |
+
<p id="fake" class="text-lg font-bold text-red-700">0%</p>
|
73 |
+
</div>
|
74 |
+
<div class="p-4 bg-green-100 rounded-lg shadow-sm text-center">
|
75 |
+
<h4 class="font-semibold text-green-600">Real</h4>
|
76 |
+
<p id="real" class="text-lg font-bold text-green-700">0%</p>
|
77 |
+
</div>
|
78 |
+
</div>
|
79 |
+
</div>
|
80 |
+
</div>
|
81 |
+
</div>
|
82 |
+
|
83 |
+
<div id="loadingSpinner" class="hidden flex justify-center items-center h-full">
|
84 |
+
<div class="flex flex-col items-center">
|
85 |
+
<div class="animate-spin rounded-full h-12 w-12 border-b-4 border-gray-600"></div>
|
86 |
+
<p class="mt-4 text-gray-600 font-semibold">Detecting...</p>
|
87 |
+
</div>
|
88 |
+
</div>
|
89 |
+
</div>
|
90 |
+
</div>
|
91 |
+
</div>
|
92 |
+
|
93 |
+
<footer class="text-center py-4 bg-gray-800 w-full shadow-inner">
|
94 |
+
<p class="text-white text-sm">
|
95 |
+
© 2024 | <span class="font-semibold">J. Embolode, A. Kuan, A. Linaza</span>
|
96 |
+
</p>
|
97 |
+
</footer>
|
98 |
+
|
99 |
+
<script>
|
100 |
+
document.getElementById("detectBtn").addEventListener("click", function () {
|
101 |
+
const newsInput = document.getElementById("newsInput").value;
|
102 |
+
const model = document.getElementById("models").value;
|
103 |
+
const loadingSpinner = document.getElementById("loadingSpinner");
|
104 |
+
const resultContainer = document.getElementById("resultContainer");
|
105 |
+
const resultText = document.getElementById("result");
|
106 |
+
const confidenceFake = document.getElementById("fake");
|
107 |
+
const confidenceReal = document.getElementById("real");
|
108 |
+
|
109 |
+
if (newsInput.trim() === "") {
|
110 |
+
alert("Please enter text.");
|
111 |
+
return;
|
112 |
+
}
|
113 |
+
|
114 |
+
loadingSpinner.classList.remove("hidden");
|
115 |
+
resultContainer.style.opacity = 0;
|
116 |
+
|
117 |
+
fetch("/detect", {
|
118 |
+
method: "POST",
|
119 |
+
headers: {
|
120 |
+
"Content-Type": "application/json",
|
121 |
+
},
|
122 |
+
body: JSON.stringify({ text: newsInput, model: model }),
|
123 |
+
})
|
124 |
+
.then((response) => response.json())
|
125 |
+
.then((data) => {
|
126 |
+
loadingSpinner.classList.add("hidden");
|
127 |
+
resultContainer.style.opacity = 1;
|
128 |
+
|
129 |
+
if (data.status === "error") {
|
130 |
+
resultText.textContent = data.message;
|
131 |
+
resultText.classList.add("text-red-500");
|
132 |
+
resultText.classList.remove("text-green-500");
|
133 |
+
} else {
|
134 |
+
resultText.innerHTML = data.prediction;
|
135 |
+
|
136 |
+
if (data.prediction === "News Needs Further Validation") {
|
137 |
+
resultText.classList.add("text-red-500");
|
138 |
+
resultText.classList.remove("text-green-500");
|
139 |
+
} else {
|
140 |
+
resultText.classList.add("text-green-500");
|
141 |
+
resultText.classList.remove("text-red-500");
|
142 |
+
}
|
143 |
+
|
144 |
+
confidenceFake.textContent = (data.confidence.fake * 100).toFixed(2) + "%";
|
145 |
+
confidenceReal.textContent = (data.confidence.real * 100).toFixed(2) + "%";
|
146 |
+
}
|
147 |
+
})
|
148 |
+
.catch((error) => {
|
149 |
+
loadingSpinner.classList.add("hidden");
|
150 |
+
console.error("Error:", error);
|
151 |
+
});
|
152 |
+
});
|
153 |
+
</script>
|
154 |
+
</body>
|
155 |
+
</html>
|