Update llm_handler.py
Browse files- llm_handler.py +117 -117
llm_handler.py
CHANGED
@@ -1,140 +1,140 @@
|
|
1 |
-
from huggingface_hub import InferenceClient
|
2 |
import pandas as pd
|
3 |
import re
|
4 |
-
from pathlib import Path
|
5 |
import os
|
|
|
|
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
-
class
|
9 |
def __init__(self):
|
10 |
load_dotenv()
|
11 |
self.client = InferenceClient(api_key=os.getenv("HF_TOK_KEY"))
|
12 |
self.model = "Qwen/Qwen2.5-Coder-32B-Instruct"
|
13 |
-
self.
|
14 |
-
self.
|
15 |
-
|
16 |
-
|
|
|
17 |
"""
|
18 |
-
|
19 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
# Generate knowledge prompt
|
21 |
-
know_prompt = self.
|
22 |
-
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
#
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
"""
|
41 |
-
|
42 |
"""
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
prompt = ("Supposed that you are now an [[ HIGHLY EXPERIENCED NETWORK TRAFFIC DATA ANALYSIS EXPERT ]] . "
|
53 |
-
"You need to help me analyze the data in the DDoS dataset and determine whether the data is [[ DDoS traffic ]] or [[ normal traffic ]] ."
|
54 |
-
"Next, I will give you the maximum, minimum, median, mean, and variance of all the data under each label or columns present in the data set, which may help you make your judgment."
|
55 |
-
"DO DEEP ANALYSIS ITS YOUR WORKS AND PRROIVDE ACCURATE ANSWERS ALONGWITH GIVEN TASK ::\n\n")
|
56 |
-
|
57 |
-
for column in df_without_label.columns:
|
58 |
-
prompt += (f"{column}: max={stats['max'][column]:.1f}, "
|
59 |
-
f"min={stats['min'][column]:.1f}, "
|
60 |
-
f"median={stats['median'][column]:.1f}, "
|
61 |
-
f"mean={stats['mean'][column]:.1f}, "
|
62 |
-
f"variance={stats['variance'][column]:.1f}\n")
|
63 |
-
|
64 |
return prompt
|
65 |
-
|
66 |
-
def
|
67 |
"""
|
68 |
-
|
69 |
"""
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
# Get LLM response
|
76 |
-
messages = [
|
77 |
-
{'role': 'user', 'content': know_prompt},
|
78 |
-
{'role': 'user', 'content': row_prompt}
|
79 |
-
]
|
80 |
-
|
81 |
-
completion = self.client.chat.completions.create(
|
82 |
-
model=self.model,
|
83 |
-
messages=messages,
|
84 |
-
max_tokens=10000
|
85 |
)
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
'attack': probabilities[0],
|
94 |
-
'benign': probabilities[1],
|
95 |
-
'original': response
|
96 |
-
}
|
97 |
-
|
98 |
-
def _generate_row_prompt(self, row):
|
99 |
"""
|
100 |
-
|
101 |
"""
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
"""
|
113 |
-
Extract probabilities from LLM response
|
114 |
"""
|
115 |
pattern = r'\[(.*?)\]'
|
116 |
-
|
117 |
-
|
118 |
-
if
|
119 |
-
probs =
|
120 |
-
return [float(p) for p in probs]
|
121 |
-
return
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
completion = self.client.chat.completions.create(
|
129 |
-
model=self.model,
|
130 |
-
messages=messages,
|
131 |
-
max_tokens=10000
|
132 |
-
)
|
133 |
-
return completion.choices[0].message.content
|
134 |
-
|
135 |
-
def _get_label_column(self, df):
|
136 |
-
"""
|
137 |
-
Get the label column name
|
138 |
-
"""
|
139 |
-
potential_labels = [col for col in df.columns if ' Label' in col.lower()]
|
140 |
-
return potential_labels[0] if potential_labels else None
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import re
|
|
|
3 |
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from huggingface_hub import InferenceClient
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
+
class DDoSInference:
|
9 |
def __init__(self):
|
10 |
load_dotenv()
|
11 |
self.client = InferenceClient(api_key=os.getenv("HF_TOK_KEY"))
|
12 |
self.model = "Qwen/Qwen2.5-Coder-32B-Instruct"
|
13 |
+
self.dataset_path = Path("~/.dataset/original.csv").expanduser()
|
14 |
+
self.results_path = Path("~/.dataset/PROBABILITY_OF_EACH_ROW_DDOS_AND_BENGNIN.csv").expanduser()
|
15 |
+
self.results_path.parent.mkdir(parents=True, exist_ok=True)
|
16 |
+
|
17 |
+
def process_dataset(self):
|
18 |
"""
|
19 |
+
Processes the dataset row by row and performs inference using the LLM.
|
20 |
"""
|
21 |
+
if not self.dataset_path.exists():
|
22 |
+
raise FileNotFoundError("The preprocessed dataset file does not exist. Ensure it is generated using the processor.")
|
23 |
+
|
24 |
+
ddos_data = pd.read_csv(self.dataset_path)
|
25 |
+
|
26 |
+
label_column = " Label"
|
27 |
+
if label_column not in ddos_data.columns:
|
28 |
+
label_column = input("Enter the label column name in your dataset: ").strip()
|
29 |
+
if label_column not in ddos_data.columns:
|
30 |
+
raise ValueError(f"Label column '{label_column}' not found in the dataset.")
|
31 |
+
|
32 |
+
ddos_data_without_label = ddos_data.drop([label_column], axis=1)
|
33 |
+
stats = {
|
34 |
+
'Max': ddos_data_without_label.max(),
|
35 |
+
'Min': ddos_data_without_label.min(),
|
36 |
+
'Median': ddos_data_without_label.median(),
|
37 |
+
'Mean': ddos_data_without_label.mean(),
|
38 |
+
'Variance': ddos_data_without_label.var()
|
39 |
+
}
|
40 |
+
|
41 |
# Generate knowledge prompt
|
42 |
+
know_prompt = self.generate_knowledge_prompt(stats)
|
43 |
+
|
44 |
+
# Prepare results DataFrame
|
45 |
+
if self.results_path.exists():
|
46 |
+
predict_df = pd.read_csv(self.results_path)
|
47 |
+
else:
|
48 |
+
predict_df = pd.DataFrame(columns=["index", "attack", "benign", "original"])
|
49 |
+
|
50 |
+
start_index = predict_df.shape[0]
|
51 |
+
print(f"Starting inference from row {start_index}")
|
52 |
+
|
53 |
+
# Process rows for inference
|
54 |
+
for i in range(start_index, ddos_data.shape[0]):
|
55 |
+
row_prompt = self.generate_row_prompt(ddos_data.iloc[i])
|
56 |
+
probabilities = self.infer_row(know_prompt, row_prompt)
|
57 |
+
if probabilities:
|
58 |
+
predict_df.loc[i] = [i, *probabilities]
|
59 |
+
else:
|
60 |
+
predict_df.loc[i] = [i, "None", "None", "No valid response"]
|
61 |
+
|
62 |
+
# Save after each row for resilience
|
63 |
+
predict_df.to_csv(self.results_path, index=False)
|
64 |
+
|
65 |
+
print(f"Processed row {i}: {predict_df.loc[i].to_dict()}")
|
66 |
+
|
67 |
+
print("Inference complete. Results saved at:", self.results_path)
|
68 |
+
|
69 |
+
def generate_knowledge_prompt(self, stats):
|
70 |
"""
|
71 |
+
Generates the knowledge prompt based on dataset statistics.
|
72 |
"""
|
73 |
+
prompt = (
|
74 |
+
"Supposed that you are now an [[ HIGHLY EXPERIENCED NETWORK TRAFFIC DATA ANALYSIS EXPERT ]]. "
|
75 |
+
"You need to help me analyze the data in the DDoS dataset and determine whether the data is [[ DDoS traffic ]] or [[ normal traffic ]]. "
|
76 |
+
"Here are the maximum, minimum, median, mean, and variance of each column in the dataset to help your judgment:\n"
|
77 |
+
)
|
78 |
+
|
79 |
+
for col, values in stats.items():
|
80 |
+
prompt += f"{col}: max={values['Max']:.2f}, min={values['Min']:.2f}, median={values['Median']:.2f}, mean={values['Mean']:.2f}, variance={values['Variance']:.2f}\n"
|
81 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
return prompt
|
83 |
+
|
84 |
+
def generate_row_prompt(self, row):
|
85 |
"""
|
86 |
+
Generates a row-specific prompt for the LLM.
|
87 |
"""
|
88 |
+
row_prompt = (
|
89 |
+
"Next, I will give you a piece of data about network traffic information. "
|
90 |
+
"You need to tell me the probability of this data being DDoS traffic or normal traffic. "
|
91 |
+
"Express the probability in the format [0.xxx, 0.xxx], where the first number represents DDoS probability and the second represents normal traffic probability. "
|
92 |
+
"Ensure that the sum of probabilities is exactly 1.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
)
|
94 |
+
|
95 |
+
for col, val in row.items():
|
96 |
+
row_prompt += f"{col}: {val}, "
|
97 |
+
|
98 |
+
return row_prompt.strip(', ')
|
99 |
+
|
100 |
+
def infer_row(self, know_prompt, row_prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
"""
|
102 |
+
Performs inference for a single row using the LLM.
|
103 |
"""
|
104 |
+
try:
|
105 |
+
messages = [
|
106 |
+
{'role': 'user', 'content': know_prompt},
|
107 |
+
{'role': 'user', 'content': row_prompt}
|
108 |
+
]
|
109 |
+
|
110 |
+
completion = self.client.chat.completions.create(
|
111 |
+
model=self.model,
|
112 |
+
messages=messages,
|
113 |
+
max_tokens=1000
|
114 |
+
)
|
115 |
+
|
116 |
+
response = completion.choices[0].message.content
|
117 |
+
probabilities = self.extract_probabilities(response)
|
118 |
+
return probabilities
|
119 |
+
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error during inference for row: {e}")
|
122 |
+
return None
|
123 |
+
|
124 |
+
def extract_probabilities(self, response):
|
125 |
"""
|
126 |
+
Extract probabilities from the LLM response using regex.
|
127 |
"""
|
128 |
pattern = r'\[(.*?)\]'
|
129 |
+
match = re.search(pattern, response)
|
130 |
+
|
131 |
+
if match:
|
132 |
+
probs = match.group(1).split(',')
|
133 |
+
return [float(p.strip()) for p in probs if p.strip()]
|
134 |
+
return None
|
135 |
+
|
136 |
+
# Example usage
|
137 |
+
if __name__ == "__main__":
|
138 |
+
handler = DDoSInference()
|
139 |
+
handler.process_dataset()
|
140 |
+
print("You can now interact with the model for mitigation steps or download the results.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|