Spaces:
Running
Running
File size: 16,294 Bytes
686ef17 82d4c23 993333f 8850bc5 993333f 8850bc5 993333f 8850bc5 5d96bcc 8850bc5 9b134bd 8850bc5 993333f 1f6ec43 73b4bb6 8850bc5 29561c3 8850bc5 29561c3 5d96bcc 8850bc5 5d96bcc 29561c3 8850bc5 29561c3 8850bc5 29561c3 8850bc5 29561c3 8850bc5 993333f 8850bc5 993333f 0b4a56a 8850bc5 993333f 9bbeb75 8850bc5 5d96bcc 993333f 0b4a56a 993333f 686ef17 8850bc5 e4f9a72 8850bc5 e4f9a72 8f401f6 29561c3 8850bc5 29561c3 e4f9a72 8850bc5 993333f 8850bc5 993333f 1f6ec43 993333f 29561c3 8850bc5 0b4a56a 29561c3 8850bc5 29561c3 8850bc5 686ef17 44904f8 29561c3 8850bc5 29561c3 8850bc5 29561c3 8850bc5 29561c3 8850bc5 29561c3 8850bc5 29561c3 e44ce7e 8850bc5 f42860d 0cac8da 8850bc5 8f401f6 4330f40 b66a090 8850bc5 b66a090 8850bc5 b66a090 8850bc5 5d96bcc 29561c3 b66a090 8850bc5 29561c3 b66a090 8850bc5 b66a090 29561c3 44904f8 29561c3 44904f8 29561c3 b66a090 8f401f6 b66a090 29561c3 8850bc5 b66a090 29561c3 8850bc5 29561c3 8f401f6 8850bc5 993333f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 |
import gradio as gr
import io
import base64
import os
import json
import re
from PIL import Image
from huggingface_hub import InferenceClient
from google.generativeai import configure, GenerativeModel
from google.ai.generativelanguage import Content, Part
# Load API keys from environment variables
inference_api_key = os.environ.get("HF_TOKEN")
google_api_key = os.environ.get("GOOGLE_API_KEY") # New Google API key
# Configure Google API
configure(api_key=google_api_key)
# Global variables to store the image data URL and prompt for the currently generated image.
global_image_data_url = None
global_image_prompt = None # Still stored if needed elsewhere
def update_difficulty_label(active_session):
return f"**Current Difficulty:** {active_session.get('difficulty', 'Very Simple')}"
def generate_prompt_from_options(difficulty, age, autism_level, topic_focus, treatment_plan=""):
"""
Generate an image prompt using Google's Gemini model.
"""
query = (
f"""
Follow the instructions below to generate an image generation prompt for an educational image intended for autistic children.
Consider the following parameters:
- Difficulty: {difficulty}
- Age: {age}
- Autism Level: {autism_level}
- Topic Focus: {topic_focus}
- Treatment Plan: {treatment_plan}
Emphasize that the image should be clear, calming, and support understanding and communication. The style should match the difficulty level: for example, "Very Simple" produces very basic visuals while "Very Detailed" produces rich visuals.
The image should specifically focus on the topic: "{topic_focus}".
Please generate a prompt that instructs the image generation engine to produce an image with:
1. Clarity and simplicity (minimalist backgrounds, clear subject)
2. Literal representation with defined borders and consistent style
3. Soft, muted colors and reduced visual complexity
4. Positive, calm scenes
5. Clear focus on the specified topic
Use descriptive and detailed language.
"""
)
# Initialize the Gemini Pro model
model = GenerativeModel('gemini-2.0-flash-lite')
# Generate content using the Gemini model
response = model.generate_content(query)
return response.text.strip()
def generate_image_fn(selected_prompt, guidance_scale=7.5,
negative_prompt="ugly, blurry, poorly drawn hands, lewd, nude, deformed, missing limbs, missing eyes, missing arms, missing legs",
num_inference_steps=50):
"""
Generate an image from the prompt via the Hugging Face Inference API.
Convert the image to a data URL.
"""
global global_image_data_url, global_image_prompt
global_image_prompt = selected_prompt
image_client = InferenceClient(provider="hf-inference", api_key=inference_api_key)
image = image_client.text_to_image(
selected_prompt,
model="stabilityai/stable-diffusion-3.5-large-turbo",
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps
)
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
global_image_data_url = f"data:image/png;base64,{img_b64}"
return image
def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions):
"""
Generate a new image (with the current difficulty) and reset the chat.
Now includes the topic_focus parameter to specify what the image should focus on.
"""
new_sessions = saved_sessions.copy()
if active_session.get("prompt"):
new_sessions.append(active_session)
# Use the current difficulty from the active session (which should be updated if advanced)
current_difficulty = active_session.get("difficulty", "Very Simple")
generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
image = generate_image_fn(generated_prompt)
new_active_session = {
"prompt": generated_prompt,
"image": global_image_data_url,
"chat": [],
"treatment_plan": treatment_plan,
"topic_focus": topic_focus,
"identified_details": [],
"difficulty": current_difficulty,
"autism_level": autism_level,
"age": age
}
return image, new_active_session, new_sessions
def compare_details_chat_fn(user_details, treatment_plan, chat_history, identified_details):
"""
Evaluate the child's description using Google's Gemini Vision model.
"""
if not global_image_data_url:
return "Please generate an image first."
history_text = ""
if chat_history:
history_text = "\n\n### Previous Conversation:\n"
for idx, (user_msg, bot_msg) in enumerate(chat_history, 1):
history_text += f"Turn {idx}:\nUser: {user_msg}\nTeacher: {bot_msg}\n"
identified_details_text = ""
if identified_details:
identified_details_text = "\n\n### Previously Identified Details:\n" + "\n".join(f"- {detail}" for detail in identified_details)
message_text = (
f"{history_text}{identified_details_text}\n\n"
f"Based on the image provided above, please evaluate the following description given by the child:\n"
f"'{user_details}'\n\n"
"You are a kind and encouraging teacher speaking to a child. Use simple, clear language. "
"Praise the child's correct observations and provide a gentle hint if something is missing. "
"Keep your feedback positive and easy to understand.\n\n"
"Focus on these evaluation criteria:\n"
"1. **Object Identification** β Did the child mention the main objects?\n"
"2. **Color & Shape Accuracy** β Were the colors and shapes described correctly?\n"
"3. **Clarity & Simplicity** β Was the description clear and easy to understand?\n"
"4. **Overall Communication** β How well did the child communicate their thoughts?\n\n"
"Note: As difficulty increases, the expected level of detail is higher. Evaluate accordingly.\n\n"
"Return your evaluation strictly as a JSON object with the following keys:\n"
"{\n"
" \"scores\": {\n"
" \"object_identification\": <number>,\n"
" \"color_shape_accuracy\": <number>,\n"
" \"clarity_simplicity\": <number>,\n"
" \"overall_communication\": <number>\n"
" },\n"
" \"final_score\": <number>,\n"
" \"feedback\": \"<string>\",\n"
" \"hint\": \"<string>\",\n"
" \"advance\": <boolean>\n"
"}\n\n"
"Do not include any additional text outside the JSON."
)
# Remove the data:image/png;base64, prefix to get just the base64 string
base64_img = global_image_data_url.split(",")[1]
# Create a Gemini Vision Pro model
vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
# Create the content with image and text using the correct parameters
# Use 'inline_data' instead of 'content' for the image part
image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
text_part = Part(text=message_text)
multimodal_content = Content(parts=[image_part, text_part])
# Generate evaluation using the vision model
response = vision_model.generate_content(multimodal_content)
return response.text
def evaluate_scores(evaluation_text, current_difficulty):
"""
Parse the JSON evaluation and decide if the child advances.
The threshold scales with difficulty:
Very Simple: 70, Simple: 75, Moderate: 80, Detailed: 85, Very Detailed: 90.
"""
try:
json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
if json_match:
json_str = json_match.group(0)
evaluation = json.loads(json_str)
else:
raise ValueError("No JSON object found in the response.")
final_score = evaluation.get("final_score", 0)
hint = evaluation.get("hint", "Keep trying!")
advance = evaluation.get("advance", False)
difficulty_thresholds = {
"Very Simple": 70,
"Simple": 75,
"Moderate": 80,
"Detailed": 85,
"Very Detailed": 90
}
current_threshold = difficulty_thresholds.get(current_difficulty, 70)
difficulty_mapping = {
"Very Simple": "Simple",
"Simple": "Moderate",
"Moderate": "Detailed",
"Detailed": "Very Detailed",
"Very Detailed": "Very Detailed"
}
if final_score >= current_threshold or advance:
new_difficulty = difficulty_mapping.get(current_difficulty, current_difficulty)
response_msg = (f"Great job! Your final score is {final_score}, which meets the target of {current_threshold}. "
f"You've advanced to {new_difficulty} difficulty.")
return response_msg, new_difficulty
else:
response_msg = (f"Your final score is {final_score} (\n target: {current_threshold}). {hint} \n "
f"Please try again at the {current_difficulty} level.")
return response_msg, current_difficulty
except Exception as e:
return f"Error processing evaluation output: {str(e)}", current_difficulty
def chat_respond(user_message, active_session, saved_sessions):
"""
Process a new chat message.
Evaluate the child's description. If the evaluation indicates advancement,
update the difficulty, generate a new image (resetting image and chat), and update the difficulty label.
"""
if not active_session.get("image"):
bot_message = "Please generate an image first."
updated_chat = active_session.get("chat", []) + [(user_message, bot_message)]
active_session["chat"] = updated_chat
return "", updated_chat, saved_sessions, active_session
chat_history = active_session.get("chat", [])
identified_details = active_session.get("identified_details", [])
raw_evaluation = compare_details_chat_fn(user_message, "", chat_history, identified_details)
current_difficulty = active_session.get("difficulty", "Very Simple")
evaluation_response, updated_difficulty = evaluate_scores(raw_evaluation, current_difficulty)
bot_message = evaluation_response
# If the child advanced, update difficulty and generate a new image
if updated_difficulty != current_difficulty:
# Update the active session's difficulty before generating a new prompt
active_session["difficulty"] = updated_difficulty
age = active_session.get("age", "3")
autism_level = active_session.get("autism_level", "Level 1")
topic_focus = active_session.get("topic_focus", "")
treatment_plan = active_session.get("treatment_plan", "")
new_image, new_active_session, new_sessions = generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions)
new_active_session["chat"].append(("System", f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you."))
active_session = new_active_session
bot_message = f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you."
saved_sessions = new_sessions
else:
updated_chat = active_session.get("chat", []) + [(user_message, bot_message)]
active_session["chat"] = updated_chat
return "", active_session["chat"], saved_sessions, active_session
def update_sessions(saved_sessions, active_session):
"""
Combine finished sessions with the active session for display.
"""
if active_session and active_session.get("prompt"):
return saved_sessions + [active_session]
return saved_sessions
##############################################
# Gradio Interface
##############################################
with gr.Blocks() as demo:
# The active session now starts with difficulty "Very Simple"
active_session = gr.State({
"prompt": None,
"image": None,
"chat": [],
"treatment_plan": "",
"topic_focus": "",
"identified_details": [],
"difficulty": "Very Simple",
"age": "3",
"autism_level": "Level 1"
})
saved_sessions = gr.State([])
with gr.Column():
gr.Markdown("# Image Generation & Chat Inference")
# Display current difficulty label
difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple")
# ----- Image Generation Section -----
with gr.Column():
gr.Markdown("## Generate Image")
gr.Markdown("Enter your age, select your autism level, specify a topic focus, and provide the treatment plan to generate an image based on the current difficulty level.")
with gr.Row():
age_input = gr.Textbox(label="Age", placeholder="Enter age...", value="3")
autism_level_dropdown = gr.Dropdown(label="Autism Level", choices=["Level 1", "Level 2", "Level 3"], value="Level 1")
topic_focus_input = gr.Textbox(
label="Topic Focus",
placeholder="Enter a specific topic or detail to focus on (e.g., 'animals', 'emotions', 'daily routines')...",
lines=1
)
treatment_plan_input = gr.Textbox(
label="Treatment Plan",
placeholder="Enter the treatment plan to guide the image generation...",
lines=2
)
generate_btn = gr.Button("Generate Image")
img_output = gr.Image(label="Generated Image")
generate_btn.click(
generate_image_and_reset_chat,
inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions],
outputs=[img_output, active_session, saved_sessions]
)
# ----- Chat Section -----
with gr.Column():
gr.Markdown("## Chat about the Image")
gr.Markdown(
"After generating an image, type details or descriptions about it. "
"Your message, along with the generated image and conversation history, will be sent for evaluation."
)
chatbot = gr.Chatbot(label="Chat History")
with gr.Row():
chat_input = gr.Textbox(label="Your Message", placeholder="Type your description here...", show_label=False)
send_btn = gr.Button("Send")
send_btn.click(
chat_respond,
inputs=[chat_input, active_session, saved_sessions],
outputs=[chat_input, chatbot, saved_sessions, active_session]
)
chat_input.submit(
chat_respond,
inputs=[chat_input, active_session, saved_sessions],
outputs=[chat_input, chatbot, saved_sessions, active_session]
)
# ----- Sidebar Section for Session Details -----
with gr.Column(variant="sidebar"):
gr.Markdown("## Saved Chat Sessions")
gr.Markdown(
"This sidebar automatically saves finished chat sessions. "
"Each session includes the prompt used, the generated image (as a data URL), "
"the topic focus, the treatment plan, the list of identified details, and the full chat history."
)
sessions_output = gr.JSON(label="Session Details", value={})
active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
# Update the current difficulty label when active_session changes.
active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
# Launch the app with public sharing enabled.
demo.launch()
|