QuantizedGrok-1 / app.py
Tonic's picture
Update app.py
5387095 verified
raw
history blame
3.72 kB
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaTokenizerFast, BitsAndBytesConfig
import torch
import sentencepiece
import os
import gradio as gr
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:120'
model_id = "eastwind/grok-1-hf-4bit"
tokenizer_id = "Xenova/grok-1-tokenizer"
# tokenizer_path = "./"
# eos_token_id = 7
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
DESCRIPTION = """
# Welcome to Tonic's Grok-1
"""
# tokenizer = AutoTokenizer.from_pretrained(model_id, device_map="auto", trust_remote_code=True)
tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_id, device_map="cuda", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config = quantization_config, device_map="cuda", trust_remote_code=True)
def format_prompt(user_message, system_message="You are Grok-1, an AI language model created by Tonic-AI. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and follow ethical guidelines and promote positive behavior.\n\n"):
# prompt = f"<|im_start|>assistant\n{system_message}<|im_end|>\n<|im_start|>\nuser\n{user_message}<|im_end|>\nassistant\n"
prompt = f"{system_message}{user_message}"
return prompt
@spaces.GPU
def predict(message, system_message, max_new_tokens=600, temperature=3.5, top_p=0.9, top_k=40, do_sample=False):
formatted_prompt = format_prompt(message, system_message)
input_ids = tokenizer.encode(formatted_prompt, return_tensors='pt')
input_ids = input_ids.to(model.device)
response_ids = model.generate(
input_ids,
max_length=max_new_tokens + input_ids.shape[1],
temperature=temperature,
top_p=top_p,
top_k=top_k,
no_repeat_ngram_size=9,
pad_token_id=tokenizer.eos_token_id,
do_sample=do_sample
)
response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
truncate_str = "<|im_end|>"
if truncate_str and truncate_str in response:
response = response.split(truncate_str)[0]
return [("bot", response)]
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
textbox = gr.Textbox(placeholder='Your Message Here', label='Your Message', lines=2)
system_prompt = gr.Textbox(placeholder='Provide a System Prompt In The First Person', label='System Prompt', lines=2, value="You are YiTonic, an AI language model created by Tonic-AI. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.")
with gr.Group():
chatbot = gr.Chatbot(label='Grok-1🀯')
with gr.Group():
submit_button = gr.Button('Submit', variant='primary')
with gr.Accordion(label='Advanced options', open=False):
max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=55000, step=1, value=4056)
temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=1.2)
top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=40)
do_sample_checkbox = gr.Checkbox(label='Disable for faster inference', value=True)
submit_button.click(
fn=predict,
inputs=[textbox, system_prompt, max_new_tokens, temperature, top_p, top_k, do_sample_checkbox],
outputs=chatbot
)
demo.launch()