s0uL141 commited on
Commit
77f3279
·
verified ·
1 Parent(s): 98abd16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -31
app.py CHANGED
@@ -1,50 +1,49 @@
1
  import streamlit as st
2
- import os
3
- from langchain_groq import ChatGroq
4
- from dotenv import load_dotenv
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
-
7
- # Load environment variables from .env file
8
- load_dotenv()
9
- groq_api_key = os.environ.get("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")
10
-
11
- # Initialize the Hugging Face model
12
- huggingface_model_name = "s0uL141/fine_tuned_science_gemma2b-it" # Your Hugging Face model name
13
- tokenizer = AutoTokenizer.from_pretrained(huggingface_model_name)
14
- huggingface_model = AutoModelForCausalLM.from_pretrained(huggingface_model_name)
15
-
16
- # Initialize the ChatGroq model with the Hugging Face model name for inference
17
- llm = ChatGroq(
18
- temperature=0,
19
- model_name=huggingface_model_name, # Use the Hugging Face model name here
20
- api_key=groq_api_key
21
- )
22
-
23
- # Function to generate text using Hugging Face model (optional if you're using ChatGroq directly)
24
- def generate_response_huggingface(prompt, max_length=500):
25
  inputs = tokenizer(prompt, return_tensors="pt")
26
- output = huggingface_model.generate(inputs.input_ids, max_length=max_length)
 
 
27
  return tokenizer.decode(output[0], skip_special_tokens=True)
28
 
29
  # Streamlit App
30
  def main():
31
- st.title("Cybersecurity Q&A with ChatGroq and Hugging Face Model")
32
- st.write("This app generates responses to your cybersecurity questions using ChatGroq with a Hugging Face model.")
 
33
 
34
- user_input = st.text_area("Enter your cybersecurity-related question below:", height=200)
 
35
 
 
36
  if st.button("Generate Response"):
 
37
  if user_input.strip() == "":
38
- st.write("Please enter a valid question.")
39
  else:
40
  with st.spinner("Generating response..."):
41
- # Invoke the ChatGroq model with the user's input
42
- response = llm.invoke(user_input)
43
-
44
  # Display the generated response
45
  st.write("### Model Response:")
46
  st.write(response)
47
 
48
  # Entry point to run the app
49
  if __name__ == "__main__":
50
- main()
 
1
  import streamlit as st
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ # Cache the model loading to avoid reloading it on every interaction
6
+ @st.cache_resource
7
+ def load_model():
8
+ model_name = "s0uL141/Cyber_gemma2_2B_it" # Replace with your Hugging Face repo or local path
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
11
+ return tokenizer, model
12
+
13
+ # Load the model and tokenizer
14
+ tokenizer, model = load_model()
15
+
16
+ # Function to generate text based on the user prompt
17
+ def generate_response(prompt, max_length=50):
18
+ # Tokenize input prompt
 
 
 
19
  inputs = tokenizer(prompt, return_tensors="pt")
20
+ # Generate response using the model
21
+ output = model.generate(inputs.input_ids, max_length=max_length, num_return_sequences=1)
22
+ # Decode the response and return
23
  return tokenizer.decode(output[0], skip_special_tokens=True)
24
 
25
  # Streamlit App
26
  def main():
27
+ # Set up the title and description for the app
28
+ st.title("Fine-Tuned Cyber Gemma 2b-it Model")
29
+ st.write("This app generates responses based on your input using a fine-tuned version of the Gemma 2b-it model.")
30
 
31
+ # Text input area for the user to provide a prompt
32
+ user_input = st.text_area("Enter your prompt here:", height=200)
33
 
34
+ # Button to trigger text generation
35
  if st.button("Generate Response"):
36
+ # Check if user input is provided
37
  if user_input.strip() == "":
38
+ st.write("Please enter a valid prompt.")
39
  else:
40
  with st.spinner("Generating response..."):
41
+ # Generate response using the model
42
+ response = generate_response(user_input)
 
43
  # Display the generated response
44
  st.write("### Model Response:")
45
  st.write(response)
46
 
47
  # Entry point to run the app
48
  if __name__ == "__main__":
49
+ main()