Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,9 +6,11 @@ from huggingface_hub import login
|
|
6 |
import re
|
7 |
import os
|
8 |
|
|
|
9 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
10 |
login(token=HF_TOKEN)
|
11 |
|
|
|
12 |
MODELS = {
|
13 |
"athena-1": {
|
14 |
"name": "🦁 Atlas-Flash",
|
@@ -20,9 +22,9 @@ MODELS = {
|
|
20 |
},
|
21 |
}
|
22 |
|
23 |
-
|
24 |
-
USER_PFP = "user.png"
|
25 |
-
AI_PFP = "ai_pfp.png"
|
26 |
|
27 |
class AtlasInferenceApp:
|
28 |
def __init__(self):
|
@@ -59,17 +61,17 @@ class AtlasInferenceApp:
|
|
59 |
|
60 |
model_path = MODELS[model_key]["sizes"][model_size]
|
61 |
|
62 |
-
|
63 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
64 |
model = AutoModelForCausalLM.from_pretrained(
|
65 |
model_path,
|
66 |
-
device_map="
|
67 |
-
torch_dtype=torch.float32,
|
68 |
trust_remote_code=True,
|
69 |
low_cpu_mem_usage=True
|
70 |
)
|
71 |
|
72 |
-
|
73 |
st.session_state.current_model.update({
|
74 |
"tokenizer": tokenizer,
|
75 |
"model": model,
|
@@ -87,7 +89,7 @@ class AtlasInferenceApp:
|
|
87 |
return "⚠️ Please select and load a model first"
|
88 |
|
89 |
try:
|
90 |
-
|
91 |
system_instruction = "You are Atlas, a helpful AI assistant trained to help the user. You are a Deepseek R1 fine-tune."
|
92 |
prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:"
|
93 |
|
@@ -99,8 +101,8 @@ class AtlasInferenceApp:
|
|
99 |
padding=True
|
100 |
)
|
101 |
|
102 |
-
|
103 |
-
response_container = st.empty()
|
104 |
full_response = ""
|
105 |
with torch.no_grad():
|
106 |
for chunk in st.session_state.current_model["model"].generate(
|
@@ -113,13 +115,27 @@ class AtlasInferenceApp:
|
|
113 |
do_sample=True,
|
114 |
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
|
115 |
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
|
116 |
-
streamer=None, # Use a custom streamer for real-time updates
|
117 |
):
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
except Exception as e:
|
124 |
return f"⚠️ Generation Error: {str(e)}"
|
125 |
finally:
|
@@ -159,6 +175,7 @@ class AtlasInferenceApp:
|
|
159 |
|
160 |
st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
|
161 |
|
|
|
162 |
for message in st.session_state.chat_history:
|
163 |
with st.chat_message(
|
164 |
message["role"],
|
@@ -166,6 +183,7 @@ class AtlasInferenceApp:
|
|
166 |
):
|
167 |
st.markdown(message["content"])
|
168 |
|
|
|
169 |
if prompt := st.chat_input("Message Atlas..."):
|
170 |
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
171 |
with st.chat_message("user", avatar=USER_PFP):
|
|
|
6 |
import re
|
7 |
import os
|
8 |
|
9 |
+
# Load Hugging Face token
|
10 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
11 |
login(token=HF_TOKEN)
|
12 |
|
13 |
+
# Define models
|
14 |
MODELS = {
|
15 |
"athena-1": {
|
16 |
"name": "🦁 Atlas-Flash",
|
|
|
22 |
},
|
23 |
}
|
24 |
|
25 |
+
# Profile pictures
|
26 |
+
USER_PFP = "user.png" # Hugging Face user avatar
|
27 |
+
AI_PFP = "ai_pfp.png" # Replace with the path to your AI's image or a URL
|
28 |
|
29 |
class AtlasInferenceApp:
|
30 |
def __init__(self):
|
|
|
61 |
|
62 |
model_path = MODELS[model_key]["sizes"][model_size]
|
63 |
|
64 |
+
# Load Qwen-compatible tokenizer and model
|
65 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
66 |
model = AutoModelForCausalLM.from_pretrained(
|
67 |
model_path,
|
68 |
+
device_map="auto", # Use GPU if available
|
69 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
70 |
trust_remote_code=True,
|
71 |
low_cpu_mem_usage=True
|
72 |
)
|
73 |
|
74 |
+
# Update session state
|
75 |
st.session_state.current_model.update({
|
76 |
"tokenizer": tokenizer,
|
77 |
"model": model,
|
|
|
89 |
return "⚠️ Please select and load a model first"
|
90 |
|
91 |
try:
|
92 |
+
# Add a system instruction to guide the model's behavior
|
93 |
system_instruction = "You are Atlas, a helpful AI assistant trained to help the user. You are a Deepseek R1 fine-tune."
|
94 |
prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:"
|
95 |
|
|
|
101 |
padding=True
|
102 |
)
|
103 |
|
104 |
+
# Generate response with streaming
|
105 |
+
response_container = st.empty() # Placeholder for streaming text
|
106 |
full_response = ""
|
107 |
with torch.no_grad():
|
108 |
for chunk in st.session_state.current_model["model"].generate(
|
|
|
115 |
do_sample=True,
|
116 |
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
|
117 |
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
|
|
|
118 |
):
|
119 |
+
# Decode the chunk and update the response
|
120 |
+
try:
|
121 |
+
chunk_text = st.session_state.current_model["tokenizer"].decode(chunk, skip_special_tokens=True)
|
122 |
+
|
123 |
+
# Remove the prompt from the response
|
124 |
+
if prompt in chunk_text:
|
125 |
+
chunk_text = chunk_text.replace(prompt, "").strip()
|
126 |
+
|
127 |
+
full_response += chunk_text
|
128 |
+
response_container.markdown(full_response)
|
129 |
+
except Exception as decode_error:
|
130 |
+
st.error(f"⚠️ Token Decoding Error: {str(decode_error)}")
|
131 |
+
break
|
132 |
+
|
133 |
+
# Stop if the response is too long or incomplete
|
134 |
+
if len(full_response) >= max_tokens * 4: # Approximate token-to-character ratio
|
135 |
+
st.warning("⚠️ Response truncated due to length limit.")
|
136 |
+
break
|
137 |
+
|
138 |
+
return full_response.strip() # Return the cleaned response
|
139 |
except Exception as e:
|
140 |
return f"⚠️ Generation Error: {str(e)}"
|
141 |
finally:
|
|
|
175 |
|
176 |
st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
|
177 |
|
178 |
+
# Display chat history
|
179 |
for message in st.session_state.chat_history:
|
180 |
with st.chat_message(
|
181 |
message["role"],
|
|
|
183 |
):
|
184 |
st.markdown(message["content"])
|
185 |
|
186 |
+
# Input box for user messages
|
187 |
if prompt := st.chat_input("Message Atlas..."):
|
188 |
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
189 |
with st.chat_message("user", avatar=USER_PFP):
|