Spaces:
Running
Running
import base64 | |
import io | |
from functools import partial | |
import gradio as gr | |
import httpx | |
from const import CSS, FOOTER, HEADER, MODELS, PLACEHOLDER | |
from openai import OpenAI | |
from PIL import Image | |
from cycloud.auth import load_default_credentials | |
def get_headers(host: str) -> dict: | |
creds = load_default_credentials() | |
return { | |
"Authorization": f"Bearer {creds.access_token}", | |
"Host": host, | |
"Accept": "application/json", | |
"Content-Type": "application/json", | |
} | |
def proxy(request: httpx.Request, model_info: dict) -> httpx.Request: | |
request.url = request.url.copy_with(path=model_info["endpoint"]) | |
request.headers.update(get_headers(host=model_info["host"].replace("https://", ""))) | |
return request | |
def encode_image_with_pillow(image_path: str) -> str: | |
with Image.open(image_path) as img: | |
img.thumbnail((384, 384)) | |
buffered = io.BytesIO() | |
img.convert("RGB").save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def call_chat_api(message, history, model_name): | |
if message["files"]: | |
if isinstance(message["files"], dict): | |
image = message["files"]["path"] | |
else: | |
image = message["files"][-1] | |
else: | |
for hist in history: | |
if isinstance(hist[0], tuple): | |
image = hist[0][0] | |
img_base64 = encode_image_with_pillow(image) | |
history_openai_format = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{img_base64}", | |
}, | |
}, | |
], | |
} | |
] | |
if len(history) == 0: | |
history_openai_format[0]["content"].append( | |
{"type": "text", "text": message["text"]} | |
) | |
else: | |
for human, assistant in history[1:]: | |
if len(history_openai_format) == 1: | |
history_openai_format[0]["content"].append( | |
{"type": "text", "text": human} | |
) | |
else: | |
history_openai_format.append({"role": "user", "content": human}) | |
history_openai_format.append({"role": "assistant", "content": assistant}) | |
history_openai_format.append({"role": "user", "content": message["text"]}) | |
client = OpenAI( | |
api_key="", | |
base_url=MODELS[model_name]["host"], | |
http_client=httpx.Client( | |
event_hooks={ | |
"request": [partial(proxy, model_info=MODELS[model_name])], | |
}, | |
verify=False, | |
), | |
) | |
stream = client.chat.completions.create( | |
model=f"/data/cyberagent/{model_name}", | |
messages=history_openai_format, | |
temperature=0.2, | |
top_p=1.0, | |
max_tokens=1024, | |
stream=True, | |
extra_body={"repetition_penalty": 1.1}, | |
) | |
message = "" | |
for chunk in stream: | |
content = chunk.choices[0].delta.content or "" | |
message = message + content | |
yield message | |
def run(): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", placeholder=PLACEHOLDER, scale=1, height=700 | |
) | |
chat_input = gr.MultimodalTextbox( | |
interactive=True, | |
file_types=["image"], | |
placeholder="Enter message or upload file...", | |
show_label=False, | |
) | |
with gr.Blocks(css=CSS) as demo: | |
gr.Markdown(HEADER) | |
with gr.Row(): | |
model_selector = gr.Dropdown( | |
choices=MODELS.keys(), | |
value=list(MODELS.keys())[0], | |
label="Model", | |
) | |
gr.ChatInterface( | |
fn=call_chat_api, | |
stop_btn="Stop Generation", | |
examples=[ | |
[ | |
{ | |
"text": "ใใฎ็ปๅใ่ฉณใใ่ชฌๆใใฆใใ ใใใ", | |
"files": ["./examples/cat.jpg"], | |
}, | |
], | |
[ | |
{ | |
"text": "ใใฎๆ็ใฏใฉใใชๅณใใใใ่ฉณใใๆใใฆใใ ใใใ", | |
"files": ["./examples/takoyaki.jpg"], | |
}, | |
], | |
], | |
multimodal=True, | |
textbox=chat_input, | |
chatbot=chatbot, | |
additional_inputs=[model_selector], | |
) | |
gr.Markdown(FOOTER) | |
demo.queue().launch(share=False) | |
if __name__ == "__main__": | |
run() | |