import pandas as pd import re import os from pathlib import Path from huggingface_hub import InferenceClient from dotenv import load_dotenv class DDoSInference: def __init__(self): load_dotenv() self.client = InferenceClient(api_key=os.getenv("HF_TOK_KEY")) self.model = "Qwen/Qwen2.5-Coder-32B-Instruct" self.dataset_path = Path("~/.dataset/original.csv").expanduser() self.results_path = Path("~/.dataset/PROBABILITY_OF_EACH_ROW_DDOS_AND_BENGNIN.csv").expanduser() self.results_path.parent.mkdir(parents=True, exist_ok=True) def process_dataset(self): """ Processes the dataset row by row and performs inference using the LLM. """ if not self.dataset_path.exists(): raise FileNotFoundError("The preprocessed dataset file does not exist. Ensure it is generated using the processor.") ddos_data = pd.read_csv(self.dataset_path) label_column = " Label" if label_column not in ddos_data.columns: label_column = input("Enter the label column name in your dataset: ").strip() if label_column not in ddos_data.columns: raise ValueError(f"Label column '{label_column}' not found in the dataset.") ddos_data_without_label = ddos_data.drop([label_column], axis=1) stats = { 'Max': ddos_data_without_label.max(), 'Min': ddos_data_without_label.min(), 'Median': ddos_data_without_label.median(), 'Mean': ddos_data_without_label.mean(), 'Variance': ddos_data_without_label.var() } # Generate knowledge prompt know_prompt = self.generate_knowledge_prompt(stats) # Prepare results DataFrame if self.results_path.exists(): predict_df = pd.read_csv(self.results_path) else: predict_df = pd.DataFrame(columns=["index", "attack", "benign", "original"]) start_index = predict_df.shape[0] print(f"Starting inference from row {start_index}") # Process rows for inference for i in range(start_index, ddos_data.shape[0]): row_prompt = self.generate_row_prompt(ddos_data.iloc[i]) probabilities = self.infer_row(know_prompt, row_prompt) if probabilities: predict_df.loc[i] = [i, *probabilities] else: predict_df.loc[i] = [i, "None", "None", "No valid response"] # Save after each row for resilience predict_df.to_csv(self.results_path, index=False) print(f"Processed row {i}: {predict_df.loc[i].to_dict()}") print("Inference complete. Results saved at:", self.results_path) def generate_knowledge_prompt(self, stats): """ Generates the knowledge prompt based on dataset statistics. """ prompt = ( "Supposed that you are now an [[ HIGHLY EXPERIENCED NETWORK TRAFFIC DATA ANALYSIS EXPERT ]]. " "You need to help me analyze the data in the DDoS dataset and determine whether the data is [[ DDoS traffic ]] or [[ normal traffic ]]. " "Here are the maximum, minimum, median, mean, and variance of each column in the dataset to help your judgment:\n" ) for col, values in stats.items(): prompt += f"{col}: max={values['Max']:.2f}, min={values['Min']:.2f}, median={values['Median']:.2f}, mean={values['Mean']:.2f}, variance={values['Variance']:.2f}\n" return prompt def generate_row_prompt(self, row): """ Generates a row-specific prompt for the LLM. """ row_prompt = ( "Next, I will give you a piece of data about network traffic information. " "You need to tell me the probability of this data being DDoS traffic or normal traffic. " "Express the probability in the format [0.xxx, 0.xxx], where the first number represents DDoS probability and the second represents normal traffic probability. " "Ensure that the sum of probabilities is exactly 1.\n" ) for col, val in row.items(): row_prompt += f"{col}: {val}, " return row_prompt.strip(', ') def infer_row(self, know_prompt, row_prompt): """ Performs inference for a single row using the LLM. """ try: messages = [ {'role': 'user', 'content': know_prompt}, {'role': 'user', 'content': row_prompt} ] completion = self.client.chat.completions.create( model=self.model, messages=messages, max_tokens=1000 ) response = completion.choices[0].message.content probabilities = self.extract_probabilities(response) return probabilities except Exception as e: print(f"Error during inference for row: {e}") return None def extract_probabilities(self, response): """ Extract probabilities from the LLM response using regex. """ pattern = r'\[(.*?)\]' match = re.search(pattern, response) if match: probs = match.group(1).split(',') return [float(p.strip()) for p in probs if p.strip()] return None # Example usage if __name__ == "__main__": handler = DDoSInference() handler.process_dataset() print("You can now interact with the model for mitigation steps or download the results.")