import datetime import os import random import re from io import StringIO import gradio as gr import pandas as pd from huggingface_hub import upload_file from text_generation import Client from dialogues import DialogueTemplate HF_TOKEN = os.environ.get("HF_TOKEN", None) API_TOKEN = os.environ.get("API_TOKEN", None) model2endpoint = { "zephyr-7b-beta": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta", "mistral-7b-instruct-v0.2": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2", "mixtral-8x7b-instruct-v0.1": "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1", "gemma-7b-it": "https://api-inference.huggingface.co/models/google/gemma-7b-it", "llama-7b-chat": "https://api-inference.huggingface.co/models/meta-llama/Llama-2-7b-chat-hf" } model_names = list(model2endpoint.keys()) def randomize_seed_generator(): seed = random.randint(0, 1000000) return seed def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): past = [] for data in chatbot: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith(sep + assistant_name): model_data = sep + assistant_name + model_data past.append(user_data + model_data.rstrip() + sep) if not inputs.startswith(user_name): inputs = user_name + inputs total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() return total_inputs def wrap_html_code(text): pattern = r"<.*?>" matches = re.findall(pattern, text) if len(matches) > 0: return f"```{text}```" else: return text def has_no_history(chatbot, history): return not chatbot and not history def generate( RETRY_FLAG, model_name, system_message, user_message, chatbot, history, temperature, top_k, top_p, max_new_tokens, repetition_penalty, # do_save=True, ): client = Client( model2endpoint[model_name], headers={"Authorization": f"Bearer {API_TOKEN}"}, timeout=60, ) # Don't return meaningless message when the input is empty if not user_message: print("Empty input") if not RETRY_FLAG: history.append(user_message) seed = 42 else: seed = randomize_seed_generator() past_messages = [] for data in chatbot: user_data, model_data = data past_messages.extend( [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] ) if len(past_messages) < 1: dialogue_template = DialogueTemplate( system=system_message, messages=[{"role": "user", "content": user_message}] ) prompt = dialogue_template.get_inference_prompt() else: dialogue_template = DialogueTemplate( system=system_message, messages=past_messages + [{"role": "user", "content": user_message}] ) prompt = dialogue_template.get_inference_prompt() generate_kwargs = { "temperature": temperature, "top_k": top_k, "top_p": top_p, "max_new_tokens": max_new_tokens, } temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, truncate=4096, seed=seed, stop_sequences=["<|end|>"], ) stream = client.generate_stream( prompt, **generate_kwargs, ) output = "" for idx, response in enumerate(stream): if response.token.special: continue output += response.token.text if idx == 0: history.append(" " + output) else: history[-1] = output chat = [ (wrap_html_code(history[i].strip()), wrap_html_code(history[i + 1].strip())) for i in range(0, len(history) - 1, 2) ] # chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] yield chat, history, user_message, "" return chat, history, user_message, "" examples = [ "What are the signs and symptoms of community acquired pneumonia (CAP)?", "What is the treatment for recurrent otitis media?" ] def clear_chat(): return [], [] def delete_last_turn(chat, history): if chat and history: chat.pop(-1) history.pop(-1) history.pop(-1) return chat, history def process_example(args): for [x, y] in generate(args): pass return [x, y] # Regenerate response def retry_last_answer( selected_model, system_message, user_message, chat, history, temperature, top_k, top_p, max_new_tokens, repetition_penalty, # do_save, ): if chat and history: # Removing the previous conversation from chat chat.pop(-1) # Removing bot response from the history history.pop(-1) # Setting up a flag to capture a retry RETRY_FLAG = True # Getting last message from user user_message = history[-1] yield from generate( RETRY_FLAG, selected_model, system_message, user_message, chat, history, temperature, top_k, top_p, max_new_tokens, repetition_penalty, # do_save, ) title = """