|
from openai import OpenAI |
|
import gradio as gr |
|
import os |
|
import json |
|
import html |
|
import random |
|
import datetime |
|
|
|
api_key = os.environ.get('FEATHERLESS_API_KEY') |
|
|
|
if not api_key: |
|
raise RuntimeError("Cannot start without required API key. Please register for one at https://featherless.ai") |
|
|
|
client = OpenAI( |
|
base_url="https://api.featherless.ai/v1", |
|
api_key=api_key |
|
) |
|
|
|
with open('./model-cache.json', 'r') as f_model_cache: |
|
model_cache = json.load(f_model_cache) |
|
model_class_from_model_id = { model_id: model_class for model_class, model_ids in model_cache.items() for model_id in model_ids } |
|
|
|
model_class_filter = { |
|
"mistral-v02-7b-std-lc": True, |
|
"llama3-8b-8k": True, |
|
"llama31-8b-16k": True, |
|
"llama2-solar-10b7-4k": True, |
|
"mistral-nemo-12b-lc": True, |
|
"llama2-13b-4k": True, |
|
"llama3-15b-8k": True, |
|
|
|
"qwen2-32b-lc":False, |
|
"llama3-70b-8k":False, |
|
"llama31-70b-16k": False, |
|
"qwen2-72b-lc":False, |
|
"mixtral-8x22b-lc":False, |
|
"llama3-405b-lc":False, |
|
} |
|
|
|
|
|
REFLECTION="mattshumer/Reflection-Llama-3.1-70B" |
|
QWEN25_72B="Qwen/Qwen2.5-72B" |
|
NEMOTRON="nvidia/Llama-3.1-Nemotron-70B-Instruct-HF" |
|
bigger_whitelisted_models = [ |
|
QWEN25_72B, |
|
NEMOTRON |
|
] |
|
|
|
model_class_from_model_id[REFLECTION] = 'llama31-70b-16k' |
|
model_class_from_model_id[NEMOTRON] = 'llama31-70b-16k' |
|
def build_model_choices(): |
|
all_choices = [] |
|
for model_class in model_cache: |
|
if model_class not in model_class_filter: |
|
print(f"Warning: new model class {model_class}. Treating as blacklisted") |
|
continue |
|
|
|
if not model_class_filter[model_class]: |
|
continue |
|
all_choices += [ (f"{model_id} ({model_class})", model_id) for model_id in model_cache[model_class] ] |
|
|
|
all_choices += [ (f"{model_id}, {model_class_from_model_id[model_id]}", model_id) for model_id in bigger_whitelisted_models ] |
|
|
|
return all_choices |
|
model_choices = build_model_choices() |
|
def model_in_list(model): |
|
for label, id in model_choices: |
|
if id == model: |
|
return True |
|
|
|
return False |
|
|
|
|
|
key=os.environ.get('RANDOM_SEED', 'kcOtfNHA+e') |
|
o = random.Random(f"{key}-{datetime.date.today().strftime('%Y-%m-%d')}") |
|
initial_model = o.choice(model_choices)[1] |
|
initial_model = NEMOTRON |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REFLECTION_SYSTEM_PROMPT = """You are a world-class AI system, capable of complex reasoning and reflection. Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags. If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags.""" |
|
|
|
def respond(message, history, model, request: gr.Request): |
|
|
|
if not model_in_list(model): |
|
raise RuntimeError(f"{model} is not supported in this hf space. Visit https://featherless.ai to see and use the complete model catalogue") |
|
|
|
history_openai_format = [] |
|
for human, assistant in history: |
|
history_openai_format.append({"role": "user", "content": human }) |
|
history_openai_format.append({"role": "assistant", "content":assistant}) |
|
history_openai_format.append({"role": "user", "content": message}) |
|
|
|
if model == REFLECTION: |
|
history_openai_format = [ |
|
{"role": "system", "content": REFLECTION_SYSTEM_PROMPT}, |
|
*history_openai_format |
|
] |
|
|
|
response = client.chat.completions.create( |
|
model=model, |
|
messages= history_openai_format, |
|
temperature=1.0, |
|
stream=True, |
|
max_tokens=2000, |
|
extra_headers={ |
|
'HTTP-Referer': request.headers.get('referer'), |
|
'X-Title': "HF's missing inference widget" |
|
} |
|
) |
|
|
|
partial_message = "" |
|
for chunk in response: |
|
if chunk.choices[0].delta.content is not None: |
|
content = chunk.choices[0].delta.content |
|
escaped_content = html.escape(content) |
|
partial_message += escaped_content |
|
yield partial_message |
|
|
|
logo = open('./logo.svg').read() |
|
logo_small = open('./logo-small.svg').read() |
|
title_text="HuggingFace's missing inference widget" |
|
css = """ |
|
.logo-mark { fill: #ffe184; } |
|
|
|
/* from https://github.com/gradio-app/gradio/issues/4001 |
|
* necessary as putting ChatInterface in gr.Blocks changes behaviour |
|
*/ |
|
|
|
.row { |
|
display: flex; |
|
justify-content: center; |
|
} |
|
|
|
.footer p { |
|
width: 450px; |
|
} |
|
|
|
.contain { display: flex; flex-direction: column; } |
|
.gradio-container { height: 100vh !important; } |
|
#component-0 { height: 100%; } |
|
#chatbot { flex-grow: 1; overflow: auto;} |
|
""" |
|
|
|
with gr.Blocks(title_text, css=css) as demo: |
|
gr.HTML(f""" |
|
<div class="header"> |
|
<h1 class="row">HuggingFace's missing inference widget</h1> |
|
<h3 class="row">powered by</h3> |
|
<div class="row"> |
|
<a href="https://featherless.ai"> |
|
{logo} |
|
</a> |
|
</div> |
|
</div> |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
model_selector = gr.Dropdown( |
|
label="Select your Model", |
|
choices=build_model_choices(), |
|
value=initial_model, |
|
|
|
scale=4 |
|
) |
|
gr.Button( |
|
value="Visit Model Card ↗️", |
|
scale=1 |
|
).click( |
|
inputs=[model_selector], |
|
js="(model_selection) => { window.open(`https://featherless.ai/models/${model_selection}/readme`, '_blank') }", |
|
fn=None, |
|
) |
|
|
|
gr.ChatInterface( |
|
respond, |
|
additional_inputs=[model_selector], |
|
head=""", |
|
<script>console.log("Hello from gradio!")</script> |
|
""", |
|
concurrency_limit=5 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
gr.HTML(f""" |
|
<div class="footer"> |
|
<div class="row"> |
|
If you enjoyed this space, |
|
check out <a href="https://featherless.ai">featherless.ai</a>, |
|
and follow us <a href="https://x.com/FeatherlessAI">on twitter</a>! |
|
</div> |
|
<!-- <div class="row">If you enjoyed this space,</div> |
|
<div class="row">check out <a href="https://featherless.ai">featherless.ai</a>,</div> |
|
<div class="row">and follow us <a href="https://x.com/FeatherlessAI">on twitter</a>!</div> --> |
|
</div> |
|
""") |
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|