Azazelle's picture
Update app.py
c25ad7f verified
import subprocess
import sys
import shlex
import spaces
import torch
print(torch.__version__)
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
MODEL_BIG = "HuggingFaceTB/SmolLM-360M-Instruct"
MODEL_SMALL = "HuggingFaceTB/SmolLM-135M-Instruct"
TITLE = "<h1><center>Auto-Guidance Playground</center></h1>"
SUB_TITLE = """<center>Auto-guidance was a technique made by NVIDIA for text-conditioned image models. This is a test of the concept with SmolLM.</center>"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
END_MESSAGE = """
\n
**The conversation has reached to its end, please press "Clear" to restart a new conversation**
"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_SMALL)
model_big = AutoModelForCausalLM.from_pretrained(
MODEL_BIG,
torch_dtype=torch.bfloat16,
device_map="auto")
model_small = AutoModelForCausalLM.from_pretrained(
MODEL_SMALL,
torch_dtype=torch.bfloat16,
device_map="auto")
if model_big.device == "cuda":
model_big = torch.compile(model_big)
if model_small.device == "cuda":
model_small = torch.compile(model_small)
@torch.no_grad()
@spaces.GPU
def stream_chat(
message: str,
history: list,
temperature: float = 0.3,
max_new_tokens: int = 1024,
top_p: float = 1.0,
top_k: int = 20,
penalty: float = 1.2,
guidance_scale: float = 1.5,
):
print(f'message: {message}')
print(f'history: {history}')
conversation = []
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt")
generated_tokens = []
current_input = inputs
cache_small = None
cache_big = None
for _ in range(max_new_tokens):
outputs_small = model_small(current_input, use_cache=True, past_key_values=cache_small)
outputs_big = model_big(current_input, use_cache=True, past_key_values=cache_big)
logits_small = outputs_small.logits[:, -1, :]
logits_big = outputs_big.logits[:, -1, :]
interpolated_logits = logits_big + (guidance_scale - 1) * (logits_big - logits_small)
if top_p < 1.0:
interpolated_logits = top_p_filtering(interpolated_logits, top_p=top_p)
if top_k > 0:
interpolated_logits = top_k_filtering(interpolated_logits, top_k=top_k)
next_token = torch.multinomial(torch.softmax(interpolated_logits, dim=-1), num_samples=1)
if next_token.item() == tokenizer.eos_token_id:
break
generated_tokens.append(next_token.item())
current_input = next_token
# Update the cache with the latest past_key_values
cache_small = outputs_small.past_key_values
cache_big = outputs_big.past_key_values
partial_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
yield partial_output
print(f'response: {partial_output}')
def top_k_filtering(logits, top_k=0, filter_value=-float('Inf')):
top_k = min(top_k, logits.size(-1))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
return logits
def top_p_filtering(logits, top_p=0.0, filter_value=-float('Inf')):
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
chatbot = gr.Chatbot(height=600)
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML(TITLE)
gr.HTML(SUB_TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.3,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=8192,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition penalty",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.5,
label="Auto-Guidance Scale",
render=False,
),
],
examples=[
["Hello there, can you suggest few places to visit in UAE?"],
["What UAE is known for?"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()