Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -18,16 +18,15 @@ MODELS = {
|
|
18 |
},
|
19 |
"emoji": "🦁",
|
20 |
"experimental": True,
|
21 |
-
"is_vision": False,
|
22 |
-
"system_prompt_env": "ATLAS_FLASH_1215",
|
23 |
},
|
24 |
}
|
25 |
|
26 |
# Profile pictures
|
27 |
-
USER_PFP = "user.png"
|
28 |
-
AI_PFP = "ai_pfp.png"
|
29 |
|
30 |
-
# Set page config (must be called only once and before any other Streamlit commands)
|
31 |
st.set_page_config(
|
32 |
page_title="Atlas Model Inference",
|
33 |
page_icon="🦁 ",
|
@@ -39,15 +38,12 @@ st.set_page_config(
|
|
39 |
}
|
40 |
)
|
41 |
|
42 |
-
# Custom CSS for blue sliders and button
|
43 |
st.markdown(
|
44 |
"""
|
45 |
<style>
|
46 |
-
/* Blue slider */
|
47 |
.stSlider > div > div > div > div {
|
48 |
background-color: #1f78b4 !important;
|
49 |
}
|
50 |
-
/* Blue button */
|
51 |
.stButton > button {
|
52 |
background-color: #1f78b4 !important;
|
53 |
color: white !important;
|
@@ -69,7 +65,6 @@ class AtlasInferenceApp:
|
|
69 |
st.session_state.chat_history = []
|
70 |
|
71 |
def clear_memory(self):
|
72 |
-
"""Optimize memory management for CPU inference"""
|
73 |
if torch.cuda.is_available():
|
74 |
torch.cuda.empty_cache()
|
75 |
gc.collect()
|
@@ -85,24 +80,22 @@ class AtlasInferenceApp:
|
|
85 |
|
86 |
model_path = MODELS[model_key]["sizes"][model_size]
|
87 |
|
88 |
-
# Load Qwen-compatible tokenizer and model
|
89 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
90 |
model = AutoModelForCausalLM.from_pretrained(
|
91 |
model_path,
|
92 |
-
device_map="auto",
|
93 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
94 |
trust_remote_code=True,
|
95 |
low_cpu_mem_usage=True
|
96 |
)
|
97 |
|
98 |
-
# Update session state
|
99 |
st.session_state.current_model.update({
|
100 |
"tokenizer": tokenizer,
|
101 |
"model": model,
|
102 |
"config": {
|
103 |
"name": f"{MODELS[model_key]['name']} {model_size}",
|
104 |
"path": model_path,
|
105 |
-
"system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"),
|
106 |
}
|
107 |
})
|
108 |
return f"✅ {MODELS[model_key]['name']} {model_size} loaded successfully!"
|
@@ -114,18 +107,10 @@ class AtlasInferenceApp:
|
|
114 |
return "⚠️ Please select and load a model first"
|
115 |
|
116 |
try:
|
117 |
-
|
118 |
-
if "config" not in st.session_state.current_model:
|
119 |
-
return "⚠️ Model configuration not found. Please load the model again."
|
120 |
-
|
121 |
-
system_prompt = st.session_state.current_model["config"].get("system_prompt", "Default system prompt")
|
122 |
if not system_prompt:
|
123 |
-
|
124 |
-
|
125 |
-
# Debugging: Print the system prompt for verification
|
126 |
-
st.write(f"System Prompt: {system_prompt}")
|
127 |
|
128 |
-
# Add the system instruction to guide the model's behavior
|
129 |
prompt = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
|
130 |
|
131 |
inputs = st.session_state.current_model["tokenizer"](
|
@@ -135,8 +120,6 @@ class AtlasInferenceApp:
|
|
135 |
truncation=True,
|
136 |
padding=True
|
137 |
)
|
138 |
-
|
139 |
-
# Generate response without streaming
|
140 |
with torch.no_grad():
|
141 |
output = st.session_state.current_model["model"].generate(
|
142 |
input_ids=inputs.input_ids,
|
@@ -151,7 +134,6 @@ class AtlasInferenceApp:
|
|
151 |
)
|
152 |
response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
|
153 |
|
154 |
-
# Remove the prompt from the response
|
155 |
if prompt in response:
|
156 |
response = response.replace(prompt, "").strip()
|
157 |
|
@@ -195,19 +177,16 @@ class AtlasInferenceApp:
|
|
195 |
|
196 |
st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
|
197 |
|
198 |
-
# Display chat history
|
199 |
for message in st.session_state.chat_history:
|
200 |
with st.chat_message(
|
201 |
message["role"],
|
202 |
avatar=USER_PFP if message["role"] == "user" else AI_PFP
|
203 |
):
|
204 |
st.markdown(message["content"])
|
205 |
-
if "image" in message:
|
206 |
-
st.image(message["image"], caption="Uploaded Image",
|
207 |
|
208 |
-
# Input box for user messages
|
209 |
if prompt := st.chat_input("Message Atlas..."):
|
210 |
-
# Allow image upload if the model supports vision
|
211 |
uploaded_image = None
|
212 |
if MODELS[model_key]["is_vision"]:
|
213 |
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
|
@@ -216,7 +195,7 @@ class AtlasInferenceApp:
|
|
216 |
with st.chat_message("user", avatar=USER_PFP):
|
217 |
st.markdown(prompt)
|
218 |
if uploaded_image:
|
219 |
-
st.image(uploaded_image, caption="Uploaded Image",
|
220 |
|
221 |
with st.chat_message("assistant", avatar=AI_PFP):
|
222 |
with st.spinner("Generating response..."):
|
@@ -233,4 +212,4 @@ def run():
|
|
233 |
st.error(f"⚠️ Application Error: {str(e)}")
|
234 |
|
235 |
if __name__ == "__main__":
|
236 |
-
run()
|
|
|
18 |
},
|
19 |
"emoji": "🦁",
|
20 |
"experimental": True,
|
21 |
+
"is_vision": False,
|
22 |
+
"system_prompt_env": "ATLAS_FLASH_1215",
|
23 |
},
|
24 |
}
|
25 |
|
26 |
# Profile pictures
|
27 |
+
USER_PFP = "user.png"
|
28 |
+
AI_PFP = "ai_pfp.png"
|
29 |
|
|
|
30 |
st.set_page_config(
|
31 |
page_title="Atlas Model Inference",
|
32 |
page_icon="🦁 ",
|
|
|
38 |
}
|
39 |
)
|
40 |
|
|
|
41 |
st.markdown(
|
42 |
"""
|
43 |
<style>
|
|
|
44 |
.stSlider > div > div > div > div {
|
45 |
background-color: #1f78b4 !important;
|
46 |
}
|
|
|
47 |
.stButton > button {
|
48 |
background-color: #1f78b4 !important;
|
49 |
color: white !important;
|
|
|
65 |
st.session_state.chat_history = []
|
66 |
|
67 |
def clear_memory(self):
|
|
|
68 |
if torch.cuda.is_available():
|
69 |
torch.cuda.empty_cache()
|
70 |
gc.collect()
|
|
|
80 |
|
81 |
model_path = MODELS[model_key]["sizes"][model_size]
|
82 |
|
|
|
83 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
84 |
model = AutoModelForCausalLM.from_pretrained(
|
85 |
model_path,
|
86 |
+
device_map="auto",
|
87 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
88 |
trust_remote_code=True,
|
89 |
low_cpu_mem_usage=True
|
90 |
)
|
91 |
|
|
|
92 |
st.session_state.current_model.update({
|
93 |
"tokenizer": tokenizer,
|
94 |
"model": model,
|
95 |
"config": {
|
96 |
"name": f"{MODELS[model_key]['name']} {model_size}",
|
97 |
"path": model_path,
|
98 |
+
"system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"),
|
99 |
}
|
100 |
})
|
101 |
return f"✅ {MODELS[model_key]['name']} {model_size} loaded successfully!"
|
|
|
107 |
return "⚠️ Please select and load a model first"
|
108 |
|
109 |
try:
|
110 |
+
system_prompt = st.session_state.current_model["config"]["system_prompt"]
|
|
|
|
|
|
|
|
|
111 |
if not system_prompt:
|
112 |
+
return "⚠️ System prompt not found for the selected model."
|
|
|
|
|
|
|
113 |
|
|
|
114 |
prompt = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
|
115 |
|
116 |
inputs = st.session_state.current_model["tokenizer"](
|
|
|
120 |
truncation=True,
|
121 |
padding=True
|
122 |
)
|
|
|
|
|
123 |
with torch.no_grad():
|
124 |
output = st.session_state.current_model["model"].generate(
|
125 |
input_ids=inputs.input_ids,
|
|
|
134 |
)
|
135 |
response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
|
136 |
|
|
|
137 |
if prompt in response:
|
138 |
response = response.replace(prompt, "").strip()
|
139 |
|
|
|
177 |
|
178 |
st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
|
179 |
|
|
|
180 |
for message in st.session_state.chat_history:
|
181 |
with st.chat_message(
|
182 |
message["role"],
|
183 |
avatar=USER_PFP if message["role"] == "user" else AI_PFP
|
184 |
):
|
185 |
st.markdown(message["content"])
|
186 |
+
if "image" in message and message["image"]:
|
187 |
+
st.image(message["image"], caption="Uploaded Image", use_column_width=True)
|
188 |
|
|
|
189 |
if prompt := st.chat_input("Message Atlas..."):
|
|
|
190 |
uploaded_image = None
|
191 |
if MODELS[model_key]["is_vision"]:
|
192 |
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
|
|
|
195 |
with st.chat_message("user", avatar=USER_PFP):
|
196 |
st.markdown(prompt)
|
197 |
if uploaded_image:
|
198 |
+
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
|
199 |
|
200 |
with st.chat_message("assistant", avatar=AI_PFP):
|
201 |
with st.spinner("Generating response..."):
|
|
|
212 |
st.error(f"⚠️ Application Error: {str(e)}")
|
213 |
|
214 |
if __name__ == "__main__":
|
215 |
+
run()
|