import os import gradio as gr from huggingface_hub import InferenceClient from transformers import AutoTokenizer # Import the tokenizer from langchain.memory import ConversationBufferMemory from langchain.schema import HumanMessage, AIMessage # Load HF token from environment variables. HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable not set") # Use the appropriate tokenizer for your model. tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B") # Instantiate the client with the new inference mechanism. client = InferenceClient( provider="hf-inference", api_key=HF_TOKEN ) # Define a maximum context length (tokens). Adjust this based on your model's requirements. MAX_CONTEXT_LENGTH = 4096 # Read the default prompt from a file. with open("prompt.txt", "r") as file: nvc_prompt_template = file.read() # Initialize LangChain Conversation Memory. memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) def count_tokens(text: str) -> int: """Counts the number of tokens in a given string.""" return len(tokenizer.encode(text)) def truncate_memory(memory, system_message: str, max_length: int): """ Truncates the conversation memory messages to fit within the maximum token limit. Args: memory: The LangChain conversation memory object. system_message: The system message. max_length: The maximum number of tokens allowed. Returns: A list of messages (as dicts with role and content) that fit within the token limit. """ truncated_messages = [] system_tokens = count_tokens(system_message) current_length = system_tokens # Iterate backwards through the memory (newest to oldest). for msg in reversed(memory.chat_memory.messages): tokens = count_tokens(msg.content) if current_length + tokens <= max_length: role = "user" if isinstance(msg, HumanMessage) else "assistant" truncated_messages.insert(0, {"role": role, "content": msg.content}) current_length += tokens else: break return truncated_messages def respond( message, history: list[tuple[str, str]], # Required by Gradio but we now use LangChain memory. system_message, max_tokens, temperature, top_p, ): """ Responds to a user message while maintaining conversation history via LangChain memory. It builds the prompt with a system message and the (truncated) conversation history, streams the response from the client, and finally updates the memory with the new response. """ # Use your prompt template as the system message. formatted_system_message = nvc_prompt_template # Prepare and add the new user message (with your special tokens) to memory. new_user_message = f"<|user|>\n{message}" memory.chat_memory.add_message(HumanMessage(content=new_user_message)) # Truncate memory to ensure the context fits within the maximum token length (reserve space for generation). truncated_history = truncate_memory( memory, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100 ) # Ensure the current user message is present at the end. if not truncated_history or truncated_history[-1]["content"] != new_user_message: truncated_history.append({"role": "user", "content": new_user_message}) # Build the full message list: system prompt + conversation history. messages = [{"role": "system", "content": formatted_system_message}] + truncated_history response = "" try: stream = client.chat.completions.create( model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", messages=messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ) for chunk in stream: token = chunk.choices[0].delta.content response += token yield response except Exception as e: print(f"An error occurred: {e}") yield "I'm sorry, I encountered an error. Please try again." # Once the full response is generated, add it to the LangChain memory. memory.chat_memory.add_message(AIMessage(content=f"<|assistant|>\n{response}")) # --- Gradio Interface --- demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value=nvc_prompt_template, label="System message", visible=True), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), ], ) if __name__ == "__main__": demo.launch()