File size: 16,294 Bytes
686ef17
82d4c23
993333f
 
8850bc5
 
993333f
 
8850bc5
 
993333f
 
 
8850bc5
5d96bcc
8850bc5
 
9b134bd
8850bc5
993333f
1f6ec43
73b4bb6
8850bc5
 
 
 
29561c3
8850bc5
29561c3
 
5d96bcc
8850bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d96bcc
29561c3
 
8850bc5
 
29561c3
8850bc5
 
29561c3
8850bc5
29561c3
8850bc5
 
 
993333f
8850bc5
 
993333f
0b4a56a
 
8850bc5
993333f
9bbeb75
8850bc5
5d96bcc
 
 
993333f
 
 
 
 
0b4a56a
993333f
686ef17
8850bc5
e4f9a72
8850bc5
 
e4f9a72
8f401f6
29561c3
 
8850bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29561c3
e4f9a72
8850bc5
993333f
8850bc5
993333f
1f6ec43
993333f
29561c3
8850bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b4a56a
29561c3
8850bc5
 
 
 
 
 
 
 
 
 
 
29561c3
8850bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686ef17
44904f8
29561c3
8850bc5
 
 
29561c3
 
 
8850bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29561c3
8850bc5
 
29561c3
8850bc5
29561c3
 
 
8850bc5
29561c3
 
 
 
 
e44ce7e
8850bc5
f42860d
0cac8da
8850bc5
 
 
 
 
 
 
 
 
 
 
 
8f401f6
 
4330f40
b66a090
8850bc5
 
b66a090
 
 
 
8850bc5
b66a090
8850bc5
 
 
 
 
 
 
 
 
 
 
 
5d96bcc
 
29561c3
b66a090
 
 
8850bc5
29561c3
b66a090
 
 
 
 
 
 
8850bc5
b66a090
 
 
 
 
29561c3
 
44904f8
 
29561c3
 
 
44904f8
 
29561c3
b66a090
 
 
8f401f6
b66a090
29561c3
 
8850bc5
b66a090
 
29561c3
8850bc5
 
29561c3
8f401f6
8850bc5
993333f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
import gradio as gr
import io
import base64
import os
import json
import re
from PIL import Image
from huggingface_hub import InferenceClient
from google.generativeai import configure, GenerativeModel
from google.ai.generativelanguage import Content, Part

# Load API keys from environment variables
inference_api_key = os.environ.get("HF_TOKEN")
google_api_key = os.environ.get("GOOGLE_API_KEY")  # New Google API key

# Configure Google API
configure(api_key=google_api_key)

# Global variables to store the image data URL and prompt for the currently generated image.
global_image_data_url = None
global_image_prompt = None  # Still stored if needed elsewhere

def update_difficulty_label(active_session):
    return f"**Current Difficulty:** {active_session.get('difficulty', 'Very Simple')}"

def generate_prompt_from_options(difficulty, age, autism_level, topic_focus, treatment_plan=""):
    """
    Generate an image prompt using Google's Gemini model.
    """
    query = (
        f"""
        Follow the instructions below to generate an image generation prompt for an educational image intended for autistic children.
        Consider the following parameters:
          - Difficulty: {difficulty}
          - Age: {age}
          - Autism Level: {autism_level}
          - Topic Focus: {topic_focus}
          - Treatment Plan: {treatment_plan}

        Emphasize that the image should be clear, calming, and support understanding and communication. The style should match the difficulty level: for example, "Very Simple" produces very basic visuals while "Very Detailed" produces rich visuals.

        The image should specifically focus on the topic: "{topic_focus}".

        Please generate a prompt that instructs the image generation engine to produce an image with:
        1. Clarity and simplicity (minimalist backgrounds, clear subject)
        2. Literal representation with defined borders and consistent style
        3. Soft, muted colors and reduced visual complexity
        4. Positive, calm scenes
        5. Clear focus on the specified topic

        Use descriptive and detailed language.
        """
    )

    # Initialize the Gemini Pro model
    model = GenerativeModel('gemini-2.0-flash-lite')

    # Generate content using the Gemini model
    response = model.generate_content(query)

    return response.text.strip()

def generate_image_fn(selected_prompt, guidance_scale=7.5,
                      negative_prompt="ugly, blurry, poorly drawn hands, lewd, nude, deformed, missing limbs, missing eyes, missing arms, missing legs",
                      num_inference_steps=50):
    """
    Generate an image from the prompt via the Hugging Face Inference API.
    Convert the image to a data URL.
    """
    global global_image_data_url, global_image_prompt
    global_image_prompt = selected_prompt
    image_client = InferenceClient(provider="hf-inference", api_key=inference_api_key)
    image = image_client.text_to_image(
        selected_prompt,
        model="stabilityai/stable-diffusion-3.5-large-turbo",
        guidance_scale=guidance_scale,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps
    )
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_bytes = buffered.getvalue()
    img_b64 = base64.b64encode(img_bytes).decode("utf-8")
    global_image_data_url = f"data:image/png;base64,{img_b64}"
    return image

def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions):
    """
    Generate a new image (with the current difficulty) and reset the chat.
    Now includes the topic_focus parameter to specify what the image should focus on.
    """
    new_sessions = saved_sessions.copy()
    if active_session.get("prompt"):
        new_sessions.append(active_session)
    # Use the current difficulty from the active session (which should be updated if advanced)
    current_difficulty = active_session.get("difficulty", "Very Simple")
    generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
    image = generate_image_fn(generated_prompt)
    new_active_session = {
        "prompt": generated_prompt,
        "image": global_image_data_url,
        "chat": [],
        "treatment_plan": treatment_plan,
        "topic_focus": topic_focus,
        "identified_details": [],
        "difficulty": current_difficulty,
        "autism_level": autism_level,
        "age": age
    }
    return image, new_active_session, new_sessions

def compare_details_chat_fn(user_details, treatment_plan, chat_history, identified_details):
    """
    Evaluate the child's description using Google's Gemini Vision model.
    """
    if not global_image_data_url:
        return "Please generate an image first."

    history_text = ""
    if chat_history:
        history_text = "\n\n### Previous Conversation:\n"
        for idx, (user_msg, bot_msg) in enumerate(chat_history, 1):
            history_text += f"Turn {idx}:\nUser: {user_msg}\nTeacher: {bot_msg}\n"

    identified_details_text = ""
    if identified_details:
        identified_details_text = "\n\n### Previously Identified Details:\n" + "\n".join(f"- {detail}" for detail in identified_details)

    message_text = (
        f"{history_text}{identified_details_text}\n\n"
        f"Based on the image provided above, please evaluate the following description given by the child:\n"
        f"'{user_details}'\n\n"
        "You are a kind and encouraging teacher speaking to a child. Use simple, clear language. "
        "Praise the child's correct observations and provide a gentle hint if something is missing. "
        "Keep your feedback positive and easy to understand.\n\n"
        "Focus on these evaluation criteria:\n"
        "1. **Object Identification** – Did the child mention the main objects?\n"
        "2. **Color & Shape Accuracy** – Were the colors and shapes described correctly?\n"
        "3. **Clarity & Simplicity** – Was the description clear and easy to understand?\n"
        "4. **Overall Communication** – How well did the child communicate their thoughts?\n\n"
        "Note: As difficulty increases, the expected level of detail is higher. Evaluate accordingly.\n\n"
        "Return your evaluation strictly as a JSON object with the following keys:\n"
        "{\n"
        "  \"scores\": {\n"
        "    \"object_identification\": <number>,\n"
        "    \"color_shape_accuracy\": <number>,\n"
        "    \"clarity_simplicity\": <number>,\n"
        "    \"overall_communication\": <number>\n"
        "  },\n"
        "  \"final_score\": <number>,\n"
        "  \"feedback\": \"<string>\",\n"
        "  \"hint\": \"<string>\",\n"
        "  \"advance\": <boolean>\n"
        "}\n\n"
        "Do not include any additional text outside the JSON."
    )

    # Remove the data:image/png;base64, prefix to get just the base64 string
    base64_img = global_image_data_url.split(",")[1]

    # Create a Gemini Vision Pro model
    vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')

    # Create the content with image and text using the correct parameters
    # Use 'inline_data' instead of 'content' for the image part
    image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
    text_part = Part(text=message_text)
    multimodal_content = Content(parts=[image_part, text_part])

    # Generate evaluation using the vision model
    response = vision_model.generate_content(multimodal_content)

    return response.text

def evaluate_scores(evaluation_text, current_difficulty):
    """
    Parse the JSON evaluation and decide if the child advances.
    The threshold scales with difficulty:
       Very Simple: 70, Simple: 75, Moderate: 80, Detailed: 85, Very Detailed: 90.
    """
    try:
        json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
        if json_match:
            json_str = json_match.group(0)
            evaluation = json.loads(json_str)
        else:
            raise ValueError("No JSON object found in the response.")
        final_score = evaluation.get("final_score", 0)
        hint = evaluation.get("hint", "Keep trying!")
        advance = evaluation.get("advance", False)
        difficulty_thresholds = {
            "Very Simple": 70,
            "Simple": 75,
            "Moderate": 80,
            "Detailed": 85,
            "Very Detailed": 90
        }
        current_threshold = difficulty_thresholds.get(current_difficulty, 70)
        difficulty_mapping = {
            "Very Simple": "Simple",
            "Simple": "Moderate",
            "Moderate": "Detailed",
            "Detailed": "Very Detailed",
            "Very Detailed": "Very Detailed"
        }
        if final_score >= current_threshold or advance:
            new_difficulty = difficulty_mapping.get(current_difficulty, current_difficulty)
            response_msg = (f"Great job! Your final score is {final_score}, which meets the target of {current_threshold}. "
                            f"You've advanced to {new_difficulty} difficulty.")
            return response_msg, new_difficulty
        else:
            response_msg = (f"Your final score is   {final_score} (\n target: {current_threshold}). {hint} \n "
                            f"Please try again at the {current_difficulty} level.")
            return response_msg, current_difficulty
    except Exception as e:
        return f"Error processing evaluation output: {str(e)}", current_difficulty

def chat_respond(user_message, active_session, saved_sessions):
    """
    Process a new chat message.
    Evaluate the child's description. If the evaluation indicates advancement,
    update the difficulty, generate a new image (resetting image and chat), and update the difficulty label.
    """
    if not active_session.get("image"):
        bot_message = "Please generate an image first."
        updated_chat = active_session.get("chat", []) + [(user_message, bot_message)]
        active_session["chat"] = updated_chat
        return "", updated_chat, saved_sessions, active_session

    chat_history = active_session.get("chat", [])
    identified_details = active_session.get("identified_details", [])
    raw_evaluation = compare_details_chat_fn(user_message, "", chat_history, identified_details)
    current_difficulty = active_session.get("difficulty", "Very Simple")
    evaluation_response, updated_difficulty = evaluate_scores(raw_evaluation, current_difficulty)
    bot_message = evaluation_response

    # If the child advanced, update difficulty and generate a new image
    if updated_difficulty != current_difficulty:
        # Update the active session's difficulty before generating a new prompt
        active_session["difficulty"] = updated_difficulty
        age = active_session.get("age", "3")
        autism_level = active_session.get("autism_level", "Level 1")
        topic_focus = active_session.get("topic_focus", "")
        treatment_plan = active_session.get("treatment_plan", "")
        new_image, new_active_session, new_sessions = generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions)
        new_active_session["chat"].append(("System", f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you."))
        active_session = new_active_session
        bot_message = f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you."
        saved_sessions = new_sessions
    else:
        updated_chat = active_session.get("chat", []) + [(user_message, bot_message)]
        active_session["chat"] = updated_chat

    return "", active_session["chat"], saved_sessions, active_session

def update_sessions(saved_sessions, active_session):
    """
    Combine finished sessions with the active session for display.
    """
    if active_session and active_session.get("prompt"):
        return saved_sessions + [active_session]
    return saved_sessions

##############################################
# Gradio Interface
##############################################
with gr.Blocks() as demo:
    # The active session now starts with difficulty "Very Simple"
    active_session = gr.State({
        "prompt": None,
        "image": None,
        "chat": [],
        "treatment_plan": "",
        "topic_focus": "",
        "identified_details": [],
        "difficulty": "Very Simple",
        "age": "3",
        "autism_level": "Level 1"
    })
    saved_sessions = gr.State([])

    with gr.Column():
        gr.Markdown("# Image Generation & Chat Inference")
        # Display current difficulty label
        difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple")

        # ----- Image Generation Section -----
        with gr.Column():
            gr.Markdown("## Generate Image")
            gr.Markdown("Enter your age, select your autism level, specify a topic focus, and provide the treatment plan to generate an image based on the current difficulty level.")
            with gr.Row():
                age_input = gr.Textbox(label="Age", placeholder="Enter age...", value="3")
                autism_level_dropdown = gr.Dropdown(label="Autism Level", choices=["Level 1", "Level 2", "Level 3"], value="Level 1")

            topic_focus_input = gr.Textbox(
                label="Topic Focus",
                placeholder="Enter a specific topic or detail to focus on (e.g., 'animals', 'emotions', 'daily routines')...",
                lines=1
            )

            treatment_plan_input = gr.Textbox(
                label="Treatment Plan",
                placeholder="Enter the treatment plan to guide the image generation...",
                lines=2
            )
            generate_btn = gr.Button("Generate Image")
            img_output = gr.Image(label="Generated Image")
            generate_btn.click(
                generate_image_and_reset_chat,
                inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions],
                outputs=[img_output, active_session, saved_sessions]
            )

        # ----- Chat Section -----
        with gr.Column():
            gr.Markdown("## Chat about the Image")
            gr.Markdown(
                "After generating an image, type details or descriptions about it. "
                "Your message, along with the generated image and conversation history, will be sent for evaluation."
            )
            chatbot = gr.Chatbot(label="Chat History")
            with gr.Row():
                chat_input = gr.Textbox(label="Your Message", placeholder="Type your description here...", show_label=False)
                send_btn = gr.Button("Send")
            send_btn.click(
                chat_respond,
                inputs=[chat_input, active_session, saved_sessions],
                outputs=[chat_input, chatbot, saved_sessions, active_session]
            )
            chat_input.submit(
                chat_respond,
                inputs=[chat_input, active_session, saved_sessions],
                outputs=[chat_input, chatbot, saved_sessions, active_session]
            )

    # ----- Sidebar Section for Session Details -----
    with gr.Column(variant="sidebar"):
        gr.Markdown("## Saved Chat Sessions")
        gr.Markdown(
            "This sidebar automatically saves finished chat sessions. "
            "Each session includes the prompt used, the generated image (as a data URL), "
            "the topic focus, the treatment plan, the list of identified details, and the full chat history."
        )
        sessions_output = gr.JSON(label="Session Details", value={})
        active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
        # Update the current difficulty label when active_session changes.
        active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
        saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)

# Launch the app with public sharing enabled.
demo.launch()