Daemontatox commited on
Commit
83f478c
·
verified ·
1 Parent(s): ee31cc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -47
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- import spaces
3
  import gradio as gr
4
  from threading import Thread
5
  from transformers import (
@@ -20,13 +19,8 @@ You speak with a playful and patient tone, using simple, child-friendly language
20
  Your responses are short, sweet, and filled with kindness, designed to nurture curiosity and inspire learning.
21
  Remember, you’re here to make every interaction magical—without using emojis.
22
  Keep your answers short and friendly.
23
-
24
-
25
-
26
-
27
  """
28
 
29
-
30
  CSS = """
31
  .gr-chatbot { min-height: 500px; border-radius: 15px; }
32
  .special-tag { color: #2ecc71; font-weight: 600; }
@@ -38,6 +32,7 @@ class StopOnTokens(StoppingCriteria):
38
  return input_ids[0][-1] == tokenizer.eos_token_id
39
 
40
  def initialize_model():
 
41
  quantization_config = BitsAndBytesConfig(
42
  load_in_4bit=True,
43
  bnb_4bit_compute_dtype=torch.bfloat16,
@@ -51,6 +46,7 @@ def initialize_model():
51
  model = AutoModelForCausalLM.from_pretrained(
52
  MODEL_ID,
53
  device_map="cuda",
 
54
  # quantization_config=quantization_config,
55
  torch_dtype=torch.bfloat16,
56
  trust_remote_code=True
@@ -59,30 +55,29 @@ def initialize_model():
59
  return model, tokenizer
60
 
61
  def format_response(text):
62
- return text.replace("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n') \
63
- .replace("[Plan]", '\n<strong class="special-tag">[Plan]</strong>\n') \
64
- .replace("[Conclude]", '\n<strong class="special-tag">[Conclude]</strong>\n') \
65
- .replace("[Reason]", '\n<strong class="special-tag">[Reason]</strong>\n') \
66
- .replace("[Verify]", '\n<strong class="special-tag">[Verify]</strong>\n')
67
- @spaces.GPU()
68
- def generate_response(message, chat_history, system_prompt, temperature, max_tokens):
69
- # Create conversation history for model
70
- conversation = [{"role": "system", "content": system_prompt}]
71
- for user_msg, bot_msg in chat_history:
72
- conversation.extend([
73
- {"role": "user", "content": user_msg},
74
- {"role": "assistant", "content": bot_msg}
75
- ])
76
- conversation.append({"role": "user", "content": message})
77
-
78
- # Tokenize input
79
  input_ids = tokenizer.apply_chat_template(
80
  conversation,
81
  add_generation_prompt=True,
82
  return_tensors="pt"
83
  ).to(model.device)
84
 
85
- # Setup streaming
86
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
87
  generate_kwargs = dict(
88
  input_ids=input_ids,
@@ -92,49 +87,47 @@ def generate_response(message, chat_history, system_prompt, temperature, max_tok
92
  stopping_criteria=StoppingCriteriaList([StopOnTokens()])
93
  )
94
 
95
- # Start generation thread
96
  Thread(target=model.generate, kwargs=generate_kwargs).start()
97
 
98
- # Initialize response buffer
99
- partial_message = ""
100
- new_history = chat_history + [(message, "")]
101
-
102
- # Stream response
103
  for new_token in streamer:
104
- partial_message += new_token
105
- formatted = format_response(partial_message)
106
- new_history[-1] = (message, formatted + "▌")
107
- yield new_history
108
-
109
- # Final update without cursor
110
- new_history[-1] = (message, format_response(partial_message))
111
- yield new_history
112
 
 
113
  model, tokenizer = initialize_model()
114
 
115
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
116
  gr.Markdown("""
117
- <h1 align="center">🧸🧸🧸 Immy Ai Teddy</h1>
118
- <p align="center">hi there buddy</p>
119
  """)
120
 
121
- chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
122
  msg = gr.Textbox(label="Your Question", placeholder="Type your question...")
123
 
124
  with gr.Accordion("⚙️ Settings", open=False):
125
  system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
126
  temperature = gr.Slider(0, 1, value=0.6, label="Creativity")
127
  max_tokens = gr.Slider(128, 1024, value=2048, label="Max Response Length")
128
-
129
- clear = gr.Button("Clear History")
130
 
 
 
 
 
 
 
131
  msg.submit(
132
  generate_response,
133
- [msg, chatbot, system_prompt, temperature, max_tokens],
134
- [chatbot],
135
  show_progress=True
136
  )
137
- clear.click(lambda: None, None, chatbot, queue=False)
138
 
139
  if __name__ == "__main__":
140
- demo.queue().launch()
 
1
  import torch
 
2
  import gradio as gr
3
  from threading import Thread
4
  from transformers import (
 
19
  Your responses are short, sweet, and filled with kindness, designed to nurture curiosity and inspire learning.
20
  Remember, you’re here to make every interaction magical—without using emojis.
21
  Keep your answers short and friendly.
 
 
 
 
22
  """
23
 
 
24
  CSS = """
25
  .gr-chatbot { min-height: 500px; border-radius: 15px; }
26
  .special-tag { color: #2ecc71; font-weight: 600; }
 
32
  return input_ids[0][-1] == tokenizer.eos_token_id
33
 
34
  def initialize_model():
35
+ # (Optional) Enable 4-bit quantization by uncommenting the quantization_config if desired.
36
  quantization_config = BitsAndBytesConfig(
37
  load_in_4bit=True,
38
  bnb_4bit_compute_dtype=torch.bfloat16,
 
46
  model = AutoModelForCausalLM.from_pretrained(
47
  MODEL_ID,
48
  device_map="cuda",
49
+ # If you want to enable 4-bit quantization, uncomment the following line:
50
  # quantization_config=quantization_config,
51
  torch_dtype=torch.bfloat16,
52
  trust_remote_code=True
 
55
  return model, tokenizer
56
 
57
  def format_response(text):
58
+ # Apply formatting to special tokens if needed
59
+ return (text.replace("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n')
60
+ .replace("[Plan]", '\n<strong class="special-tag">[Plan]</strong>\n')
61
+ .replace("[Conclude]", '\n<strong class="special-tag">[Conclude]</strong>\n')
62
+ .replace("[Reason]", '\n<strong class="special-tag">[Reason]</strong>\n')
63
+ .replace("[Verify]", '\n<strong class="special-tag">[Verify]</strong>\n'))
64
+
65
+ @gradio.sync # Ensures compatibility with the async streaming interface.
66
+ def generate_response(message, system_prompt, temperature, max_tokens):
67
+ # Create a minimal conversation with only the system prompt and the user's message.
68
+ conversation = [
69
+ {"role": "system", "content": system_prompt},
70
+ {"role": "user", "content": message}
71
+ ]
72
+
73
+ # Tokenize input using the chat template provided by the tokenizer
 
74
  input_ids = tokenizer.apply_chat_template(
75
  conversation,
76
  add_generation_prompt=True,
77
  return_tensors="pt"
78
  ).to(model.device)
79
 
80
+ # Set up the streamer to yield tokens as they are generated.
81
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
82
  generate_kwargs = dict(
83
  input_ids=input_ids,
 
87
  stopping_criteria=StoppingCriteriaList([StopOnTokens()])
88
  )
89
 
90
+ # Start generation in a separate thread.
91
  Thread(target=model.generate, kwargs=generate_kwargs).start()
92
 
93
+ answer = ""
94
+ # Stream and yield intermediate results with a cursor symbol.
 
 
 
95
  for new_token in streamer:
96
+ answer += new_token
97
+ yield format_response(answer) + "▌"
98
+ # Yield the final answer without the cursor.
99
+ yield format_response(answer)
 
 
 
 
100
 
101
+ # Initialize the model and tokenizer
102
  model, tokenizer = initialize_model()
103
 
104
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
105
  gr.Markdown("""
106
+ <h1 align="center">🧸 Immy Ai Teddy</h1>
107
+ <p align="center">Hi there, buddy!</p>
108
  """)
109
 
110
+ # User input: question to Immy.
111
  msg = gr.Textbox(label="Your Question", placeholder="Type your question...")
112
 
113
  with gr.Accordion("⚙️ Settings", open=False):
114
  system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
115
  temperature = gr.Slider(0, 1, value=0.6, label="Creativity")
116
  max_tokens = gr.Slider(128, 1024, value=2048, label="Max Response Length")
 
 
117
 
118
+ # Output: Only the model's answer will be displayed.
119
+ answer_output = gr.Markdown(label="Answer")
120
+
121
+ clear = gr.Button("Clear")
122
+
123
+ # When the user submits a question, only the model's answer is generated.
124
  msg.submit(
125
  generate_response,
126
+ inputs=[msg, system_prompt, temperature, max_tokens],
127
+ outputs=answer_output,
128
  show_progress=True
129
  )
130
+ clear.click(lambda: "", None, answer_output, queue=False)
131
 
132
  if __name__ == "__main__":
133
+ demo.queue().launch()