Spestly commited on
Commit
4b0103d
·
verified ·
1 Parent(s): ef62490

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -33
app.py CHANGED
@@ -18,16 +18,15 @@ MODELS = {
18
  },
19
  "emoji": "🦁",
20
  "experimental": True,
21
- "is_vision": False, # Enable vision support for this model
22
- "system_prompt_env": "ATLAS_FLASH_1215", # Environment variable for system prompt
23
  },
24
  }
25
 
26
  # Profile pictures
27
- USER_PFP = "user.png" # Hugging Face user avatar
28
- AI_PFP = "ai_pfp.png" # Replace with the path to your AI's image or a URL
29
 
30
- # Set page config (must be called only once and before any other Streamlit commands)
31
  st.set_page_config(
32
  page_title="Atlas Model Inference",
33
  page_icon="🦁 ",
@@ -39,15 +38,12 @@ st.set_page_config(
39
  }
40
  )
41
 
42
- # Custom CSS for blue sliders and button
43
  st.markdown(
44
  """
45
  <style>
46
- /* Blue slider */
47
  .stSlider > div > div > div > div {
48
  background-color: #1f78b4 !important;
49
  }
50
- /* Blue button */
51
  .stButton > button {
52
  background-color: #1f78b4 !important;
53
  color: white !important;
@@ -69,7 +65,6 @@ class AtlasInferenceApp:
69
  st.session_state.chat_history = []
70
 
71
  def clear_memory(self):
72
- """Optimize memory management for CPU inference"""
73
  if torch.cuda.is_available():
74
  torch.cuda.empty_cache()
75
  gc.collect()
@@ -85,24 +80,22 @@ class AtlasInferenceApp:
85
 
86
  model_path = MODELS[model_key]["sizes"][model_size]
87
 
88
- # Load Qwen-compatible tokenizer and model
89
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
90
  model = AutoModelForCausalLM.from_pretrained(
91
  model_path,
92
- device_map="auto", # Use GPU if available
93
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
94
  trust_remote_code=True,
95
  low_cpu_mem_usage=True
96
  )
97
 
98
- # Update session state
99
  st.session_state.current_model.update({
100
  "tokenizer": tokenizer,
101
  "model": model,
102
  "config": {
103
  "name": f"{MODELS[model_key]['name']} {model_size}",
104
  "path": model_path,
105
- "system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"), # Load system prompt from env
106
  }
107
  })
108
  return f"✅ {MODELS[model_key]['name']} {model_size} loaded successfully!"
@@ -114,18 +107,10 @@ class AtlasInferenceApp:
114
  return "⚠️ Please select and load a model first"
115
 
116
  try:
117
- # Debugging: Check if config and system_prompt exist
118
- if "config" not in st.session_state.current_model:
119
- return "⚠️ Model configuration not found. Please load the model again."
120
-
121
- system_prompt = st.session_state.current_model["config"].get("system_prompt", "Default system prompt")
122
  if not system_prompt:
123
- system_prompt = "You are Atlas. You are developed by Spestly" # Fallback if system_prompt is None
124
-
125
- # Debugging: Print the system prompt for verification
126
- st.write(f"System Prompt: {system_prompt}")
127
 
128
- # Add the system instruction to guide the model's behavior
129
  prompt = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
130
 
131
  inputs = st.session_state.current_model["tokenizer"](
@@ -135,8 +120,6 @@ class AtlasInferenceApp:
135
  truncation=True,
136
  padding=True
137
  )
138
-
139
- # Generate response without streaming
140
  with torch.no_grad():
141
  output = st.session_state.current_model["model"].generate(
142
  input_ids=inputs.input_ids,
@@ -151,7 +134,6 @@ class AtlasInferenceApp:
151
  )
152
  response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
153
 
154
- # Remove the prompt from the response
155
  if prompt in response:
156
  response = response.replace(prompt, "").strip()
157
 
@@ -195,19 +177,16 @@ class AtlasInferenceApp:
195
 
196
  st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
197
 
198
- # Display chat history
199
  for message in st.session_state.chat_history:
200
  with st.chat_message(
201
  message["role"],
202
  avatar=USER_PFP if message["role"] == "user" else AI_PFP
203
  ):
204
  st.markdown(message["content"])
205
- if "image" in message:
206
- st.image(message["image"], caption="Uploaded Image", use_container_width=True) # Updated parameter
207
 
208
- # Input box for user messages
209
  if prompt := st.chat_input("Message Atlas..."):
210
- # Allow image upload if the model supports vision
211
  uploaded_image = None
212
  if MODELS[model_key]["is_vision"]:
213
  uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
@@ -216,7 +195,7 @@ class AtlasInferenceApp:
216
  with st.chat_message("user", avatar=USER_PFP):
217
  st.markdown(prompt)
218
  if uploaded_image:
219
- st.image(uploaded_image, caption="Uploaded Image", use_container_width=True) # Updated parameter
220
 
221
  with st.chat_message("assistant", avatar=AI_PFP):
222
  with st.spinner("Generating response..."):
@@ -233,4 +212,4 @@ def run():
233
  st.error(f"⚠️ Application Error: {str(e)}")
234
 
235
  if __name__ == "__main__":
236
- run()
 
18
  },
19
  "emoji": "🦁",
20
  "experimental": True,
21
+ "is_vision": False,
22
+ "system_prompt_env": "ATLAS_FLASH_1215",
23
  },
24
  }
25
 
26
  # Profile pictures
27
+ USER_PFP = "user.png"
28
+ AI_PFP = "ai_pfp.png"
29
 
 
30
  st.set_page_config(
31
  page_title="Atlas Model Inference",
32
  page_icon="🦁 ",
 
38
  }
39
  )
40
 
 
41
  st.markdown(
42
  """
43
  <style>
 
44
  .stSlider > div > div > div > div {
45
  background-color: #1f78b4 !important;
46
  }
 
47
  .stButton > button {
48
  background-color: #1f78b4 !important;
49
  color: white !important;
 
65
  st.session_state.chat_history = []
66
 
67
  def clear_memory(self):
 
68
  if torch.cuda.is_available():
69
  torch.cuda.empty_cache()
70
  gc.collect()
 
80
 
81
  model_path = MODELS[model_key]["sizes"][model_size]
82
 
 
83
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
84
  model = AutoModelForCausalLM.from_pretrained(
85
  model_path,
86
+ device_map="auto",
87
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
88
  trust_remote_code=True,
89
  low_cpu_mem_usage=True
90
  )
91
 
 
92
  st.session_state.current_model.update({
93
  "tokenizer": tokenizer,
94
  "model": model,
95
  "config": {
96
  "name": f"{MODELS[model_key]['name']} {model_size}",
97
  "path": model_path,
98
+ "system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"),
99
  }
100
  })
101
  return f"✅ {MODELS[model_key]['name']} {model_size} loaded successfully!"
 
107
  return "⚠️ Please select and load a model first"
108
 
109
  try:
110
+ system_prompt = st.session_state.current_model["config"]["system_prompt"]
 
 
 
 
111
  if not system_prompt:
112
+ return "⚠️ System prompt not found for the selected model."
 
 
 
113
 
 
114
  prompt = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
115
 
116
  inputs = st.session_state.current_model["tokenizer"](
 
120
  truncation=True,
121
  padding=True
122
  )
 
 
123
  with torch.no_grad():
124
  output = st.session_state.current_model["model"].generate(
125
  input_ids=inputs.input_ids,
 
134
  )
135
  response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
136
 
 
137
  if prompt in response:
138
  response = response.replace(prompt, "").strip()
139
 
 
177
 
178
  st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
179
 
 
180
  for message in st.session_state.chat_history:
181
  with st.chat_message(
182
  message["role"],
183
  avatar=USER_PFP if message["role"] == "user" else AI_PFP
184
  ):
185
  st.markdown(message["content"])
186
+ if "image" in message and message["image"]:
187
+ st.image(message["image"], caption="Uploaded Image", use_column_width=True)
188
 
 
189
  if prompt := st.chat_input("Message Atlas..."):
 
190
  uploaded_image = None
191
  if MODELS[model_key]["is_vision"]:
192
  uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
 
195
  with st.chat_message("user", avatar=USER_PFP):
196
  st.markdown(prompt)
197
  if uploaded_image:
198
+ st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
199
 
200
  with st.chat_message("assistant", avatar=AI_PFP):
201
  with st.spinner("Generating response..."):
 
212
  st.error(f"⚠️ Application Error: {str(e)}")
213
 
214
  if __name__ == "__main__":
215
+ run()