Avril Lalaine commited on
Commit
0ad9aa8
·
1 Parent(s): 28bff37

Add flask app with dockerfire

Browse files
Files changed (5) hide show
  1. Dockerfile +11 -0
  2. app.py +133 -0
  3. model.py +87 -0
  4. requirements.txt +12 -0
  5. templates/index.html +155 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . .
6
+
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ EXPOSE 8080
10
+
11
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from flask import Flask, render_template, request, jsonify
3
+ from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+
6
+ app = Flask(__name__)
7
+
8
+ # Configuration # Directory containing model files
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ MAX_LENGTH = 512
11
+ BERT_TOKENIZER = 'bert-base-uncased'
12
+ ROBERTA_TOKENIZER = 'jcblaise/roberta-tagalog-base'
13
+ ELECTRA_TOKENIZER = 'google/electra-base-discriminator'
14
+
15
+
16
+ LABELS = ["fake", "real"]
17
+
18
+ class Classifier:
19
+ def __init__(self, model_path, device, tokenizer_name):
20
+ self.device = device
21
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
22
+ self.model = AutoModelForSequenceClassification.from_pretrained(
23
+ model_path,
24
+ local_files_only=True,
25
+ device_map=device
26
+ )
27
+ self.model.eval()
28
+
29
+ def predict(self, text):
30
+ """Make prediction for a single text"""
31
+ # Tokenize
32
+ inputs = self.tokenizer(
33
+ text,
34
+ truncation=True,
35
+ max_length=MAX_LENGTH,
36
+ padding=True,
37
+ return_tensors="pt"
38
+ ).to(self.device)
39
+
40
+ # Get prediction
41
+ with torch.no_grad():
42
+ outputs = self.model(**inputs)
43
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
44
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
45
+ confidence_scores = probabilities[0].tolist()
46
+
47
+ # Format results
48
+ result = {
49
+ 'predicted_class': LABELS[predicted_class],
50
+ 'confidence_scores': {
51
+ label: score
52
+ for label, score in zip(LABELS, confidence_scores)
53
+ }
54
+ }
55
+ return result
56
+
57
+
58
+
59
+ @app.route('/')
60
+ def home():
61
+ return render_template('index.html')
62
+
63
+ @app.route('/detect', methods=['POST'])
64
+ def detect():
65
+
66
+ try:
67
+ data = request.get_json()
68
+ news_text = data.get('text')
69
+
70
+ model_chosen = data.get('model')
71
+
72
+ print(model_chosen)
73
+
74
+ if not news_text:
75
+ return jsonify({
76
+ 'status': 'error',
77
+ 'message': 'No text provided'
78
+ }), 400
79
+
80
+ switch={
81
+ 'nonaug-bert':'bert-nonaug',
82
+ 'aug-bert':'bert-aug',
83
+ 'nonaug-tagbert':'tagbert-nonaug',
84
+ 'aug-tagbert':'tagbert-aug',
85
+ 'nonaug-electra':'electra-nonaug',
86
+ 'aug-electra':'electra-aug'
87
+ }
88
+
89
+ model_p = switch.get(model_chosen)
90
+
91
+ print("model",model_p)
92
+
93
+ MODEL_PATH = Path("D:\\Aplil\\skibidi-thesis\\webapp", model_p)
94
+
95
+
96
+ print(MODEL_PATH)
97
+
98
+ tokenizer = model_chosen.split("-")[1]
99
+
100
+ tokenizer_chosen = {
101
+ 'bert':BERT_TOKENIZER,
102
+ 'tagbert':ROBERTA_TOKENIZER,
103
+ 'electra':ELECTRA_TOKENIZER
104
+ }
105
+
106
+ print(tokenizer)
107
+
108
+ classifier = Classifier(MODEL_PATH,DEVICE,tokenizer_chosen.get(tokenizer))
109
+
110
+ result = classifier.predict(news_text)
111
+ print(result['confidence_scores'])
112
+
113
+
114
+ if result['predicted_class'] == "fake":
115
+ out = "News Needs Further Validation"
116
+ else:
117
+ out = "News is Real"
118
+
119
+
120
+ return jsonify({
121
+ 'status': 'success',
122
+ 'prediction': out,
123
+ 'confidence':result['confidence_scores']
124
+ })
125
+
126
+ except Exception as e:
127
+ return jsonify({
128
+ 'status': 'error',
129
+ 'message': str(e)
130
+ }), 400
131
+
132
+ if __name__ == '__main__':
133
+ app.run(debug=True)
model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ import numpy as np
5
+ from sklearn.metrics import accuracy_score,recall_score,precision_score,f1_score
6
+ from transformers import BertTokenizer, BertForSequenceClassification, Trainer,TrainingArguments
7
+
8
+ # no augment dataset
9
+ # df = df = pd.read_csv(r".\train_set.csv")
10
+
11
+ # with augment training dataset
12
+ df = pd.read_csv(r".\cleaned_combined_aug_set.csv")
13
+ # df.info()
14
+ value_counts = df['label'].value_counts()
15
+ print(value_counts)
16
+
17
+
18
+ test_df = pd.read_csv(r".\test_set.csv")
19
+ # test_df.info()
20
+ test_df['label'].value_counts()
21
+
22
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
23
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2)
24
+
25
+ model = model.to('cuda')
26
+
27
+ # independent var
28
+ X = list(df['article'])
29
+ X_test = list(test_df['article'])
30
+
31
+ #dependent
32
+ y= list(df['label'])
33
+ y_test = list(test_df['label'])
34
+
35
+ max_length = 512
36
+ train_encodings = tokenizer(X, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt')
37
+ test_encodings = tokenizer(X_test, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt')
38
+
39
+ class CustomDataset(Dataset):
40
+ def __init__(self, encodings, labels):
41
+ self.encodings = encodings
42
+ self.labels = labels
43
+
44
+ def __getitem__(self, idx):
45
+ item = {key: val[idx] for key, val in self.encodings.items()}
46
+ item['labels'] = torch.tensor(self.labels[idx])
47
+ return item
48
+
49
+ def __len__(self):
50
+ return len(self.labels)
51
+
52
+ torch_train_dataset = CustomDataset(train_encodings,y)
53
+ torch_test_dataset = CustomDataset(test_encodings,y_test)
54
+
55
+ training_args = TrainingArguments(
56
+ output_dir='./results/fake-news-bert-aug',
57
+ evaluation_strategy='epoch',
58
+ learning_rate=2e-5,
59
+ per_device_train_batch_size=16,
60
+ per_device_eval_batch_size=16,
61
+ num_train_epochs=3
62
+ )
63
+
64
+ def compute_metrics(p):
65
+ print(type(p))
66
+ pred, labels = p
67
+ pred = np.argmax(pred,axis=1)
68
+
69
+ accuracy = accuracy_score(y_true=labels,y_pred=pred)
70
+ recall = recall_score(y_true=labels,y_pred=pred)
71
+ precision = precision_score(y_true=labels,y_pred=pred)
72
+ f1 = f1_score(y_true=labels,y_pred=pred)
73
+
74
+ return {"accuracy":accuracy,"precision":precision,"recall":recall,"f1":f1}
75
+
76
+ trainer = Trainer(
77
+ model=model,
78
+ args=training_args,
79
+ train_dataset=torch_train_dataset,
80
+ eval_dataset=torch_test_dataset,
81
+ compute_metrics=compute_metrics
82
+ )
83
+
84
+ trainer.train()
85
+
86
+ def predict(text):
87
+ return trainer.predict(text)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blinker==1.8.2
2
+ click==8.1.7
3
+ colorama==0.4.6
4
+ Flask==3.0.3
5
+ importlib_metadata==8.5.0
6
+ itsdangerous==2.2.0
7
+ Jinja2==3.1.4
8
+ MarkupSafe==2.1.5
9
+ Werkzeug==3.0.4
10
+ zipp==3.20.2
11
+ transformers==4.33.3
12
+ torch==2.4.1+cu118
templates/index.html ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>Fake News Detection using AugTagalog-BERT</title>
7
+ <link rel="icon" type="image/png" href="{{ url_for('static', filename='bert.png') }}" />
8
+ <script src="https://cdn.tailwindcss.com"></script>
9
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css" />
10
+ </head>
11
+ <body class="min-h-screen flex flex-col justify-between font-sans bg-gray-100">
12
+ <header class="bg-gray-800 w-full py-2 flex items-center justify-start">
13
+ <div class="flex items-center space-x-4 ml-4">
14
+ <img src="{{ url_for('static', filename='bert.png') }}" alt="BERT Logo" class="w-8 h-8" />
15
+ <h1 class="text-white text-md font-bold">Fake News Detection using AugTagalog-BERT</h1>
16
+ </div>
17
+ </header>
18
+
19
+ <div class="flex-grow flex items-center justify-center px-10 py-12">
20
+ <div class="grid grid-cols-1 md:grid-cols-2 w-full gap-12 max-w-6xl">
21
+ <div class="flex flex-col p-12 space-y-8 bg-white rounded-lg shadow-lg">
22
+ <h2 class="text-2xl font-semibold text-gray-800">Tagalog Fake News Classifier</h2>
23
+
24
+ <div>
25
+ <label for="models" class="block text-lg font-medium text-gray-600 mb-2">Choose a model:</label>
26
+ <div class="relative">
27
+ <select id="models" name="models" class="w-full bg-white text-gray-900 text-lg rounded-md border border-gray-300 focus:border-gray-500 focus:ring focus:ring-gray-200 py-3 pl-4 pr-10 appearance-none transition duration-200">
28
+ <option value="nonaug-bert">Non-Augmented BERT Model</option>
29
+ <option value="aug-bert">Augmented BERT Model</option>
30
+ <option value="nonaug-tagbert">Non-Augmented Tagalog-RoBERTa Model</option>
31
+ <option value="aug-tagbert" selected>Augmented Tagalog-RoBERTa Model</option>
32
+ <option value="nonaug-electra">Non-Augmented ELECTRA</option>
33
+ <option value="aug-electra">Augmented ELECTRA</option>
34
+ </select>
35
+ <div class="absolute inset-y-0 right-0 flex items-center pr-4 pointer-events-none">
36
+ <i class="fas fa-chevron-down text-gray-500"></i>
37
+ </div>
38
+ </div>
39
+ </div>
40
+
41
+ <div class="relative w-full">
42
+ <label for="newsInput" class="block text-lg font-medium text-gray-600 mb-2">Input News:</label>
43
+ <textarea
44
+ id="newsInput"
45
+ class="h-40 w-full border-2 border-gray-300 rounded-lg pl-4 pr-4 py-3 focus:outline-none focus:ring-2 focus:ring-gray-500 focus:border-gray-500 transition duration-200"
46
+ placeholder="Paste your text here..."
47
+ rows="6"
48
+ ></textarea>
49
+ </div>
50
+
51
+ <div class="flex justify-center">
52
+ <button
53
+ id="detectBtn"
54
+ class="bg-gray-800 text-white font-semibold py-3 px-8 rounded-lg hover:bg-gray-600 transition duration-300"
55
+ >
56
+ Detect
57
+ </button>
58
+ </div>
59
+ </div>
60
+
61
+ <div class="flex flex-col justify-center p-12 bg-white rounded-lg shadow-lg">
62
+ <div id="resultContainer" class="opacity-0 transition-opacity duration-500 h-full flex flex-col justify-center">
63
+ <div class="p-8 bg-gradient-to-b from-blue-50 to-white rounded-lg shadow-md">
64
+ <h2 class="text-3xl font-semibold mb-6 text-center text-gray-700">Result</h2>
65
+ <p id="result" class="text-center text-lg font-semibold p-4 rounded-lg border text-gray-800"></p>
66
+
67
+ <div class="mt-8">
68
+ <h3 class="text-lg font-bold text-center text-gray-700 mb-4">Confidence Levels</h3>
69
+ <div class="grid grid-cols-2 gap-4">
70
+ <div class="p-4 bg-red-100 rounded-lg shadow-sm text-center">
71
+ <h4 class="font-semibold text-red-600">Fake</h4>
72
+ <p id="fake" class="text-lg font-bold text-red-700">0%</p>
73
+ </div>
74
+ <div class="p-4 bg-green-100 rounded-lg shadow-sm text-center">
75
+ <h4 class="font-semibold text-green-600">Real</h4>
76
+ <p id="real" class="text-lg font-bold text-green-700">0%</p>
77
+ </div>
78
+ </div>
79
+ </div>
80
+ </div>
81
+ </div>
82
+
83
+ <div id="loadingSpinner" class="hidden flex justify-center items-center h-full">
84
+ <div class="flex flex-col items-center">
85
+ <div class="animate-spin rounded-full h-12 w-12 border-b-4 border-gray-600"></div>
86
+ <p class="mt-4 text-gray-600 font-semibold">Detecting...</p>
87
+ </div>
88
+ </div>
89
+ </div>
90
+ </div>
91
+ </div>
92
+
93
+ <footer class="text-center py-4 bg-gray-800 w-full shadow-inner">
94
+ <p class="text-white text-sm">
95
+ © 2024 | <span class="font-semibold">J. Embolode, A. Kuan, A. Linaza</span>
96
+ </p>
97
+ </footer>
98
+
99
+ <script>
100
+ document.getElementById("detectBtn").addEventListener("click", function () {
101
+ const newsInput = document.getElementById("newsInput").value;
102
+ const model = document.getElementById("models").value;
103
+ const loadingSpinner = document.getElementById("loadingSpinner");
104
+ const resultContainer = document.getElementById("resultContainer");
105
+ const resultText = document.getElementById("result");
106
+ const confidenceFake = document.getElementById("fake");
107
+ const confidenceReal = document.getElementById("real");
108
+
109
+ if (newsInput.trim() === "") {
110
+ alert("Please enter text.");
111
+ return;
112
+ }
113
+
114
+ loadingSpinner.classList.remove("hidden");
115
+ resultContainer.style.opacity = 0;
116
+
117
+ fetch("/detect", {
118
+ method: "POST",
119
+ headers: {
120
+ "Content-Type": "application/json",
121
+ },
122
+ body: JSON.stringify({ text: newsInput, model: model }),
123
+ })
124
+ .then((response) => response.json())
125
+ .then((data) => {
126
+ loadingSpinner.classList.add("hidden");
127
+ resultContainer.style.opacity = 1;
128
+
129
+ if (data.status === "error") {
130
+ resultText.textContent = data.message;
131
+ resultText.classList.add("text-red-500");
132
+ resultText.classList.remove("text-green-500");
133
+ } else {
134
+ resultText.innerHTML = data.prediction;
135
+
136
+ if (data.prediction === "News Needs Further Validation") {
137
+ resultText.classList.add("text-red-500");
138
+ resultText.classList.remove("text-green-500");
139
+ } else {
140
+ resultText.classList.add("text-green-500");
141
+ resultText.classList.remove("text-red-500");
142
+ }
143
+
144
+ confidenceFake.textContent = (data.confidence.fake * 100).toFixed(2) + "%";
145
+ confidenceReal.textContent = (data.confidence.real * 100).toFixed(2) + "%";
146
+ }
147
+ })
148
+ .catch((error) => {
149
+ loadingSpinner.classList.add("hidden");
150
+ console.error("Error:", error);
151
+ });
152
+ });
153
+ </script>
154
+ </body>
155
+ </html>