ruslanmv commited on
Commit
58bcb23
·
verified ·
1 Parent(s): f2b4cb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -53
app.py CHANGED
@@ -1,12 +1,15 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- from transformers import AutoTokenizer # Import the tokenizer
4
 
5
- # Import the tokenizer - No need to import twice, remove the second import
6
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
- # Define a maximum context length (tokens). Check your model's documentation!
9
- MAX_CONTEXT_LENGTH = 4096 # Example: Adjust this based on your model!
 
 
 
10
  default_nvc_prompt_template = r"""<|system|>You are Roos, an NVC (Nonviolent Communication) Chatbot. Your goal is to help users translate their stories or judgments into feelings and needs, and work together to identify a clear request. Follow these steps:
11
  1. **Goal of the Conversation**
12
  - Translate the user’s story or judgments into feelings and needs.
@@ -72,87 +75,113 @@ default_nvc_prompt_template = r"""<|system|>You are Roos, an NVC (Nonviolent Com
72
  - “I sense some frustration. Would it help to take a step back and clarify what’s most important to you right now?”
73
  13. **Ending the Conversation**
74
  - If the user indicates they want to end the conversation, thank them for sharing and offer to continue later:
75
- - “Thank you for sharing with me. If you’d like to continue this conversation later, I’m here to help.”</s>"""def count_tokens(text: str) -> int:
76
- """Counts the number of tokens in a given string."""
77
- return len(tokenizer.encode(text))def truncate_history(history: list[tuple[str, str]], system_message: str, max_length: int) -> list[tuple[str, str]]:
78
- """Truncates the conversation history to fit within the maximum token limit.
 
 
 
 
 
79
 
80
  Args:
81
- history: The conversation history (list of user/assistant tuples).
82
  system_message: The system message.
83
  max_length: The maximum number of tokens allowed.
84
 
85
  Returns:
86
- The truncated history.
87
  """
88
  truncated_history = []
89
  system_message_tokens = count_tokens(system_message)
90
  current_length = system_message_tokens
91
- # Iterate backwards through the history (newest to oldest)
 
92
  for user_msg, assistant_msg in reversed(history):
93
  user_tokens = count_tokens(user_msg) if user_msg else 0
94
  assistant_tokens = count_tokens(assistant_msg) if assistant_msg else 0
95
  turn_tokens = user_tokens + assistant_tokens
 
96
  if current_length + turn_tokens <= max_length:
97
- truncated_history.insert(0, (user_msg, assistant_msg)) # Add to the beginning
98
  current_length += turn_tokens
99
  else:
100
- break # Stop adding turns if we exceed the limit
101
- return truncated_historydef respond(
102
- message,
103
- history: list[tuple[str, str]],
104
- system_message, # System message is now an argument
105
- max_tokens,
106
- temperature,
107
- top_p,
108
- ):
109
- """Responds to a user message, maintaining conversation history, using special tokens and message list."""
110
- if message.lower() == "clear memory": # Check for the clear memory command
111
- return "", [] # Return empty message and empty history to reset the chat
112
-
113
- formatted_system_message = system_message # Use the system_message argument
114
- truncated_history = truncate_history(history, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100) # Reserve space for the new message and some generation
115
-
116
- messages = [{"role": "system", "content": formatted_system_message}] # Start with system message as before
 
 
 
 
 
 
 
117
  for user_msg, assistant_msg in truncated_history:
118
  if user_msg:
119
- messages.append({"role": "user", "content": user_msg}) # Format history user message - Removed extra tags
120
  if assistant_msg:
121
- messages.append({"role": "assistant", "content": assistant_msg}) # Format history assistant message - Removed extra tags
122
- messages.append({"role": "user", "content": message}) # Format current user message - Removed extra tags
 
 
123
 
124
  response = ""
125
  try:
126
- for chunk in client.chat_completion(
127
- messages, # Send the messages list again, but with formatted content
128
- max_tokens=max_tokens,
129
- stream=True,
130
- temperature=temperature,
131
- top_p=top_p,
132
- ):
133
- token = chunk.choices[0].delta.content
134
- response += token
135
- # Post-processing to remove prefixes (example - add to your existing yield) - Solution 3 (Fallback)
136
- processed_response = response.replace("User:", "").replace("Assistant:", "").replace("Roos:", "").lstrip()
137
- yield processed_response
138
 
139
  except Exception as e:
140
- print(f"An error occurred: {e}") # It's a good practice add a try-except block
141
- yield "I'm sorry, I encountered an error. Please try again."
142
-
143
 
144
- # --- Gradio Interface ---
145
  demo = gr.ChatInterface(
146
- respond,
147
  additional_inputs=[
148
  gr.Textbox(
149
  value=default_nvc_prompt_template,
150
  label="System message",
151
  visible=True,
152
- lines=10, # Increased height for more space to read the prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  ),
154
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
155
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
156
  gr.Slider(
157
  minimum=0.1,
158
  maximum=1.0,
@@ -164,4 +193,4 @@ demo = gr.ChatInterface(
164
  )
165
 
166
  if __name__ == "__main__":
167
- demo.launch(share=True)
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from transformers import AutoTokenizer
4
 
5
+ # Load tokenizer and inference client
6
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+ # Define a maximum context length (tokens); adjust based on your model
10
+ MAX_CONTEXT_LENGTH = 4096
11
+
12
+ # Default system prompt
13
  default_nvc_prompt_template = r"""<|system|>You are Roos, an NVC (Nonviolent Communication) Chatbot. Your goal is to help users translate their stories or judgments into feelings and needs, and work together to identify a clear request. Follow these steps:
14
  1. **Goal of the Conversation**
15
  - Translate the user’s story or judgments into feelings and needs.
 
75
  - “I sense some frustration. Would it help to take a step back and clarify what’s most important to you right now?”
76
  13. **Ending the Conversation**
77
  - If the user indicates they want to end the conversation, thank them for sharing and offer to continue later:
78
+ - “Thank you for sharing with me. If you’d like to continue this conversation later, I’m here to help.”"""
79
+
80
+ def count_tokens(text: str) -> int:
81
+ """Counts the number of tokens in a given string by encoding with the tokenizer."""
82
+ return len(tokenizer.encode(text))
83
+
84
+ def truncate_history(history: list[tuple[str, str]], system_message: str, max_length: int) -> list[tuple[str, str]]:
85
+ """
86
+ Truncates the conversation history to fit within the maximum token limit.
87
 
88
  Args:
89
+ history: The conversation history (list of (user_msg, assistant_msg) tuples).
90
  system_message: The system message.
91
  max_length: The maximum number of tokens allowed.
92
 
93
  Returns:
94
+ The truncated history as a list of (user_msg, assistant_msg) tuples.
95
  """
96
  truncated_history = []
97
  system_message_tokens = count_tokens(system_message)
98
  current_length = system_message_tokens
99
+
100
+ # Iterate backwards (from the newest to the oldest)
101
  for user_msg, assistant_msg in reversed(history):
102
  user_tokens = count_tokens(user_msg) if user_msg else 0
103
  assistant_tokens = count_tokens(assistant_msg) if assistant_msg else 0
104
  turn_tokens = user_tokens + assistant_tokens
105
+
106
  if current_length + turn_tokens <= max_length:
107
+ truncated_history.insert(0, (user_msg, assistant_msg))
108
  current_length += turn_tokens
109
  else:
110
+ break
111
+
112
+ return truncated_history
113
+
114
+ def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
115
+ """
116
+ Responds to a user message, maintaining conversation history.
117
+ Uses standard role-based messaging rather than explicit <|user|> tokens.
118
+ """
119
+
120
+ # Clear memory command
121
+ if message.lower() == "clear memory":
122
+ return "", []
123
+
124
+ # Truncate the history to fit within max token limit
125
+ truncated_history = truncate_history(
126
+ history,
127
+ system_message,
128
+ MAX_CONTEXT_LENGTH - max_tokens - 100 # Reserve space for the new message
129
+ )
130
+
131
+ # Prepare the messages list in a standard chat format
132
+ messages = [{"role": "system", "content": system_message}]
133
+
134
  for user_msg, assistant_msg in truncated_history:
135
  if user_msg:
136
+ messages.append({"role": "user", "content": user_msg})
137
  if assistant_msg:
138
+ messages.append({"role": "assistant", "content": assistant_msg})
139
+
140
+ # Add the new user message
141
+ messages.append({"role": "user", "content": message})
142
 
143
  response = ""
144
  try:
145
+ # Stream the response
146
+ for chunk in client.chat_completion(
147
+ messages,
148
+ max_tokens=max_tokens,
149
+ stream=True,
150
+ temperature=temperature,
151
+ top_p=top_p
152
+ ):
153
+ token = chunk.choices[0].delta.content
154
+ response += token
155
+ yield response
 
156
 
157
  except Exception as e:
158
+ print(f"An error occurred: {e}")
159
+ yield "I'm sorry, I encountered an error. Please try again."
 
160
 
161
+ # Build a Gradio chat interface
162
  demo = gr.ChatInterface(
163
+ fn=respond,
164
  additional_inputs=[
165
  gr.Textbox(
166
  value=default_nvc_prompt_template,
167
  label="System message",
168
  visible=True,
169
+ lines=10,
170
+ ),
171
+ gr.Slider(
172
+ minimum=1,
173
+ maximum=2048,
174
+ value=512,
175
+ step=1,
176
+ label="Max new tokens",
177
+ ),
178
+ gr.Slider(
179
+ minimum=0.1,
180
+ maximum=4.0,
181
+ value=0.7,
182
+ step=0.1,
183
+ label="Temperature",
184
  ),
 
 
185
  gr.Slider(
186
  minimum=0.1,
187
  maximum=1.0,
 
193
  )
194
 
195
  if __name__ == "__main__":
196
+ demo.launch(share=True)