Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import re | |
import pandas as pd | |
from pathlib import Path | |
from time import sleep | |
from tqdm import tqdm | |
from api_calls import * | |
ROOT_DIR = Path(__file__).resolve().parents[0] | |
def disable_btn(): | |
return gr.Button.update(interactive=False) | |
def enable_btn(): | |
return gr.Button.update(interactive=True) | |
def preview_uploaded_file(file_paths): | |
if file_paths: | |
return gr.update(value=file_paths[0]) | |
else: | |
return gr.update(value=None) | |
def open_data_check(checked): | |
if checked: | |
return gr.update(visible=True) | |
else: | |
return gr.update(visible=False) | |
def uploaded_file_process(file_path, ocr_model_choice): | |
name, filetype = Path(file_path).parts[-1].split(".")[0], Path(file_path).parts[-1].split(".")[-1] | |
print(name) | |
ocr_extracted_data = api_ocr( | |
image_filepath=file_path, model_provider=ocr_model_choice) | |
return ocr_extracted_data | |
def reference_from_file(file_paths, ocr_model_choice="Gemini Pro Vision"): | |
data_array = [] | |
for file_path in tqdm(file_paths): | |
data = uploaded_file_process(file_path, ocr_model_choice=ocr_model_choice) | |
data_array.append(data) | |
sleep(1) | |
return data_array | |
def print_like_dislike(x: gr.LikeData): | |
print(x.index, x.value, x.liked) | |
def bot(query, history, data_array, file_paths, qa_prompt_tmpl, checkbox_replace): | |
if data_array: | |
params = {"query": query, "filtered_data": data_array} | |
else: | |
params = {"query": query} | |
if checkbox_replace: | |
params.update({"prompt_template": qa_prompt_tmpl}) | |
if not file_paths or "大台北" in file_paths: | |
func = api_qa_waterfee | |
else: | |
func = api_qa_normal | |
response = func(**params) | |
full_anwser = "" | |
for chunk in response.iter_content(chunk_size=32): | |
if chunk: | |
try: | |
_c = chunk.decode('utf-8') | |
except UnicodeDecodeError: | |
_c = " " | |
full_anwser += _c | |
yield full_anwser | |
# print(_c, flush=True, end="") | |
# for character in response: | |
# full_anwser += character | |
# yield full_anwser | |
def cat_report_explanation(data_array): | |
response = api_qa_cat_report(data_array) | |
full_anwser = "" | |
for chunk in response.iter_content(chunk_size=32): | |
if chunk: | |
try: | |
_c = chunk.decode('utf-8') | |
except UnicodeDecodeError: | |
_c = " " | |
full_anwser += _c | |
yield full_anwser | |
def draw_cat_pain_assessment_result(user_input_image): | |
if user_input_image: | |
json_result = api_model_cat_pain_assessment(user_input_image) | |
print(json_result) | |
total_score = sum(list(json_result.values())) | |
df_result = pd.DataFrame(json_result, index=[0]).T.reset_index() | |
df_result.columns = ["a", "b"] | |
return gr.BarPlot( | |
df_result, | |
x="a", | |
y="b", | |
x_title="Aspects", | |
y_title="Score", | |
title="Cat Pain Assessment", | |
vertical=False, | |
height=400, | |
width=800, | |
tooltip=["a", "b"], | |
y_lim=[0, 2], | |
scale=1, | |
), gr.HTML( | |
'<h3>Total Score</h3>' | |
f'<span style="font-size: 50px;">{total_score}</span>' | |
'<span style="font-size: 40px;">/10</span>' | |
), gr.HTML( | |
'<h3>Explanation</h3>' | |
'<p>Ear position: 0-2</p>' | |
'<p>Orbital tightening: 0-2</p>' | |
'<p>Muzzle tension: 0-2</p>' | |
'<p>Whiskers change: 0-2</p>' | |
'<p>Head position: 0-2</p>' | |
) | |
else: | |
return gr.update(value=None) | |
chatbot = gr.Chatbot( | |
[(None, "我是 ESG AI Chat\n有什麼能為您服務的嗎?")], | |
elem_id="chatbot", | |
scale=1, | |
height=700, | |
bubble_full_width=False | |
) | |
css = """ | |
#examples_file_to_ocr {color: green !important} | |
#center {text-align: center} | |
footer {visibility: hidden} | |
a {color: rgb(255, 206, 10) !important} | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Monochrome(neutral_hue="green")) as demo: | |
gr.HTML("<h1>GlobalModelAI AI Product Test</h1><p>Made by `GlobalModelAI Abao`</p>", elem_id="center") | |
with gr.Tab("OCR + Text2SQL"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## OCR Processing", elem_id="center") | |
ocr_model_choice = gr.Dropdown(label="Model", value="Gemini Pro Vision", choices=["GPT-4", "Gemini Pro Vision"]) | |
file_preview = gr.Image(type="filepath", image_mode="RGB", sources=None, label="File Preview") | |
file_upload = gr.File(label="Upload File", file_types=["png", "jpg", "jpeg", "helc"], file_count='multiple') | |
checkbox_open_data_check = gr.Checkbox(label="Open Data Check") | |
text_data_from_file_check = gr.Textbox(label="File Upload Status", interactive=False, visible=False) | |
gr.Examples( | |
examples=[ | |
[[f"{ROOT_DIR}/data/image_for_test/screenshot_for_test-esg_report_table.png"]], | |
[[f"{ROOT_DIR}/data/image_for_test/screenshot_for_test-esg_report_table2.png"], | |
[f"{ROOT_DIR}/data/image_for_test/screenshot_for_test-esg_report_table3.png"]], | |
[[f"{ROOT_DIR}/data/image_for_test/screenshot_for_test-medical_thesis_table.png"], | |
[f"{ROOT_DIR}/data/image_for_test/screenshot_for_test-medical_thesis_table2.jpg"]], | |
], | |
inputs=file_upload, | |
outputs=text_data_from_file_check, | |
fn=reference_from_file, | |
cache_examples=True, | |
elem_id="examples_file_to_ocr" | |
) | |
with gr.Column(): | |
gr.Markdown("## Chat with your data", elem_id="center") | |
with gr.Accordion("Revise Your Prompt", open=False): | |
checkbox_replace = gr.Checkbox(label="Replace with new prompt") | |
qa_prompt_tmpl = gr.Textbox( | |
label="希望用於本次問答的prompt", | |
info="必須使用到的變數:{filtered_data}、{query}", | |
value="", | |
interactive=True, | |
) | |
chat_interface = gr.ChatInterface( | |
fn=bot, | |
additional_inputs=[text_data_from_file_check, file_upload, qa_prompt_tmpl, checkbox_replace], | |
chatbot=chatbot, | |
) | |
chatbot.like(print_like_dislike, None, None) | |
with gr.Tab("Cat Pain Assessment Model"): | |
gr.Markdown("## Cat Pain Assessment Model", elem_id="center") | |
with gr.Row(): | |
user_input_image = gr.Image( | |
type="filepath", image_mode="RGB", | |
sources=["upload", "webcam", "clipboard"], | |
label="Upload a cat image") | |
with gr.Column(): | |
cat_pain_assessment_barplot = gr.BarPlot(label="Cat Pain Assessment") | |
cat_pain_assessment_score = gr.HTML(elem_id="center") | |
cat_pain_assessment_explanation = gr.HTML() | |
gr.Examples( | |
examples=[ | |
[f"{ROOT_DIR}/data/cat_pain_detection/fgs_cat_examples/5f2afc_3c44de4afb8345a2a56828e3dd166f41~mv2.jpg"], | |
[f"{ROOT_DIR}/data/cat_pain_detection/fgs_cat_examples/5f2afc_9d9838561cde41d3b2dc9ef079dc2303~mv2.jpg"], | |
[f"{ROOT_DIR}/data/cat_pain_detection/fgs_cat_examples/5f2afc_da95c2a1a3294701a007d34ec02f62a5~mv2.jpg"], | |
], | |
inputs=user_input_image, | |
outputs=[cat_pain_assessment_barplot, cat_pain_assessment_score, cat_pain_assessment_explanation], | |
fn=draw_cat_pain_assessment_result, | |
cache_examples=True, | |
) | |
with gr.Tab("Cat Report Explanation"): | |
gr.Markdown("## Cat Report Explanation", elem_id="center") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Report Processing", elem_id="center") | |
catrep_ocr_model_choice = gr.Dropdown(label="Model", value="Gemini Pro Vision", choices=["GPT-4", "Gemini Pro Vision"]) | |
catrep_file_preview = gr.Image(type="filepath", image_mode="RGB", sources=None, label="File Preview") | |
catrep_file_upload = gr.File(label="Upload File", file_types=["png", "jpg", "jpeg", "helc"], file_count='multiple') | |
catrep_button_generation_explanation = gr.Button("Start Explanation") | |
catrep_checkbox_open_data_check = gr.Checkbox(label="Open Data Check") | |
catrep_text_data_from_file_check = gr.Textbox(label="File Upload Status", interactive=False, visible=False) | |
gr.Examples( | |
examples=[ | |
[[f"{ROOT_DIR}/data/image_for_test/screenshot_for_test-cat_report_12.png"]] | |
], | |
inputs=catrep_file_upload, | |
outputs=catrep_text_data_from_file_check, | |
fn=reference_from_file, | |
cache_examples=True, | |
elem_id="examples_file_to_ocr" | |
) | |
with gr.Column(): | |
gr.Markdown("### View Explanation", elem_id="center") | |
catrep_textbox_explanation = gr.Textbox( | |
label="Explanation", | |
placeholder="Explanation will show here after you upload image & click the button", | |
interactive=False, | |
) | |
# Callbacks | |
## OCR + Text2SQL | |
file_upload.upload( | |
reference_from_file, [file_upload, ocr_model_choice], [text_data_from_file_check] | |
) | |
file_upload.change( | |
preview_uploaded_file, [file_upload], [file_preview] | |
) | |
ocr_model_choice.change( | |
reference_from_file, [file_upload, ocr_model_choice], [text_data_from_file_check] | |
) | |
checkbox_open_data_check.select( | |
open_data_check, [checkbox_open_data_check], [text_data_from_file_check] | |
) | |
## Cat Pain Assessment Model | |
user_input_image.change( | |
draw_cat_pain_assessment_result, [user_input_image], | |
[cat_pain_assessment_barplot, cat_pain_assessment_score, cat_pain_assessment_explanation] | |
) | |
## Cat Report Explanation | |
catrep_file_upload.upload( | |
reference_from_file, [catrep_file_upload, catrep_ocr_model_choice], [catrep_text_data_from_file_check] | |
) | |
catrep_file_upload.change( | |
preview_uploaded_file, [catrep_file_upload], [catrep_file_preview] | |
) | |
catrep_ocr_model_choice.change( | |
reference_from_file, [catrep_file_upload, catrep_ocr_model_choice], [catrep_text_data_from_file_check] | |
) | |
catrep_checkbox_open_data_check.select( | |
open_data_check, [catrep_checkbox_open_data_check], [catrep_text_data_from_file_check] | |
) | |
catrep_button_generation_explanation.click( | |
cat_report_explanation, [catrep_text_data_from_file_check], [catrep_textbox_explanation] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(max_threads=10) | |