Update app.py
Browse files
app.py
CHANGED
@@ -1,34 +1,50 @@
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
3 |
-
from transformers import AutoTokenizer
|
4 |
-
from langchain.memory import
|
5 |
-
from langchain.schema import HumanMessage, AIMessage
|
6 |
|
7 |
-
#
|
8 |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
9 |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
10 |
|
11 |
-
|
|
|
12 |
|
13 |
-
#
|
14 |
with open("prompt.txt", "r") as file:
|
15 |
nvc_prompt_template = file.read()
|
16 |
|
17 |
-
# Initialize LangChain Memory
|
18 |
-
memory =
|
19 |
|
20 |
def count_tokens(text: str) -> int:
|
|
|
21 |
return len(tokenizer.encode(text))
|
22 |
|
23 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
truncated_messages = []
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
32 |
else:
|
33 |
break
|
34 |
|
@@ -36,38 +52,39 @@ def truncate_history(messages, max_length):
|
|
36 |
|
37 |
def respond(
|
38 |
message,
|
39 |
-
history,
|
40 |
system_message,
|
41 |
max_tokens,
|
42 |
temperature,
|
43 |
top_p,
|
44 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
formatted_system_message = nvc_prompt_template
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
# Truncate history to ensure it fits within context window
|
52 |
-
max_history_tokens = MAX_CONTEXT_LENGTH - max_tokens - count_tokens(formatted_system_message) - 100
|
53 |
-
truncated_chat_history = truncate_history(chat_history, max_history_tokens)
|
54 |
|
55 |
-
#
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
-
#
|
61 |
-
|
62 |
-
for msg in messages:
|
63 |
-
role = "system" if isinstance(msg, SystemMessage) else "user" if isinstance(msg, HumanMessage) else "assistant"
|
64 |
-
content = f"<|{role}|>\n{msg.content}</s>"
|
65 |
-
formatted_messages.append({"role": role, "content": content})
|
66 |
|
67 |
response = ""
|
68 |
try:
|
69 |
for chunk in client.chat_completion(
|
70 |
-
|
71 |
max_tokens=max_tokens,
|
72 |
stream=True,
|
73 |
temperature=temperature,
|
@@ -76,14 +93,13 @@ def respond(
|
|
76 |
token = chunk.choices[0].delta.content
|
77 |
response += token
|
78 |
yield response
|
79 |
-
|
80 |
-
# Save AI's response in LangChain memory
|
81 |
-
memory.chat_memory.add_ai_message(response)
|
82 |
-
|
83 |
except Exception as e:
|
84 |
print(f"An error occurred: {e}")
|
85 |
yield "I'm sorry, I encountered an error. Please try again."
|
86 |
|
|
|
|
|
|
|
87 |
# --- Gradio Interface ---
|
88 |
demo = gr.ChatInterface(
|
89 |
respond,
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
3 |
+
from transformers import AutoTokenizer # Import the tokenizer
|
4 |
+
from langchain.memory import ConversationBufferMemory
|
5 |
+
from langchain.schema import HumanMessage, AIMessage
|
6 |
|
7 |
+
# Use the appropriate tokenizer for your model.
|
8 |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
9 |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
10 |
|
11 |
+
# Define a maximum context length (tokens). Check your model's documentation!
|
12 |
+
MAX_CONTEXT_LENGTH = 4096 # Example: Adjust this based on your model!
|
13 |
|
14 |
+
# Read the default prompt from a file
|
15 |
with open("prompt.txt", "r") as file:
|
16 |
nvc_prompt_template = file.read()
|
17 |
|
18 |
+
# Initialize LangChain Conversation Memory
|
19 |
+
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
20 |
|
21 |
def count_tokens(text: str) -> int:
|
22 |
+
"""Counts the number of tokens in a given string."""
|
23 |
return len(tokenizer.encode(text))
|
24 |
|
25 |
+
def truncate_memory(memory, system_message: str, max_length: int):
|
26 |
+
"""
|
27 |
+
Truncates the conversation memory messages to fit within the maximum token limit.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
memory: The LangChain conversation memory object.
|
31 |
+
system_message: The system message.
|
32 |
+
max_length: The maximum number of tokens allowed.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
A list of messages (as dicts with role and content) that fit within the token limit.
|
36 |
+
"""
|
37 |
truncated_messages = []
|
38 |
+
system_tokens = count_tokens(system_message)
|
39 |
+
current_length = system_tokens
|
40 |
+
|
41 |
+
# Iterate backwards through the memory (newest to oldest)
|
42 |
+
for msg in reversed(memory.chat_memory.messages):
|
43 |
+
tokens = count_tokens(msg.content)
|
44 |
+
if current_length + tokens <= max_length:
|
45 |
+
role = "user" if isinstance(msg, HumanMessage) else "assistant"
|
46 |
+
truncated_messages.insert(0, {"role": role, "content": msg.content})
|
47 |
+
current_length += tokens
|
48 |
else:
|
49 |
break
|
50 |
|
|
|
52 |
|
53 |
def respond(
|
54 |
message,
|
55 |
+
history: list[tuple[str, str]], # Required by Gradio but we now use LangChain memory
|
56 |
system_message,
|
57 |
max_tokens,
|
58 |
temperature,
|
59 |
top_p,
|
60 |
):
|
61 |
+
"""
|
62 |
+
Responds to a user message while maintaining conversation history via LangChain memory.
|
63 |
+
It builds the prompt with a system message and the (truncated) conversation history,
|
64 |
+
streams the response from the client, and finally updates the memory with the new response.
|
65 |
+
"""
|
66 |
+
# Use your prompt template as the system message.
|
67 |
formatted_system_message = nvc_prompt_template
|
68 |
|
69 |
+
# Prepare and add the new user message (with your special tokens) to memory.
|
70 |
+
new_user_message = f"<|user|>\n{message}</s>"
|
71 |
+
memory.chat_memory.add_message(HumanMessage(content=new_user_message))
|
|
|
|
|
|
|
|
|
72 |
|
73 |
+
# Truncate memory to ensure the context fits within the maximum token length (reserve space for generation).
|
74 |
+
truncated_history = truncate_memory(
|
75 |
+
memory, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100
|
76 |
+
)
|
77 |
+
# Ensure the current user message is present at the end.
|
78 |
+
if not truncated_history or truncated_history[-1]["content"] != new_user_message:
|
79 |
+
truncated_history.append({"role": "user", "content": new_user_message})
|
80 |
|
81 |
+
# Build the full message list: system prompt + conversation history.
|
82 |
+
messages = [{"role": "system", "content": formatted_system_message}] + truncated_history
|
|
|
|
|
|
|
|
|
83 |
|
84 |
response = ""
|
85 |
try:
|
86 |
for chunk in client.chat_completion(
|
87 |
+
messages,
|
88 |
max_tokens=max_tokens,
|
89 |
stream=True,
|
90 |
temperature=temperature,
|
|
|
93 |
token = chunk.choices[0].delta.content
|
94 |
response += token
|
95 |
yield response
|
|
|
|
|
|
|
|
|
96 |
except Exception as e:
|
97 |
print(f"An error occurred: {e}")
|
98 |
yield "I'm sorry, I encountered an error. Please try again."
|
99 |
|
100 |
+
# Once the full response is generated, add it to the LangChain memory.
|
101 |
+
memory.chat_memory.add_message(AIMessage(content=f"<|assistant|>\n{response}</s>"))
|
102 |
+
|
103 |
# --- Gradio Interface ---
|
104 |
demo = gr.ChatInterface(
|
105 |
respond,
|