ruslanmv commited on
Commit
0047abe
·
verified ·
1 Parent(s): 27ad4cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -39
app.py CHANGED
@@ -1,34 +1,50 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- from transformers import AutoTokenizer
4
- from langchain.memory import ConversationBufferWindowMemory
5
- from langchain.schema import HumanMessage, AIMessage, SystemMessage
6
 
7
- # Initialize tokenizer and inference client
8
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
9
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
10
 
11
- MAX_CONTEXT_LENGTH = 4096
 
12
 
13
- # Load prompt from file
14
  with open("prompt.txt", "r") as file:
15
  nvc_prompt_template = file.read()
16
 
17
- # Initialize LangChain Memory (buffer window to keep recent conversation)
18
- memory = ConversationBufferWindowMemory(k=10, return_messages=True)
19
 
20
  def count_tokens(text: str) -> int:
 
21
  return len(tokenizer.encode(text))
22
 
23
- def truncate_history(messages, max_length):
 
 
 
 
 
 
 
 
 
 
 
24
  truncated_messages = []
25
- total_tokens = 0
26
-
27
- for message in reversed(messages):
28
- message_tokens = count_tokens(message.content)
29
- if total_tokens + message_tokens <= max_length:
30
- truncated_messages.insert(0, message)
31
- total_tokens += message_tokens
 
 
 
32
  else:
33
  break
34
 
@@ -36,38 +52,39 @@ def truncate_history(messages, max_length):
36
 
37
  def respond(
38
  message,
39
- history,
40
  system_message,
41
  max_tokens,
42
  temperature,
43
  top_p,
44
  ):
 
 
 
 
 
 
45
  formatted_system_message = nvc_prompt_template
46
 
47
- # Retrieve conversation history from LangChain memory
48
- memory.save_context({"input": message}, {"output": ""})
49
- chat_history = memory.load_memory_variables({})["history"]
50
-
51
- # Truncate history to ensure it fits within context window
52
- max_history_tokens = MAX_CONTEXT_LENGTH - max_tokens - count_tokens(formatted_system_message) - 100
53
- truncated_chat_history = truncate_history(chat_history, max_history_tokens)
54
 
55
- # Construct the messages for inference
56
- messages = [SystemMessage(content=formatted_system_message)]
57
- messages.extend(truncated_chat_history)
58
- messages.append(HumanMessage(content=message))
 
 
 
59
 
60
- # Convert LangChain messages to the format required by HuggingFace client
61
- formatted_messages = []
62
- for msg in messages:
63
- role = "system" if isinstance(msg, SystemMessage) else "user" if isinstance(msg, HumanMessage) else "assistant"
64
- content = f"<|{role}|>\n{msg.content}</s>"
65
- formatted_messages.append({"role": role, "content": content})
66
 
67
  response = ""
68
  try:
69
  for chunk in client.chat_completion(
70
- formatted_messages,
71
  max_tokens=max_tokens,
72
  stream=True,
73
  temperature=temperature,
@@ -76,14 +93,13 @@ def respond(
76
  token = chunk.choices[0].delta.content
77
  response += token
78
  yield response
79
-
80
- # Save AI's response in LangChain memory
81
- memory.chat_memory.add_ai_message(response)
82
-
83
  except Exception as e:
84
  print(f"An error occurred: {e}")
85
  yield "I'm sorry, I encountered an error. Please try again."
86
 
 
 
 
87
  # --- Gradio Interface ---
88
  demo = gr.ChatInterface(
89
  respond,
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from transformers import AutoTokenizer # Import the tokenizer
4
+ from langchain.memory import ConversationBufferMemory
5
+ from langchain.schema import HumanMessage, AIMessage
6
 
7
+ # Use the appropriate tokenizer for your model.
8
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
9
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
10
 
11
+ # Define a maximum context length (tokens). Check your model's documentation!
12
+ MAX_CONTEXT_LENGTH = 4096 # Example: Adjust this based on your model!
13
 
14
+ # Read the default prompt from a file
15
  with open("prompt.txt", "r") as file:
16
  nvc_prompt_template = file.read()
17
 
18
+ # Initialize LangChain Conversation Memory
19
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
20
 
21
  def count_tokens(text: str) -> int:
22
+ """Counts the number of tokens in a given string."""
23
  return len(tokenizer.encode(text))
24
 
25
+ def truncate_memory(memory, system_message: str, max_length: int):
26
+ """
27
+ Truncates the conversation memory messages to fit within the maximum token limit.
28
+
29
+ Args:
30
+ memory: The LangChain conversation memory object.
31
+ system_message: The system message.
32
+ max_length: The maximum number of tokens allowed.
33
+
34
+ Returns:
35
+ A list of messages (as dicts with role and content) that fit within the token limit.
36
+ """
37
  truncated_messages = []
38
+ system_tokens = count_tokens(system_message)
39
+ current_length = system_tokens
40
+
41
+ # Iterate backwards through the memory (newest to oldest)
42
+ for msg in reversed(memory.chat_memory.messages):
43
+ tokens = count_tokens(msg.content)
44
+ if current_length + tokens <= max_length:
45
+ role = "user" if isinstance(msg, HumanMessage) else "assistant"
46
+ truncated_messages.insert(0, {"role": role, "content": msg.content})
47
+ current_length += tokens
48
  else:
49
  break
50
 
 
52
 
53
  def respond(
54
  message,
55
+ history: list[tuple[str, str]], # Required by Gradio but we now use LangChain memory
56
  system_message,
57
  max_tokens,
58
  temperature,
59
  top_p,
60
  ):
61
+ """
62
+ Responds to a user message while maintaining conversation history via LangChain memory.
63
+ It builds the prompt with a system message and the (truncated) conversation history,
64
+ streams the response from the client, and finally updates the memory with the new response.
65
+ """
66
+ # Use your prompt template as the system message.
67
  formatted_system_message = nvc_prompt_template
68
 
69
+ # Prepare and add the new user message (with your special tokens) to memory.
70
+ new_user_message = f"<|user|>\n{message}</s>"
71
+ memory.chat_memory.add_message(HumanMessage(content=new_user_message))
 
 
 
 
72
 
73
+ # Truncate memory to ensure the context fits within the maximum token length (reserve space for generation).
74
+ truncated_history = truncate_memory(
75
+ memory, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100
76
+ )
77
+ # Ensure the current user message is present at the end.
78
+ if not truncated_history or truncated_history[-1]["content"] != new_user_message:
79
+ truncated_history.append({"role": "user", "content": new_user_message})
80
 
81
+ # Build the full message list: system prompt + conversation history.
82
+ messages = [{"role": "system", "content": formatted_system_message}] + truncated_history
 
 
 
 
83
 
84
  response = ""
85
  try:
86
  for chunk in client.chat_completion(
87
+ messages,
88
  max_tokens=max_tokens,
89
  stream=True,
90
  temperature=temperature,
 
93
  token = chunk.choices[0].delta.content
94
  response += token
95
  yield response
 
 
 
 
96
  except Exception as e:
97
  print(f"An error occurred: {e}")
98
  yield "I'm sorry, I encountered an error. Please try again."
99
 
100
+ # Once the full response is generated, add it to the LangChain memory.
101
+ memory.chat_memory.add_message(AIMessage(content=f"<|assistant|>\n{response}</s>"))
102
+
103
  # --- Gradio Interface ---
104
  demo = gr.ChatInterface(
105
  respond,