File size: 4,881 Bytes
490ab38
e00ad77
 
0047abe
 
 
e00ad77
490ab38
 
 
 
 
0047abe
490ab38
 
 
 
16555dd
490ab38
 
58bcb23
490ab38
 
58bcb23
490ab38
ca509cb
 
15152ff
490ab38
0047abe
4282ccc
58bcb23
0047abe
58bcb23
 
0047abe
 
 
 
 
 
 
 
 
 
 
 
4282ccc
0047abe
 
 
490ab38
0047abe
 
 
 
 
 
cdfa6da
4282ccc
58bcb23
4282ccc
58bcb23
f33cc36
 
490ab38
1221286
f33cc36
 
 
 
0047abe
 
 
 
 
 
15152ff
 
0047abe
 
 
f33cc36
0047abe
 
 
 
 
 
 
15152ff
0047abe
 
e00ad77
cdfa6da
fa909a7
490ab38
 
 
ca509cb
 
 
 
490ab38
 
ca509cb
 
 
fa909a7
ca509cb
 
15152ff
0047abe
 
 
15152ff
 
 
 
9d6a6b8
ff9a596
15152ff
ca509cb
15152ff
 
c1faa76
 
ca509cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer  # Import the tokenizer
from langchain.memory import ConversationBufferMemory
from langchain.schema import HumanMessage, AIMessage

# Load HF token from environment variables.
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("HF_TOKEN environment variable not set")

# Use the appropriate tokenizer for your model.
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B")

# Instantiate the client with the new inference mechanism.
client = InferenceClient(
    provider="hf-inference",
    api_key=HF_TOKEN
)

# Define a maximum context length (tokens). Adjust this based on your model's requirements.
MAX_CONTEXT_LENGTH = 4096

# Read the default prompt from a file.
with open("prompt.txt", "r") as file:
    nvc_prompt_template = file.read()

# Initialize LangChain Conversation Memory.
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

def count_tokens(text: str) -> int:
    """Counts the number of tokens in a given string."""
    return len(tokenizer.encode(text))

def truncate_memory(memory, system_message: str, max_length: int):
    """
    Truncates the conversation memory messages to fit within the maximum token limit.
    
    Args:
        memory: The LangChain conversation memory object.
        system_message: The system message.
        max_length: The maximum number of tokens allowed.
        
    Returns:
        A list of messages (as dicts with role and content) that fit within the token limit.
    """
    truncated_messages = []
    system_tokens = count_tokens(system_message)
    current_length = system_tokens

    # Iterate backwards through the memory (newest to oldest).
    for msg in reversed(memory.chat_memory.messages):
        tokens = count_tokens(msg.content)
        if current_length + tokens <= max_length:
            role = "user" if isinstance(msg, HumanMessage) else "assistant"
            truncated_messages.insert(0, {"role": role, "content": msg.content})
            current_length += tokens
        else:
            break

    return truncated_messages

def respond(
    message,
    history: list[tuple[str, str]],  # Required by Gradio but we now use LangChain memory.
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """
    Responds to a user message while maintaining conversation history via LangChain memory.
    It builds the prompt with a system message and the (truncated) conversation history,
    streams the response from the client, and finally updates the memory with the new response.
    """
    # Use your prompt template as the system message.
    formatted_system_message = nvc_prompt_template

    # Prepare and add the new user message (with your special tokens) to memory.
    new_user_message = f"<|user|>\n{message}</s>"
    memory.chat_memory.add_message(HumanMessage(content=new_user_message))

    # Truncate memory to ensure the context fits within the maximum token length (reserve space for generation).
    truncated_history = truncate_memory(
        memory, formatted_system_message, MAX_CONTEXT_LENGTH - max_tokens - 100
    )
    # Ensure the current user message is present at the end.
    if not truncated_history or truncated_history[-1]["content"] != new_user_message:
        truncated_history.append({"role": "user", "content": new_user_message})

    # Build the full message list: system prompt + conversation history.
    messages = [{"role": "system", "content": formatted_system_message}] + truncated_history

    response = ""
    try:
        stream = client.chat.completions.create(
            model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
            messages=messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        )
        for chunk in stream:
            token = chunk.choices[0].delta.content
            response += token
            yield response
    except Exception as e:
        print(f"An error occurred: {e}")
        yield "I'm sorry, I encountered an error. Please try again."

    # Once the full response is generated, add it to the LangChain memory.
    memory.chat_memory.add_message(AIMessage(content=f"<|assistant|>\n{response}</s>"))

# --- Gradio Interface ---
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value=nvc_prompt_template, label="System message", visible=True),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()