Canstralian commited on
Commit
6a46dba
·
verified ·
1 Parent(s): d8536bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -15
app.py CHANGED
@@ -1,10 +1,23 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
  # Sidebar for user input
6
  st.sidebar.header("Model Configuration")
7
- model_name = st.sidebar.text_input("Enter model name", "huggingface/transformers")
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Load model and tokenizer on demand
10
  @st.cache_resource
@@ -12,7 +25,10 @@ def load_model(model_name):
12
  try:
13
  # Load the model and tokenizer
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
 
 
16
  return tokenizer, model
17
  except Exception as e:
18
  st.error(f"Error loading model: {e}")
@@ -22,20 +38,28 @@ def load_model(model_name):
22
  tokenizer, model = load_model(model_name)
23
 
24
  # Input text box in the main panel
25
- st.title("Text Classification with Hugging Face Models")
26
- user_input = st.text_area("Enter text for classification:")
27
 
28
  # Make prediction if user input is provided
29
  if user_input and model and tokenizer:
30
- inputs = tokenizer(user_input, return_tensors="pt")
31
- with torch.no_grad():
32
- outputs = model(**inputs)
 
 
 
 
33
 
34
- # Display results (e.g., classification logits)
35
- logits = outputs.logits
36
- predicted_class = torch.argmax(logits, dim=-1).item()
37
- st.write(f"Predicted Class: {predicted_class}")
38
- st.write(f"Logits: {logits}")
39
- else:
40
- st.info("Please enter some text to classify.")
 
 
41
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
  # Sidebar for user input
6
  st.sidebar.header("Model Configuration")
7
+ model_choice = st.sidebar.selectbox("Select a model", [
8
+ "CyberAttackDetection",
9
+ "text2shellcommands",
10
+ "pentest_ai"
11
+ ])
12
+
13
+ # Define the model names
14
+ model_mapping = {
15
+ "CyberAttackDetection": "Canstralian/CyberAttackDetection",
16
+ "text2shellcommands": "Canstralian/text2shellcommands",
17
+ "pentest_ai": "Canstralian/pentest_ai"
18
+ }
19
+
20
+ model_name = model_mapping.get(model_choice, "Canstralian/CyberAttackDetection")
21
 
22
  # Load model and tokenizer on demand
23
  @st.cache_resource
 
25
  try:
26
  # Load the model and tokenizer
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ if model_name == "Canstralian/text2shellcommands":
29
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
30
+ else:
31
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
32
  return tokenizer, model
33
  except Exception as e:
34
  st.error(f"Error loading model: {e}")
 
38
  tokenizer, model = load_model(model_name)
39
 
40
  # Input text box in the main panel
41
+ st.title(f"{model_choice} Model")
42
+ user_input = st.text_area("Enter text:")
43
 
44
  # Make prediction if user input is provided
45
  if user_input and model and tokenizer:
46
+ if model_choice == "text2shellcommands":
47
+ # For text2shellcommands model, generate shell commands
48
+ inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
49
+ with torch.no_grad():
50
+ outputs = model.generate(**inputs)
51
+ generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+ st.write(f"Generated Shell Command: {generated_command}")
53
 
54
+ else:
55
+ # For CyberAttackDetection and pentest_ai models, perform classification
56
+ inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
57
+ with torch.no_grad():
58
+ outputs = model(**inputs)
59
+ logits = outputs.logits
60
+ predicted_class = torch.argmax(logits, dim=-1).item()
61
+ st.write(f"Predicted Class: {predicted_class}")
62
+ st.write(f"Logits: {logits}")
63
 
64
+ else:
65
+ st.info("Please enter some text for prediction.")