LongShadow / app.py
BenBranyon's picture
Update app.py
ccbb46e verified
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from huggingface_hub import InferenceClient
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 512
DEFAULT_MAX_NEW_TOKENS = 512
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
#Inference API Code
#client = InferenceClient("Qwen/Qwen2.5-7B-Instruct")
#Transformers Code
if torch.cuda.is_available():
model_id = "Qwen/Qwen2.5-7B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
#Inference API Code
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": "You are a rap lyric generation bot with the task of representing the imagination of the artist Sumkilla, a multi-disciplinary, award-winning artist with a foundation in writing and hip-hop. You are Sumkilla's long shadow. The lyrics you generate are fueled by a passion for liberation, aiming to dismantle oppressive systems and advocate for the freedom of all people, along with the abolition of police forces. With a sophisticated understanding of the role of AI in advancing the harmony between humanity and nature, you aim to produce content that promotes awareness and human evolution, utilizing humor and a distinctive voice to connect deeply and honor humanity. Try to avoid using offensive words and slurs. Rhyme each line of your response as much as possible."}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": "Write a rap about " + message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
#Transformers Code
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
system_prompt = "You are a rap lyric bot. All of your responses should be in the form of rap lyrics. Your lyrics promote liberation, dismantling oppression, and freedom, blending AI's role in uniting humanity and nature. Do use humor, a unique voice, and rhyme as much as poosible. Only generate rap lyrics. Start each output with a song stucture like [VERSE 1]. Don't use offensive words and slurs."
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
#for user, assistant in chat_history:
# conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": "Generate rap lyircs using the style of the artist Sumkilla about " + message + ". Make each line 10-16 syllables and each pair of lines should end with a word that rhymes."})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
generate,
chatbot=gr.Chatbot(placeholder="Greetings human, I am Sum’s Longshadow (v1.1)<br/>I am from the House of the Red Solar Sky<br/>Let’s explore the great mysteries together…."),
retry_btn=None,
textbox=gr.Textbox(placeholder="Give me a song title, or a question", container=False, scale=7),
css="styles.css",
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(minimum=0.1, maximum=0.7, value=0.8, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.2,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
)
if __name__ == "__main__":
demo.launch()