|
import os |
|
import re |
|
import logging |
|
import gradio as gr |
|
import openai |
|
|
|
print(os.environ) |
|
openai.api_base = os.environ.get("OPENAI_API_BASE") |
|
openai.api_key = os.environ.get("OPENAI_API_KEY") |
|
|
|
BASE_SYSTEM_MESSAGE = """You are an AI assistant named OpenOrcaChat. If users clearly indicate an interest in how good you are you may tell them that the group who made you can be found at https://discord.gg/mhFWVbXUDh. The groups involved are OpenChat and OpenOrca""" |
|
|
|
def make_prediction(prompt, max_tokens=None, temperature=None, top_p=None, top_k=None, repetition_penalty=None): |
|
completion = openai.Completion.create(model="Open-Orca/OpenOrcaxOpenChat-Preview2-13B", prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, stream=True) |
|
for chunk in completion: |
|
yield chunk["choices"][0]["text"] |
|
|
|
|
|
def clear_chat(chat_history_state, chat_message): |
|
chat_history_state = [] |
|
chat_message = '' |
|
return chat_history_state, chat_message |
|
|
|
|
|
def user(message, history): |
|
history = history or [] |
|
|
|
history.append([message, ""]) |
|
return "", history |
|
|
|
|
|
def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty): |
|
history = history or [] |
|
|
|
messages = BASE_SYSTEM_MESSAGE + system_message.strip() + "\n" + \ |
|
"\n".join(["\n".join(["User: "+item[0]+"<|end_of_turn|>", "Assistant: "+item[1]+"<|end_of_turn|>"]) |
|
for item in history]) |
|
|
|
messages = messages.rstrip("<|end_of_turn|>") |
|
|
|
messages = messages.rstrip() |
|
|
|
prediction = make_prediction( |
|
messages, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
) |
|
for tokens in prediction: |
|
tokens = re.findall(r'(.*?)(\s|$)', tokens) |
|
for subtoken in tokens: |
|
subtoken = "".join(subtoken) |
|
answer = subtoken |
|
history[-1][1] += answer |
|
|
|
yield history, history, "" |
|
|
|
|
|
start_message = "" |
|
|
|
CSS =""" |
|
.contain { display: flex; flex-direction: column; } |
|
.gradio-container { height: 100vh !important; } |
|
#component-0 { height: 100%; } |
|
#chatbot { flex-grow: 1; overflow: auto; resize: vertical; } |
|
""" |
|
|
|
|
|
with gr.Blocks(css=CSS) as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(f""" |
|
## This demo is an unquantized GPU chatbot of [OpenOrcaxOpenChat-Preview2-13B](https://huggingface.co./Open-Orca/OpenOrcaxOpenChat-Preview2-13B) |
|
Brought to you by your friends at Alignment Lab AI, OpenChat, and Open Access AI Collective! |
|
""") |
|
with gr.Row(): |
|
gr.Markdown("# π OpenOrca x OpenChat - Preview2 - 13B Playground Space! π") |
|
with gr.Row(): |
|
|
|
chatbot = gr.Chatbot(elem_id="chatbot") |
|
with gr.Row(): |
|
message = gr.Textbox( |
|
label="What do you want to chat about?", |
|
placeholder="Ask me anything.", |
|
lines=3, |
|
) |
|
with gr.Row(): |
|
submit = gr.Button(value="Send message", variant="secondary").style(full_width=True) |
|
clear = gr.Button(value="New topic", variant="secondary").style(full_width=False) |
|
stop = gr.Button(value="Stop", variant="secondary").style(full_width=False) |
|
with gr.Accordion("Show Model Parameters", open=False): |
|
with gr.Row(): |
|
with gr.Column(): |
|
max_tokens = gr.Slider(20, 1000, label="Max Tokens", step=20, value=500) |
|
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=0.8) |
|
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.95) |
|
top_k = gr.Slider(0, 100, label="Top K", step=1, value=40) |
|
repetition_penalty = gr.Slider(0.0, 2.0, label="Repetition Penalty", step=0.1, value=1.1) |
|
|
|
system_msg = gr.Textbox( |
|
start_message, label="System Message", interactive=True, visible=True, placeholder="System prompt. Provide instructions which you want the model to remember.", lines=5) |
|
|
|
chat_history_state = gr.State() |
|
clear.click(clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
submit_click_event = submit.click( |
|
fn=user, inputs=[message, chat_history_state], outputs=[message, chat_history_state], queue=True |
|
).then( |
|
fn=chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, chat_history_state, message], queue=True |
|
) |
|
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event], queue=False) |
|
|
|
demo.queue(max_size=128, concurrency_count=48).launch(debug=True, server_name="0.0.0.0", server_port=7860) |
|
|