import gradio as gr import json import os from PIL import Image from database_operations import Neo4jDatabase from graph_visualization import visualize_graph from utils import extract_label_prefix, strip_keys, format_json, validate_json from models.gemini_image_to_json import fetch_gemini_response from models.openai_image_to_json import openaiprocess_image_to_json from any_to_image import pdf_to_images, process_image # Initialize Neo4j database db = Neo4jDatabase("bolt://localhost:7687", "neo4j", "password123") def dump_to_neo4j_with_confirmation(json_content, file_path, history, previous_states): if not file_path: return "No image uploaded or invalid file", history, previous_states, None try: json_data = json.loads(json_content) except json.JSONDecodeError: return "Invalid JSON data. Please check your input.", history, previous_states, None label_prefix = extract_label_prefix(file_path) if db.check_existing_graph(label_prefix): previous_state = db.get_graph_data(label_prefix) return f"A graph with label prefix '{label_prefix}' already exists in the database. Do you want to overwrite it?", history, previous_states, label_prefix else: json_data = strip_keys(json_data) db.dump_to_neo4j(json_data['nodes'], json_data['edges'], label_prefix) result = f"Data successfully dumped into the database with label prefix '{label_prefix}'." new_history = f"{history}\n[NEW ENTRY] {result}" if history else f"[NEW ENTRY] {result}" previous_states[label_prefix] = [] return result, new_history, previous_states, None def confirm_overwrite(confirmation, gradio_state, json_content, file_path, history, previous_states): if confirmation.lower() == 'yes': try: label_prefix = extract_label_prefix(file_path) previous_state = db.get_graph_data(label_prefix) # print(f'previous_state from the confirm_overwrite function: {previous_state}') # print(f'label_prefix from the confirm_overwrite function: {label_prefix}') # print(f'previouse_states from the confirm_overwrite function: {previous_states}') if label_prefix not in previous_states: previous_states[label_prefix] = [] previous_states[label_prefix].append(previous_state) else: previous_states[label_prefix].append(previous_state) if len(previous_states[label_prefix]) > 3: previous_states[label_prefix] = previous_states[label_prefix][-3:] db.delete_graph(label_prefix) json_data = json.loads(json_content) json_data = strip_keys(json_data) db.dump_to_neo4j(json_data['nodes'], json_data['edges'], label_prefix) result = f"Data successfully overwritten in the database with label prefix '{label_prefix}'." new_history = f"{history}\n[OVERWRITE] {result}" if history else f"[OVERWRITE] {result}" return result, new_history, previous_states, "" except json.JSONDecodeError: return "Invalid JSON data. Please check your input.", history, previous_states, "" else: return "Operation cancelled. The existing graph was not overwritten.", history, previous_states, "" def revert_last_action(history, previous_states): if not history: return "No actions to revert.", history, previous_states last_action = history.split('\n')[-1] label_prefix = last_action.split("'")[1] if label_prefix in previous_states and previous_states[label_prefix]: db.delete_graph(label_prefix) db.dump_to_neo4j(previous_states[label_prefix][-1]['nodes'], previous_states[label_prefix][-1]['edges'], label_prefix) new_history = history + f"\n[REVERT] Reverted overwrite of graph with label prefix '{label_prefix}'" previous_states[label_prefix].pop() return f"Reverted last action: {last_action}", new_history, previous_states elif label_prefix in previous_states and not previous_states[label_prefix]: db.delete_graph(label_prefix) new_history = history + f"\n[REVERT] Deleted newly added graph with label prefix '{label_prefix}'" del previous_states[label_prefix] return f"Reverted last action: {last_action}", new_history, previous_states else: return "Unable to revert the last action.", history, previous_states def update_graph_from_edited_json(json_content, physics_enabled): try: json_data = json.loads(json_content) json_data = strip_keys(json_data) validate_json(json_data) return visualize_graph(json_data, physics_enabled), "" except json.JSONDecodeError as e: return None, f"Invalid JSON format: {str(e)}" except ValueError as e: return None, f"Invalid graph structure: {str(e)}" except Exception as e: return None, f"An unexpected error occurred: {str(e)}" def fetch_kg(image_file_path, model_choice_state): if image_file_path: mind_map_image = Image.open(image_file_path) if model_choice_state == 'Gemini': print(f'model choice is gemini') kg_json_text = fetch_gemini_response(mind_map_image) elif model_choice_state == 'OpenAI': print(f'model choice is openai') kg_json_text = openaiprocess_image_to_json(mind_map_image) json_data = json.loads(kg_json_text) return format_json(json_data), "" return "", "No image uploaded or invalid file" def input_file_handler(file_path): if file_path: image_path, error = process_image(file_path) return image_path, error return "", "No image uploaded or invalid file" # Gradio interface with gr.Blocks() as demo: gr.Markdown("## Image to Knowledge Graph Transformation") with gr.Row(): file_input = gr.File(label="Upload File", file_count="single", type="filepath", file_types=[".pdf", ".png", ".jpeg", ".jpg", ".heic"]) image_file = gr.Image(label="Input Image", type="filepath", visible=False) json_editor = gr.Textbox(label="Edit JSON", lines=15, placeholder="JSON data will appear here after image upload") with gr.Row(): with gr.Column(): with gr.Row(): CCW_rotate_button = gr.Button('Rotate Image Counter-Clockwise') CW_rotate_button = gr.Button('Rotate Image Clockwise') with gr.Column(): model_call = gr.Button('Transform Image into KG representation', scale=2) with gr.Row(): physics_button = gr.Checkbox(value=True, label="Enable Graph Physics") model_choice = gr.Radio(label="Select Model", choices=["OpenAI", "Gemini"], value="Gemini", interactive=True) graph_output = gr.HTML(label="Graph Output") error_output = gr.Textbox(label="Error Messages", interactive=False) update_button = gr.Button("Update Graph") dump_button = gr.Button("Dump to Neo4j") revert_button = gr.Button("Revert Last Action") history_block = gr.Textbox(label="History", placeholder="Graphs pushed to the Database", interactive=False, lines=5, max_lines=50) history_state = gr.State("") previous_states = gr.State({}) confirmation_output = gr.Textbox(label="Confirmation Message", visible=False, interactive=False) confirmation_input = gr.Textbox(label="Type 'yes' to confirm overwrite", visible=False, interactive=True) confirm_button = gr.Button("Confirm Overwrite", visible=False) #----------------------------------------- # Added 2 examples for this deployment only # examples_list = ["image_examples/image1.png", "image_examples/image2.png"] # # same full chain of events as the file.upload() below # def process_input(file): # # First, call input_file_handler # processed_file, error = input_file_handler(file) # # Then, update image visibility # visible_image, hidden_file = update_image_visibility(processed_file) # return processed_file, error, visible_image, hidden_file # example_component = gr.Examples(examples_list, inputs=file_input, fn=process_input, outputs=[image_file, error_output, image_file, file_input]) #------------------------------------------- file_input.upload( fn=input_file_handler, inputs=[file_input], outputs=[image_file, error_output] ).then( lambda image_file: ( gr.Image(value=image_file, visible=True), gr.File(visible=False) ), inputs=[image_file], outputs=[image_file, file_input] ) image_file.clear( lambda file_input, image_file: ( gr.File(visible=True), gr.Image(visible=False) ), inputs=[file_input, image_file], outputs=[file_input, image_file] ) def rotate_image_to_left(image_path): if image_path: image = Image.open(image_path) image = image.rotate(-90, expand=True) image.save(image_path) return image_path CW_rotate_button.click( fn=rotate_image_to_left, inputs=[image_file], outputs=[image_file] ) def rotate_image_to_right(image_path): if image_path: image = Image.open(image_path) image = image.rotate(90, expand=True) image.save(image_path) return image_path CCW_rotate_button.click( fn=rotate_image_to_right, inputs=[image_file], outputs=[image_file] ) dump_button.click( dump_to_neo4j_with_confirmation, inputs=[json_editor, image_file, history_state, previous_states], outputs=[confirmation_output, history_state, previous_states, gr.State()] ).then( lambda message, history, previous_states, label_prefix: ( gr.Textbox(value=message, visible=True), gr.Textbox(visible=True), gr.Button(visible=True), history, previous_states, label_prefix ), inputs=[confirmation_output, history_state, previous_states, gr.State()], outputs=[confirmation_output, confirmation_input, confirm_button, history_state, previous_states, gr.State()] ).then( lambda history: history, inputs=[history_state], outputs=[history_block] ) gr.on( triggers=[confirm_button.click, confirmation_input.submit], fn=confirm_overwrite, inputs=[confirmation_input, gr.State(), json_editor, image_file, history_state, previous_states], outputs=[confirmation_output, history_state, previous_states, confirmation_input] ).then( lambda confirmation_output, confirmation_input: ( gr.Textbox(value=confirmation_output, visible=True), gr.Textbox(value='', visible=False), gr.Button(visible=False) ), inputs=[confirmation_output, confirmation_input], outputs=[confirmation_output, confirmation_input, confirm_button] ).then( lambda history: history, inputs=[history_state], outputs=[history_block] ) revert_button.click( revert_last_action, inputs=[history_state, previous_states], outputs=[confirmation_output, history_state, previous_states] ).then( lambda confirmation_output: gr.Textbox(value=confirmation_output, visible=True), inputs=[confirmation_output], outputs=[confirmation_output] ).then( lambda history: history, inputs=[history_state], outputs=[history_block] ) update_button.click( update_graph_from_edited_json, inputs=[json_editor, physics_button], outputs=[graph_output, error_output] ) physics_button.change( update_graph_from_edited_json, inputs=[json_editor, physics_button], outputs=[graph_output, error_output] ) model_call.click( fn=fetch_kg, inputs=[image_file, model_choice], outputs=[json_editor, error_output] ) if __name__ == "__main__": demo.launch()