Canstralian commited on
Commit
5cda0af
·
verified ·
1 Parent(s): 5d7c58c

Create test_streamlit_app.py

Browse files
Files changed (1) hide show
  1. tests/test_streamlit_app.py +66 -0
tests/test_streamlit_app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import patch, MagicMock
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import streamlit as st
6
+ import io
7
+
8
+ class TestStreamlitApp(unittest.TestCase):
9
+
10
+ @patch("transformers.AutoTokenizer.from_pretrained")
11
+ @patch("transformers.AutoModelForSequenceClassification.from_pretrained")
12
+ def test_load_model_success(self, mock_model, mock_tokenizer):
13
+ # Mock the tokenizer and model loading
14
+ mock_tokenizer.return_value = MagicMock(spec=AutoTokenizer)
15
+ mock_model.return_value = MagicMock(spec=AutoModelForSequenceClassification)
16
+
17
+ tokenizer, model = load_model("Canstralian/CyberAttackDetection")
18
+
19
+ # Assert that the tokenizer and model are not None
20
+ self.assertIsNotNone(tokenizer)
21
+ self.assertIsNotNone(model)
22
+ mock_tokenizer.assert_called_once_with("Canstralian/CyberAttackDetection")
23
+ mock_model.assert_called_once_with("Canstralian/CyberAttackDetection")
24
+
25
+ @patch("transformers.AutoTokenizer.from_pretrained")
26
+ @patch("transformers.AutoModelForSequenceClassification.from_pretrained")
27
+ def test_predict_classification(self, mock_model, mock_tokenizer):
28
+ # Mock the tokenizer and model for inference
29
+ mock_tokenizer.return_value = MagicMock(spec=AutoTokenizer)
30
+ mock_model.return_value = MagicMock(spec=AutoModelForSequenceClassification)
31
+
32
+ # Simulate model outputs
33
+ mock_model.return_value.__call__.return_value = MagicMock(logits=torch.tensor([[1.0, 2.0, 3.0]]))
34
+
35
+ # Call the prediction function
36
+ inputs = mock_tokenizer("Test input", return_tensors="pt", padding=True, truncation=True)
37
+ with torch.no_grad():
38
+ outputs = mock_model.return_value(**inputs)
39
+ logits = outputs.logits
40
+ predicted_class = torch.argmax(logits, dim=-1).item()
41
+
42
+ # Assert that the predicted class is correct
43
+ self.assertEqual(predicted_class, 2) # The class with the highest score (index 2)
44
+
45
+ @patch("transformers.AutoTokenizer.from_pretrained")
46
+ @patch("transformers.AutoModelForSeq2SeqLM.from_pretrained")
47
+ def test_generate_shell_command(self, mock_model, mock_tokenizer):
48
+ # Mock the tokenizer and model for shell command generation
49
+ mock_tokenizer.return_value = MagicMock(spec=AutoTokenizer)
50
+ mock_model.return_value = MagicMock(spec=AutoModelForSeq2SeqLM)
51
+
52
+ # Simulate model output (fake shell command)
53
+ mock_model.return_value.generate.return_value = torch.tensor([[1, 2, 3, 4]])
54
+
55
+ # Simulate text input
56
+ user_input = "Create a directory"
57
+ inputs = mock_tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
58
+ with torch.no_grad():
59
+ outputs = mock_model.return_value.generate(**inputs)
60
+ generated_command = mock_tokenizer.decode(outputs[0], skip_special_tokens=True)
61
+
62
+ # Assert the generated command is as expected
63
+ self.assertEqual(generated_command, "mkdir directory") # Assuming the model generates this
64
+
65
+ if __name__ == "__main__":
66
+ unittest.main()