File size: 4,881 Bytes
490ab38 e00ad77 0047abe e00ad77 490ab38 0047abe 490ab38 16555dd 490ab38 58bcb23 490ab38 58bcb23 490ab38 ca509cb 15152ff 490ab38 0047abe 4282ccc 58bcb23 0047abe 58bcb23 0047abe 4282ccc 0047abe 490ab38 0047abe cdfa6da 4282ccc 58bcb23 4282ccc 58bcb23 f33cc36 490ab38 1221286 f33cc36 0047abe 15152ff 0047abe f33cc36 0047abe 15152ff 0047abe e00ad77 cdfa6da fa909a7 490ab38 ca509cb 490ab38 ca509cb fa909a7 ca509cb 15152ff 0047abe 15152ff 9d6a6b8 ff9a596 15152ff ca509cb 15152ff c1faa76 ca509cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer # Import the tokenizer
from langchain.memory import ConversationBufferMemory
from langchain.schema import HumanMessage, AIMessage
# Load HF token from environment variables.
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable not set")
# Use the appropriate tokenizer for your model.
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B")
# Instantiate the client with the new inference mechanism.
client = InferenceClient(
provider="hf-inference",
api_key=HF_TOKEN
)
# Define a maximum context length (tokens). Adjust this based on your model's requirements.
MAX_CONTEXT_LENGTH = 4096
# Read the default prompt from a file.
with open("prompt.txt", "r") as file:
nvc_prompt_template = file.read()
# Initialize LangChain Conversation Memory.
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
def count_tokens(text: str) -> int:
"""Counts the number of tokens in a given string."""
return len(tokenizer.encode(text))
def truncate_memory(memory, system_message: str, max_length: int):
"""
Truncates the conversation memory messages to fit within the maximum token limit.
Args:
memory: The LangChain conversation memory object.
system_message: The system message.
max_length: The maximum number of tokens allowed.
Returns:
A list of messages (as dicts with role and content) that fit within the token limit.
"""
truncated_messages = []
system_tokens = count_tokens(system_message)
current_length = system_tokens
# Iterate backwards through the memory (newest to oldest).
for msg in reversed(memory.chat_memory.messages):
tokens = count_tokens(msg.content)
if current_length + tokens <= max_length:
role = "user" if isinstance(msg, HumanMessage) else "assistant"
truncated_messages.insert(0, {"role": role, "content": msg.content})
current_length += tokens
else:
break
return truncated_messages
def respond(
message,
history: list[tuple[str, str]], # Required by Gradio but we now use LangChain memory.
system_message,
max_tokens,
temperature,
top_p,
):
"""
Responds to a user message while maintaining conversation history via LangChain memory.
It builds the prompt with a system message and the (truncated) conversation history,
streams the response from the client, and finally updates the memory with the new response.
"""
# Use your prompt template as the system message.
formatted_system_message = nvc_prompt_template
# Prepare and add the new user message (with your special tokens) to memory.
new_user_message = f"<|user|>\n{message}</s>"
memory.chat_memory.add_message(HumanMessage(content=new_user_message))
# Truncate memory to ensure the context fits within the maximum token length (reserve space for generation).
truncated_history = truncate_memory(
memory, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100
)
# Ensure the current user message is present at the end.
if not truncated_history or truncated_history[-1]["content"] != new_user_message:
truncated_history.append({"role": "user", "content": new_user_message})
# Build the full message list: system prompt + conversation history.
messages = [{"role": "system", "content": formatted_system_message}] + truncated_history
response = ""
try:
stream = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
messages=messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
)
for chunk in stream:
token = chunk.choices[0].delta.content
response += token
yield response
except Exception as e:
print(f"An error occurred: {e}")
yield "I'm sorry, I encountered an error. Please try again."
# Once the full response is generated, add it to the LangChain memory.
memory.chat_memory.add_message(AIMessage(content=f"<|assistant|>\n{response}</s>"))
# --- Gradio Interface ---
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value=nvc_prompt_template, label="System message", visible=True),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch()
|