Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,051 Bytes
c773bb9 48b916e c773bb9 95bc271 c773bb9 48b916e c773bb9 48b916e c773bb9 48b916e c773bb9 95bc271 c773bb9 48b916e c773bb9 48b916e c773bb9 48b916e c773bb9 48b916e c773bb9 48b916e c773bb9 48b916e c773bb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
MODEL = "THUDM/LongWriter-llama3.1-8b"
TITLE = "<h1><center>LongWriter-llama3.1-8b</center></h1>"
PLACEHOLDER = """
<center>
<p>Hi! I'm LongWriter, capable of generating 10,000+ words. How can I assist you today?</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
model = model.eval()
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.5,
max_new_tokens: int = 32768,
top_p: float = 1.0,
top_k: int = 50,
):
print(f'message: {message}')
print(f'history: {history}')
full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
for prompt, answer in history:
full_prompt += f"[INST]{prompt}[/INST]{answer}"
full_prompt += f"[INST]{message}[/INST]"
inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device)
context_length = inputs.input_ids.shape[-1]
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs=inputs.input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML(TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Textbox(
value="You are a helpful assistant capable of generating long-form content.",
label="System Prompt",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.5,
label="Temperature",
render=False,
),
gr.Slider(
minimum=1024,
maximum=32768,
step=1024,
value=32768,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="Top p",
render=False,
),
gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top k",
render=False,
),
],
examples=[
["Write a 5000-word comprehensive guide on machine learning for beginners."],
["Create a detailed 3000-word business plan for a sustainable energy startup."],
["Compose a 2000-word short story set in a futuristic underwater city."],
["Develop a 4000-word research proposal on the potential effects of climate change on global food security."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|