Spaces:
Sleeping
Sleeping
import os | |
# import threading | |
import streamlit as st | |
from itertools import tee | |
from model import InferenceBuilder | |
# from chain import ChainBuilder | |
# DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST") | |
# DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN") | |
# remove these secrets from the container | |
# VS_ENDPOINT_NAME = os.environ.get("VS_ENDPOINT_NAME") | |
# VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME") | |
# if DATABRICKS_HOST is None: | |
# raise ValueError("DATABRICKS_HOST environment variable must be set") | |
# if DATABRICKS_TOKEN is None: | |
# raise ValueError("DATABRICKS_TOKEN environment variable must be set") | |
MODEL_AVATAR_URL= "./iphone_robot.png" | |
MAX_CHAT_TURNS = 10 # limit this for preliminary testing | |
MSG_MAX_TURNS_EXCEEDED = f"Sorry! The CyberSolve LinAlg playground is limited to {MAX_CHAT_TURNS} turns in a single history. Click the 'Clear Chat' button or refresh the page to start a new conversation." | |
# MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground" | |
EXAMPLE_PROMPTS = [ | |
"How is a data lake used at Vanderbilt University Medical Center?", | |
"In a table, what are some of the greatest hurdles to healthcare in the United States?", | |
"What does EDW stand for in the context of Vanderbilt University Medical Center?", | |
"Code a sql statement that can query a database named 'VUMC'.", | |
"Write a short story about a country concert in Nashville, Tennessee.", | |
"Tell me about maximum out-of-pocket costs in healthcare.", | |
] | |
TITLE = "CyberSolve LinAlg 1.2" | |
DESCRIPTION= """Welcome to the CyberSolve LinAlg 1.2 demo! \n | |
**Overview and Usage**: This π€ Space is designed to demo the abilities of the CyberSolve LinAlg 1.2 text-to-text language model. | |
and is augmented with additional organization-specific knowledge. Particularly, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center | |
terms like **EDW**, **HCERA**, **NRHA** and **thousands more**. (Ask the assistant if you don't know what any of these terms mean!) On the left is a sidebar of **Examples**; | |
click any of these examples to issue the corresponding query to the AI. | |
**Feedback**: Feedback is welcomed, encouraged, and invaluable! To give feedback in regards to one of the model's responses, click the **Give Feedback on Last Response** button just below | |
the user input bar. This allows you to provide either positive or negative feedback in regards to the model's most recent response. A **Feedback Form** will appear above the model's title. | |
Please be sure to select either π or π before adding additional notes about your choice. Be as brief or as detailed as you like! Note that you are making a difference; this | |
feedback allows us to later improve this model for your usage through a training technique known as reinforcement learning through human feedback. \n | |
Please provide any additional, larger feedback, ideas, or issues to the email: **[email protected]**. Happy inference!""" | |
GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation." | |
# # To prevent streaming too fast, chunk the output into TOKEN_CHUNK_SIZE chunks | |
TOKEN_CHUNK_SIZE = 1 # test this number | |
# if TOKEN_CHUNK_SIZE_ENV is not None: | |
# TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV) | |
QUEUE_SIZE = 20 # maximize this value for adding enough places in the global queue? | |
# if QUEUE_SIZE_ENV is not None: | |
# QUEUE_SIZE = int(QUEUE_SIZE_ENV) | |
# @st.cache_resource | |
# def get_global_semaphore(): | |
# return threading.BoundedSemaphore(QUEUE_SIZE) | |
# global_semaphore = get_global_semaphore() | |
st.set_page_config(layout="wide") | |
st.title(TITLE) | |
# st.image("sunrise.jpg", caption="Sunrise by the mountains") # TODO add a Vanderbilt related picture to the head of our Space! | |
st.markdown(DESCRIPTION) | |
st.markdown("\n") | |
# use this to format later | |
with open("./style.css") as css: | |
st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True) | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
if "feedback" not in st.session_state: | |
st.session_state["feedback"] = [None] | |
def clear_chat_history(): | |
st.session_state["messages"] = [] | |
st.button('Clear Chat', on_click=clear_chat_history) | |
# build our chain outside the working body so that its only instantiated once - simply pass it the chat history for chat completion | |
builder = InferenceBuilder() | |
tokenizer = builder.load_tokenizer() | |
model = builder.load_model() | |
def last_role_is_user(): | |
return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user" | |
def get_last_question(): | |
return st.session_state["messages"][-1]["content"] | |
def text_stream(stream): | |
for chunk in stream: | |
if chunk["content"] is not None: | |
yield chunk["content"] | |
def get_stream_warning_error(stream): | |
error = None | |
warning = None | |
for chunk in stream: | |
if chunk["error"] is not None: | |
error = chunk["error"] | |
if chunk["warning"] is not None: | |
warning = chunk["warning"] | |
return warning, error | |
# # @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3)) | |
# def chain_call(history): | |
# input = {'messages': [{"role": m["role"], "content": m["content"]} for m in history]} | |
# chat_completion = chain.stream(input) | |
# return chat_completion | |
def model_inference(messages): | |
# input_ids = tokenizer(get_last_question(), return_tensors="pt").input_ids.to("cuda") # tokenize the input and put it on the GPU | |
input_ids = tokenizer(get_last_question(), return_tensors="pt").input_ids # move to GPU eventually | |
outputs = model.generate(input_ids) | |
for chunk in tokenizer.decode(outputs[0], skip_special_tokens=True): | |
yield chunk # yield each chunk of the predicted string character by character | |
def write_response(): | |
stream = chat_completion(st.session_state["messages"]) | |
content_stream, error_stream = tee(stream) | |
response = st.write_stream(text_stream(content_stream)) | |
stream_warning, stream_error = get_stream_warning_error(error_stream) | |
if stream_warning is not None: | |
st.warning(stream_warning,icon="β οΈ") | |
if stream_error is not None: | |
st.error(stream_error,icon="π¨") | |
# if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream | |
if isinstance(response, list): | |
response = None | |
return response, stream_warning, stream_error | |
def chat_completion(messages): | |
if (len(messages)-1)//2 >= MAX_CHAT_TURNS: | |
yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None} | |
return | |
chat_completion = None | |
error = None | |
# *** TODO add code for implementing a global queue with a bounded semaphore? | |
# wait to be in queue | |
# with global_semaphore: | |
# try: | |
# chat_completion = chat_api_call(history_dbrx_format) | |
# except Exception as e: | |
# error = e | |
# chat_completion = chain_call(history_dbrx_format) | |
chat_completion = model_inference(messages) | |
if error is not None: | |
yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None} | |
print(error) | |
return | |
max_token_warning = None | |
partial_message = "" | |
chunk_counter = 0 | |
for chunk in chat_completion: | |
if chunk is not None: | |
chunk_counter += 1 | |
partial_message += chunk | |
if chunk_counter % TOKEN_CHUNK_SIZE == 0: | |
chunk_counter = 0 | |
yield {"content": partial_message, "error": None, "warning": None} | |
partial_message = "" | |
# if chunk.choices[0].finish_reason == "length": | |
# max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS | |
yield {"content": partial_message, "error": None, "warning": max_token_warning} | |
# if assistant is the last message, we need to prompt the user | |
# if user is the last message, we need to retry the assistant. | |
def handle_user_input(user_input): | |
with history: | |
response, stream_warning, stream_error = [None, None, None] | |
if last_role_is_user(): | |
# retry the assistant if the user tries to send a new message | |
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL): | |
response, stream_warning, stream_error = write_response() | |
else: | |
st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None}) | |
with st.chat_message("user", avatar="π§βπ»"): | |
st.markdown(user_input) | |
# stream = chat_completion(st.session_state["messages"]) | |
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL): | |
response, stream_warning, stream_error = write_response() | |
st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error}) | |
def feedback(): | |
with st.form("feedback_form"): | |
st.title("Feedback Form") | |
st.markdown("Please select either π or π before providing a reason for your review of the most recent response. Dont forget to click submit!") | |
rating = st.feedback() | |
feedback = st.text_input("Please detail your feedback: ") | |
# implement a method for writing these responses to storage! | |
submitted = st.form_submit_button("Submit Feedback") | |
main = st.container() | |
with main: | |
if st.session_state["feedback"][-1] is not None: # TODO clean this up in a fn? | |
st.markdown("Thank you! Feedback received! Type a new message to continue your conversation.") | |
history = st.container(height=400) | |
with history: | |
for message in st.session_state["messages"]: | |
avatar = "π§βπ»" | |
if message["role"] == "assistant": | |
avatar = MODEL_AVATAR_URL | |
with st.chat_message(message["role"], avatar=avatar): | |
if message["content"] is not None: | |
st.markdown(message["content"]) | |
if message["error"] is not None: | |
st.error(message["error"],icon="π¨") | |
if message["warning"] is not None: | |
st.warning(message["warning"],icon="β οΈ") | |
if prompt := st.chat_input("Type a message!", max_chars=5000): | |
handle_user_input(prompt) | |
st.markdown("\n") #add some space for iphone users | |
gave_feedback = st.button('Give Feedback on Last Response', on_click=feedback) | |
if gave_feedback: # TODO clean up the conditions here with a function | |
st.session_state["feedback"].append("given") | |
else: | |
st.session_state["feedback"].append(None) | |
with st.sidebar: | |
with st.container(): | |
st.title("Examples") | |
for prompt in EXAMPLE_PROMPTS: | |
st.button(prompt, args=(prompt,), on_click=handle_user_input) |