Spaces:
Running
Running
import gradio as gr | |
import io | |
import base64 | |
import os | |
import json | |
import re | |
from PIL import Image | |
from huggingface_hub import InferenceClient | |
from google.generativeai import configure, GenerativeModel | |
from google.ai.generativelanguage import Content, Part | |
# Load API keys from environment variables | |
inference_api_key = os.environ.get("HF_TOKEN") | |
google_api_key = os.environ.get("GOOGLE_API_KEY") # New Google API key | |
# Configure Google API | |
configure(api_key=google_api_key) | |
# Global variables to store the image data URL and prompt for the currently generated image. | |
global_image_data_url = None | |
global_image_prompt = None # Still stored if needed elsewhere | |
def update_difficulty_label(active_session): | |
return f"**Current Difficulty:** {active_session.get('difficulty', 'Very Simple')}" | |
def generate_prompt_from_options(difficulty, age, autism_level, topic_focus, treatment_plan=""): | |
""" | |
Generate an image prompt using Google's Gemini model. | |
""" | |
query = ( | |
f""" | |
Follow the instructions below to generate an image generation prompt for an educational image intended for autistic children. | |
Consider the following parameters: | |
- Difficulty: {difficulty} | |
- Age: {age} | |
- Autism Level: {autism_level} | |
- Topic Focus: {topic_focus} | |
- Treatment Plan: {treatment_plan} | |
Emphasize that the image should be clear, calming, and support understanding and communication. The style should match the difficulty level: for example, "Very Simple" produces very basic visuals while "Very Detailed" produces rich visuals. | |
The image should specifically focus on the topic: "{topic_focus}". | |
Please generate a prompt that instructs the image generation engine to produce an image with: | |
1. Clarity and simplicity (minimalist backgrounds, clear subject) | |
2. Literal representation with defined borders and consistent style | |
3. Soft, muted colors and reduced visual complexity | |
4. Positive, calm scenes | |
5. Clear focus on the specified topic | |
Use descriptive and detailed language. | |
""" | |
) | |
# Initialize the Gemini Pro model | |
model = GenerativeModel('gemini-2.0-flash-lite') | |
# Generate content using the Gemini model | |
response = model.generate_content(query) | |
return response.text.strip() | |
def generate_image_fn(selected_prompt, guidance_scale=7.5, | |
negative_prompt="ugly, blurry, poorly drawn hands, lewd, nude, deformed, missing limbs, missing eyes, missing arms, missing legs", | |
num_inference_steps=50): | |
""" | |
Generate an image from the prompt via the Hugging Face Inference API. | |
Convert the image to a data URL. | |
""" | |
global global_image_data_url, global_image_prompt | |
global_image_prompt = selected_prompt | |
image_client = InferenceClient(provider="hf-inference", api_key=inference_api_key) | |
image = image_client.text_to_image( | |
selected_prompt, | |
model="stabilityai/stable-diffusion-3.5-large-turbo", | |
guidance_scale=guidance_scale, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps | |
) | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
img_bytes = buffered.getvalue() | |
img_b64 = base64.b64encode(img_bytes).decode("utf-8") | |
global_image_data_url = f"data:image/png;base64,{img_b64}" | |
return image | |
def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions): | |
""" | |
Generate a new image (with the current difficulty) and reset the chat. | |
Now includes the topic_focus parameter to specify what the image should focus on. | |
""" | |
new_sessions = saved_sessions.copy() | |
if active_session.get("prompt"): | |
new_sessions.append(active_session) | |
# Use the current difficulty from the active session (which should be updated if advanced) | |
current_difficulty = active_session.get("difficulty", "Very Simple") | |
generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan) | |
image = generate_image_fn(generated_prompt) | |
new_active_session = { | |
"prompt": generated_prompt, | |
"image": global_image_data_url, | |
"chat": [], | |
"treatment_plan": treatment_plan, | |
"topic_focus": topic_focus, | |
"identified_details": [], | |
"difficulty": current_difficulty, | |
"autism_level": autism_level, | |
"age": age | |
} | |
return image, new_active_session, new_sessions | |
def compare_details_chat_fn(user_details, treatment_plan, chat_history, identified_details): | |
""" | |
Evaluate the child's description using Google's Gemini Vision model. | |
""" | |
if not global_image_data_url: | |
return "Please generate an image first." | |
history_text = "" | |
if chat_history: | |
history_text = "\n\n### Previous Conversation:\n" | |
for idx, (user_msg, bot_msg) in enumerate(chat_history, 1): | |
history_text += f"Turn {idx}:\nUser: {user_msg}\nTeacher: {bot_msg}\n" | |
identified_details_text = "" | |
if identified_details: | |
identified_details_text = "\n\n### Previously Identified Details:\n" + "\n".join(f"- {detail}" for detail in identified_details) | |
message_text = ( | |
f"{history_text}{identified_details_text}\n\n" | |
f"Based on the image provided above, please evaluate the following description given by the child:\n" | |
f"'{user_details}'\n\n" | |
"You are a kind and encouraging teacher speaking to a child. Use simple, clear language. " | |
"Praise the child's correct observations and provide a gentle hint if something is missing. " | |
"Keep your feedback positive and easy to understand.\n\n" | |
"Focus on these evaluation criteria:\n" | |
"1. **Object Identification** – Did the child mention the main objects?\n" | |
"2. **Color & Shape Accuracy** – Were the colors and shapes described correctly?\n" | |
"3. **Clarity & Simplicity** – Was the description clear and easy to understand?\n" | |
"4. **Overall Communication** – How well did the child communicate their thoughts?\n\n" | |
"Note: As difficulty increases, the expected level of detail is higher. Evaluate accordingly.\n\n" | |
"Return your evaluation strictly as a JSON object with the following keys:\n" | |
"{\n" | |
" \"scores\": {\n" | |
" \"object_identification\": <number>,\n" | |
" \"color_shape_accuracy\": <number>,\n" | |
" \"clarity_simplicity\": <number>,\n" | |
" \"overall_communication\": <number>\n" | |
" },\n" | |
" \"final_score\": <number>,\n" | |
" \"feedback\": \"<string>\",\n" | |
" \"hint\": \"<string>\",\n" | |
" \"advance\": <boolean>\n" | |
"}\n\n" | |
"Do not include any additional text outside the JSON." | |
) | |
# Remove the data:image/png;base64, prefix to get just the base64 string | |
base64_img = global_image_data_url.split(",")[1] | |
# Create a Gemini Vision Pro model | |
vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21') | |
# Create the content with image and text using the correct parameters | |
# Use 'inline_data' instead of 'content' for the image part | |
image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)}) | |
text_part = Part(text=message_text) | |
multimodal_content = Content(parts=[image_part, text_part]) | |
# Generate evaluation using the vision model | |
response = vision_model.generate_content(multimodal_content) | |
return response.text | |
def evaluate_scores(evaluation_text, current_difficulty): | |
""" | |
Parse the JSON evaluation and decide if the child advances. | |
The threshold scales with difficulty: | |
Very Simple: 70, Simple: 75, Moderate: 80, Detailed: 85, Very Detailed: 90. | |
""" | |
try: | |
json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL) | |
if json_match: | |
json_str = json_match.group(0) | |
evaluation = json.loads(json_str) | |
else: | |
raise ValueError("No JSON object found in the response.") | |
final_score = evaluation.get("final_score", 0) | |
hint = evaluation.get("hint", "Keep trying!") | |
advance = evaluation.get("advance", False) | |
difficulty_thresholds = { | |
"Very Simple": 70, | |
"Simple": 75, | |
"Moderate": 80, | |
"Detailed": 85, | |
"Very Detailed": 90 | |
} | |
current_threshold = difficulty_thresholds.get(current_difficulty, 70) | |
difficulty_mapping = { | |
"Very Simple": "Simple", | |
"Simple": "Moderate", | |
"Moderate": "Detailed", | |
"Detailed": "Very Detailed", | |
"Very Detailed": "Very Detailed" | |
} | |
if final_score >= current_threshold or advance: | |
new_difficulty = difficulty_mapping.get(current_difficulty, current_difficulty) | |
response_msg = (f"Great job! Your final score is {final_score}, which meets the target of {current_threshold}. " | |
f"You've advanced to {new_difficulty} difficulty.") | |
return response_msg, new_difficulty | |
else: | |
response_msg = (f"Your final score is {final_score} (\n target: {current_threshold}). {hint} \n " | |
f"Please try again at the {current_difficulty} level.") | |
return response_msg, current_difficulty | |
except Exception as e: | |
return f"Error processing evaluation output: {str(e)}", current_difficulty | |
def chat_respond(user_message, active_session, saved_sessions): | |
""" | |
Process a new chat message. | |
Evaluate the child's description. If the evaluation indicates advancement, | |
update the difficulty, generate a new image (resetting image and chat), and update the difficulty label. | |
""" | |
if not active_session.get("image"): | |
bot_message = "Please generate an image first." | |
updated_chat = active_session.get("chat", []) + [(user_message, bot_message)] | |
active_session["chat"] = updated_chat | |
return "", updated_chat, saved_sessions, active_session | |
chat_history = active_session.get("chat", []) | |
identified_details = active_session.get("identified_details", []) | |
raw_evaluation = compare_details_chat_fn(user_message, "", chat_history, identified_details) | |
current_difficulty = active_session.get("difficulty", "Very Simple") | |
evaluation_response, updated_difficulty = evaluate_scores(raw_evaluation, current_difficulty) | |
bot_message = evaluation_response | |
# If the child advanced, update difficulty and generate a new image | |
if updated_difficulty != current_difficulty: | |
# Update the active session's difficulty before generating a new prompt | |
active_session["difficulty"] = updated_difficulty | |
age = active_session.get("age", "3") | |
autism_level = active_session.get("autism_level", "Level 1") | |
topic_focus = active_session.get("topic_focus", "") | |
treatment_plan = active_session.get("treatment_plan", "") | |
new_image, new_active_session, new_sessions = generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions) | |
new_active_session["chat"].append(("System", f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you.")) | |
active_session = new_active_session | |
bot_message = f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you." | |
saved_sessions = new_sessions | |
else: | |
updated_chat = active_session.get("chat", []) + [(user_message, bot_message)] | |
active_session["chat"] = updated_chat | |
return "", active_session["chat"], saved_sessions, active_session | |
def update_sessions(saved_sessions, active_session): | |
""" | |
Combine finished sessions with the active session for display. | |
""" | |
if active_session and active_session.get("prompt"): | |
return saved_sessions + [active_session] | |
return saved_sessions | |
############################################## | |
# Gradio Interface | |
############################################## | |
with gr.Blocks() as demo: | |
# The active session now starts with difficulty "Very Simple" | |
active_session = gr.State({ | |
"prompt": None, | |
"image": None, | |
"chat": [], | |
"treatment_plan": "", | |
"topic_focus": "", | |
"identified_details": [], | |
"difficulty": "Very Simple", | |
"age": "3", | |
"autism_level": "Level 1" | |
}) | |
saved_sessions = gr.State([]) | |
with gr.Column(): | |
gr.Markdown("# Image Generation & Chat Inference") | |
# Display current difficulty label | |
difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple") | |
# ----- Image Generation Section ----- | |
with gr.Column(): | |
gr.Markdown("## Generate Image") | |
gr.Markdown("Enter your age, select your autism level, specify a topic focus, and provide the treatment plan to generate an image based on the current difficulty level.") | |
with gr.Row(): | |
age_input = gr.Textbox(label="Age", placeholder="Enter age...", value="3") | |
autism_level_dropdown = gr.Dropdown(label="Autism Level", choices=["Level 1", "Level 2", "Level 3"], value="Level 1") | |
topic_focus_input = gr.Textbox( | |
label="Topic Focus", | |
placeholder="Enter a specific topic or detail to focus on (e.g., 'animals', 'emotions', 'daily routines')...", | |
lines=1 | |
) | |
treatment_plan_input = gr.Textbox( | |
label="Treatment Plan", | |
placeholder="Enter the treatment plan to guide the image generation...", | |
lines=2 | |
) | |
generate_btn = gr.Button("Generate Image") | |
img_output = gr.Image(label="Generated Image") | |
generate_btn.click( | |
generate_image_and_reset_chat, | |
inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions], | |
outputs=[img_output, active_session, saved_sessions] | |
) | |
# ----- Chat Section ----- | |
with gr.Column(): | |
gr.Markdown("## Chat about the Image") | |
gr.Markdown( | |
"After generating an image, type details or descriptions about it. " | |
"Your message, along with the generated image and conversation history, will be sent for evaluation." | |
) | |
chatbot = gr.Chatbot(label="Chat History") | |
with gr.Row(): | |
chat_input = gr.Textbox(label="Your Message", placeholder="Type your description here...", show_label=False) | |
send_btn = gr.Button("Send") | |
send_btn.click( | |
chat_respond, | |
inputs=[chat_input, active_session, saved_sessions], | |
outputs=[chat_input, chatbot, saved_sessions, active_session] | |
) | |
chat_input.submit( | |
chat_respond, | |
inputs=[chat_input, active_session, saved_sessions], | |
outputs=[chat_input, chatbot, saved_sessions, active_session] | |
) | |
# ----- Sidebar Section for Session Details ----- | |
with gr.Column(variant="sidebar"): | |
gr.Markdown("## Saved Chat Sessions") | |
gr.Markdown( | |
"This sidebar automatically saves finished chat sessions. " | |
"Each session includes the prompt used, the generated image (as a data URL), " | |
"the topic focus, the treatment plan, the list of identified details, and the full chat history." | |
) | |
sessions_output = gr.JSON(label="Session Details", value={}) | |
active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output) | |
# Update the current difficulty label when active_session changes. | |
active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label]) | |
saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output) | |
# Launch the app with public sharing enabled. | |
demo.launch() | |