Mattral commited on
Commit
26a9c66
·
verified ·
1 Parent(s): 7a3a8fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -39
app.py CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import os
5
  from dotenv import load_dotenv
 
6
 
7
  # Load environment variables
8
  load_dotenv()
@@ -12,19 +13,14 @@ HF_TOKEN = os.getenv("HF_TOKEN")
12
  st.title("I am Your GrowBuddy 🌱")
13
  st.write("Let me help you start gardening. Let's grow together!")
14
 
15
- # Function to load model only once
 
16
  def load_model():
17
  try:
18
- # If model and tokenizer are already in session state, return them
19
- if "tokenizer" in st.session_state and "model" in st.session_state:
20
- return st.session_state.tokenizer, st.session_state.model
21
- else:
22
- tokenizer = AutoTokenizer.from_pretrained("TheSheBots/UrbanGardening", use_auth_token=HF_TOKEN)
23
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN)
24
- # Store the model and tokenizer in session state
25
- st.session_state.tokenizer = tokenizer
26
- st.session_state.model = model
27
- return tokenizer, model
28
  except Exception as e:
29
  st.error(f"Failed to load model: {e}")
30
  return None, None
@@ -35,8 +31,8 @@ tokenizer, model = load_model()
35
  if not tokenizer or not model:
36
  st.stop()
37
 
38
- # Default to CPU, or use GPU if available
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  model = model.to(device)
41
 
42
  # Initialize session state messages
@@ -50,33 +46,20 @@ for message in st.session_state.messages:
50
  with st.chat_message(message["role"]):
51
  st.write(message["content"])
52
 
53
- # Create a text area to display logs
54
- log_box = st.empty()
 
 
 
 
 
55
 
56
- # Function to generate response with debugging logs
57
  def generate_response(prompt):
58
  try:
59
- # Tokenize input prompt with dynamic padding and truncation
60
- log_box.text_area("Debugging Logs", "Tokenizing the prompt...", height=200)
61
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
62
-
63
- # Display tokenized inputs
64
- log_box.text_area("Debugging Logs", f"Tokenized inputs: {inputs['input_ids']}", height=200)
65
-
66
- # Generate output from model
67
- log_box.text_area("Debugging Logs", "Generating output...", height=200)
68
- outputs = model.generate(inputs["input_ids"], max_new_tokens=100, temperature=0.7, do_sample=True)
69
-
70
- # Display the raw output from the model
71
- log_box.text_area("Debugging Logs", f"Raw model output (tokens): {outputs}", height=200)
72
-
73
- # Decode and return response
74
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
-
76
- # Display the final decoded response
77
- log_box.text_area("Debugging Logs", f"Decoded response: {response}", height=200)
78
-
79
- return response
80
  except Exception as e:
81
  st.error(f"Error during text generation: {e}")
82
  return "Sorry, I couldn't process your request."
@@ -93,7 +76,6 @@ if user_input:
93
  response = generate_response(user_input)
94
  st.write(response)
95
 
96
- # Update session state
97
  st.session_state.messages.append({"role": "user", "content": user_input})
98
  st.session_state.messages.append({"role": "assistant", "content": response})
99
-
 
3
  import torch
4
  import os
5
  from dotenv import load_dotenv
6
+ from functools import lru_cache
7
 
8
  # Load environment variables
9
  load_dotenv()
 
13
  st.title("I am Your GrowBuddy 🌱")
14
  st.write("Let me help you start gardening. Let's grow together!")
15
 
16
+ # Function to load model only once (with quantization for CPU optimization)
17
+ @st.cache_resource
18
  def load_model():
19
  try:
20
+ tokenizer = AutoTokenizer.from_pretrained("TheSheBots/UrbanGardening", use_auth_token=HF_TOKEN, use_fast=True)
21
+ # Quantized model for better CPU performance (with 8-bit precision)
22
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN, torch_dtype=torch.float32)
23
+ return tokenizer, model
 
 
 
 
 
 
24
  except Exception as e:
25
  st.error(f"Failed to load model: {e}")
26
  return None, None
 
31
  if not tokenizer or not model:
32
  st.stop()
33
 
34
+ # Ensure model is on CPU (set to float32 for better performance on CPU)
35
+ device = torch.device("cpu")
36
  model = model.to(device)
37
 
38
  # Initialize session state messages
 
46
  with st.chat_message(message["role"]):
47
  st.write(message["content"])
48
 
49
+ # LRU Cache for repeated queries to avoid redundant computation
50
+ @lru_cache(maxsize=100)
51
+ def cached_generate_response(prompt, tokenizer, model):
52
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
53
+ outputs = model.generate(inputs["input_ids"], max_new_tokens=50, temperature=0.7, do_sample=True)
54
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
+ return response
56
 
57
+ # Function to generate response with optimization
58
  def generate_response(prompt):
59
  try:
60
+ # Check cache for previous result (for repeated queries)
61
+ cached_response = cached_generate_response(prompt, tokenizer, model)
62
+ return cached_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  except Exception as e:
64
  st.error(f"Error during text generation: {e}")
65
  return "Sorry, I couldn't process your request."
 
76
  response = generate_response(user_input)
77
  st.write(response)
78
 
79
+ # Update session state with new messages
80
  st.session_state.messages.append({"role": "user", "content": user_input})
81
  st.session_state.messages.append({"role": "assistant", "content": response})