Zaherrr's picture
Update app.py
cc1ecab verified
import gradio as gr
import json
import os
from PIL import Image
from database_operations import Neo4jDatabase
from graph_visualization import visualize_graph
from utils import extract_label_prefix, strip_keys, format_json, validate_json
from models.gemini_image_to_json import fetch_gemini_response
from models.openai_image_to_json import openaiprocess_image_to_json
from any_to_image import pdf_to_images, process_image
# Initialize Neo4j database
db = Neo4jDatabase("bolt://localhost:7687", "neo4j", "password123")
def dump_to_neo4j_with_confirmation(json_content, file_path, history, previous_states):
if not file_path:
return "No image uploaded or invalid file", history, previous_states, None
try:
json_data = json.loads(json_content)
except json.JSONDecodeError:
return "Invalid JSON data. Please check your input.", history, previous_states, None
label_prefix = extract_label_prefix(file_path)
if db.check_existing_graph(label_prefix):
previous_state = db.get_graph_data(label_prefix)
return f"A graph with label prefix '{label_prefix}' already exists in the database. Do you want to overwrite it?", history, previous_states, label_prefix
else:
json_data = strip_keys(json_data)
db.dump_to_neo4j(json_data['nodes'], json_data['edges'], label_prefix)
result = f"Data successfully dumped into the database with label prefix '{label_prefix}'."
new_history = f"{history}\n[NEW ENTRY] {result}" if history else f"[NEW ENTRY] {result}"
previous_states[label_prefix] = []
return result, new_history, previous_states, None
def confirm_overwrite(confirmation, gradio_state, json_content, file_path, history, previous_states):
if confirmation.lower() == 'yes':
try:
label_prefix = extract_label_prefix(file_path)
previous_state = db.get_graph_data(label_prefix)
# print(f'previous_state from the confirm_overwrite function: {previous_state}')
# print(f'label_prefix from the confirm_overwrite function: {label_prefix}')
# print(f'previouse_states from the confirm_overwrite function: {previous_states}')
if label_prefix not in previous_states:
previous_states[label_prefix] = []
previous_states[label_prefix].append(previous_state)
else:
previous_states[label_prefix].append(previous_state)
if len(previous_states[label_prefix]) > 3:
previous_states[label_prefix] = previous_states[label_prefix][-3:]
db.delete_graph(label_prefix)
json_data = json.loads(json_content)
json_data = strip_keys(json_data)
db.dump_to_neo4j(json_data['nodes'], json_data['edges'], label_prefix)
result = f"Data successfully overwritten in the database with label prefix '{label_prefix}'."
new_history = f"{history}\n[OVERWRITE] {result}" if history else f"[OVERWRITE] {result}"
return result, new_history, previous_states, ""
except json.JSONDecodeError:
return "Invalid JSON data. Please check your input.", history, previous_states, ""
else:
return "Operation cancelled. The existing graph was not overwritten.", history, previous_states, ""
def revert_last_action(history, previous_states):
if not history:
return "No actions to revert.", history, previous_states
last_action = history.split('\n')[-1]
label_prefix = last_action.split("'")[1]
if label_prefix in previous_states and previous_states[label_prefix]:
db.delete_graph(label_prefix)
db.dump_to_neo4j(previous_states[label_prefix][-1]['nodes'], previous_states[label_prefix][-1]['edges'], label_prefix)
new_history = history + f"\n[REVERT] Reverted overwrite of graph with label prefix '{label_prefix}'"
previous_states[label_prefix].pop()
return f"Reverted last action: {last_action}", new_history, previous_states
elif label_prefix in previous_states and not previous_states[label_prefix]:
db.delete_graph(label_prefix)
new_history = history + f"\n[REVERT] Deleted newly added graph with label prefix '{label_prefix}'"
del previous_states[label_prefix]
return f"Reverted last action: {last_action}", new_history, previous_states
else:
return "Unable to revert the last action.", history, previous_states
def update_graph_from_edited_json(json_content, physics_enabled):
try:
json_data = json.loads(json_content)
json_data = strip_keys(json_data)
validate_json(json_data)
return visualize_graph(json_data, physics_enabled), ""
except json.JSONDecodeError as e:
return None, f"Invalid JSON format: {str(e)}"
except ValueError as e:
return None, f"Invalid graph structure: {str(e)}"
except Exception as e:
return None, f"An unexpected error occurred: {str(e)}"
def fetch_kg(image_file_path, model_choice_state):
if image_file_path:
mind_map_image = Image.open(image_file_path)
if model_choice_state == 'Gemini':
print(f'model choice is gemini')
kg_json_text = fetch_gemini_response(mind_map_image)
elif model_choice_state == 'OpenAI':
print(f'model choice is openai')
kg_json_text = openaiprocess_image_to_json(mind_map_image)
json_data = json.loads(kg_json_text)
return format_json(json_data), ""
return "", "No image uploaded or invalid file"
def input_file_handler(file_path):
if file_path:
image_path, error = process_image(file_path)
return image_path, error
return "", "No image uploaded or invalid file"
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Image to Knowledge Graph Transformation")
with gr.Row():
file_input = gr.File(label="Upload File", file_count="single",
type="filepath",
file_types=[".pdf", ".png", ".jpeg", ".jpg", ".heic"])
image_file = gr.Image(label="Input Image", type="filepath", visible=False)
json_editor = gr.Textbox(label="Edit JSON", lines=15, placeholder="JSON data will appear here after image upload")
with gr.Row():
with gr.Column():
with gr.Row():
CCW_rotate_button = gr.Button('Rotate Image Counter-Clockwise')
CW_rotate_button = gr.Button('Rotate Image Clockwise')
with gr.Column():
model_call = gr.Button('Transform Image into KG representation', scale=2)
with gr.Row():
physics_button = gr.Checkbox(value=True, label="Enable Graph Physics")
model_choice = gr.Radio(label="Select Model", choices=["OpenAI", "Gemini"], value="Gemini", interactive=True)
graph_output = gr.HTML(label="Graph Output")
error_output = gr.Textbox(label="Error Messages", interactive=False)
update_button = gr.Button("Update Graph")
dump_button = gr.Button("Dump to Neo4j")
revert_button = gr.Button("Revert Last Action")
history_block = gr.Textbox(label="History", placeholder="Graphs pushed to the Database", interactive=False, lines=5, max_lines=50)
history_state = gr.State("")
previous_states = gr.State({})
confirmation_output = gr.Textbox(label="Confirmation Message", visible=False, interactive=False)
confirmation_input = gr.Textbox(label="Type 'yes' to confirm overwrite", visible=False, interactive=True)
confirm_button = gr.Button("Confirm Overwrite", visible=False)
#-----------------------------------------
# Added 2 examples for this deployment only
# examples_list = ["image_examples/image1.png", "image_examples/image2.png"]
# # same full chain of events as the file.upload() below
# def process_input(file):
# # First, call input_file_handler
# processed_file, error = input_file_handler(file)
# # Then, update image visibility
# visible_image, hidden_file = update_image_visibility(processed_file)
# return processed_file, error, visible_image, hidden_file
# example_component = gr.Examples(examples_list, inputs=file_input, fn=process_input, outputs=[image_file, error_output, image_file, file_input])
#-------------------------------------------
file_input.upload(
fn=input_file_handler,
inputs=[file_input],
outputs=[image_file, error_output]
).then(
lambda image_file: (
gr.Image(value=image_file, visible=True),
gr.File(visible=False)
),
inputs=[image_file],
outputs=[image_file, file_input]
)
image_file.clear(
lambda file_input, image_file: (
gr.File(visible=True),
gr.Image(visible=False)
),
inputs=[file_input, image_file],
outputs=[file_input, image_file]
)
def rotate_image_to_left(image_path):
if image_path:
image = Image.open(image_path)
image = image.rotate(-90, expand=True)
image.save(image_path)
return image_path
CW_rotate_button.click(
fn=rotate_image_to_left,
inputs=[image_file],
outputs=[image_file]
)
def rotate_image_to_right(image_path):
if image_path:
image = Image.open(image_path)
image = image.rotate(90, expand=True)
image.save(image_path)
return image_path
CCW_rotate_button.click(
fn=rotate_image_to_right,
inputs=[image_file],
outputs=[image_file]
)
dump_button.click(
dump_to_neo4j_with_confirmation,
inputs=[json_editor, image_file, history_state, previous_states],
outputs=[confirmation_output, history_state, previous_states, gr.State()]
).then(
lambda message, history, previous_states, label_prefix: (
gr.Textbox(value=message, visible=True),
gr.Textbox(visible=True),
gr.Button(visible=True),
history,
previous_states,
label_prefix
),
inputs=[confirmation_output, history_state, previous_states, gr.State()],
outputs=[confirmation_output, confirmation_input, confirm_button, history_state, previous_states, gr.State()]
).then(
lambda history: history,
inputs=[history_state],
outputs=[history_block]
)
gr.on(
triggers=[confirm_button.click, confirmation_input.submit],
fn=confirm_overwrite,
inputs=[confirmation_input, gr.State(), json_editor, image_file, history_state, previous_states],
outputs=[confirmation_output, history_state, previous_states, confirmation_input]
).then(
lambda confirmation_output, confirmation_input: (
gr.Textbox(value=confirmation_output, visible=True),
gr.Textbox(value='', visible=False),
gr.Button(visible=False)
),
inputs=[confirmation_output, confirmation_input],
outputs=[confirmation_output, confirmation_input, confirm_button]
).then(
lambda history: history,
inputs=[history_state],
outputs=[history_block]
)
revert_button.click(
revert_last_action,
inputs=[history_state, previous_states],
outputs=[confirmation_output, history_state, previous_states]
).then(
lambda confirmation_output: gr.Textbox(value=confirmation_output, visible=True),
inputs=[confirmation_output],
outputs=[confirmation_output]
).then(
lambda history: history,
inputs=[history_state],
outputs=[history_block]
)
update_button.click(
update_graph_from_edited_json,
inputs=[json_editor, physics_button],
outputs=[graph_output, error_output]
)
physics_button.change(
update_graph_from_edited_json,
inputs=[json_editor, physics_button],
outputs=[graph_output, error_output]
)
model_call.click(
fn=fetch_kg,
inputs=[image_file, model_choice],
outputs=[json_editor, error_output]
)
if __name__ == "__main__":
demo.launch()