Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import os | |
from PIL import Image | |
from database_operations import Neo4jDatabase | |
from graph_visualization import visualize_graph | |
from utils import extract_label_prefix, strip_keys, format_json, validate_json | |
from models.gemini_image_to_json import fetch_gemini_response | |
from models.openai_image_to_json import openaiprocess_image_to_json | |
from any_to_image import pdf_to_images, process_image | |
# Initialize Neo4j database | |
db = Neo4jDatabase("bolt://localhost:7687", "neo4j", "password123") | |
def dump_to_neo4j_with_confirmation(json_content, file_path, history, previous_states): | |
if not file_path: | |
return "No image uploaded or invalid file", history, previous_states, None | |
try: | |
json_data = json.loads(json_content) | |
except json.JSONDecodeError: | |
return "Invalid JSON data. Please check your input.", history, previous_states, None | |
label_prefix = extract_label_prefix(file_path) | |
if db.check_existing_graph(label_prefix): | |
previous_state = db.get_graph_data(label_prefix) | |
return f"A graph with label prefix '{label_prefix}' already exists in the database. Do you want to overwrite it?", history, previous_states, label_prefix | |
else: | |
json_data = strip_keys(json_data) | |
db.dump_to_neo4j(json_data['nodes'], json_data['edges'], label_prefix) | |
result = f"Data successfully dumped into the database with label prefix '{label_prefix}'." | |
new_history = f"{history}\n[NEW ENTRY] {result}" if history else f"[NEW ENTRY] {result}" | |
previous_states[label_prefix] = [] | |
return result, new_history, previous_states, None | |
def confirm_overwrite(confirmation, gradio_state, json_content, file_path, history, previous_states): | |
if confirmation.lower() == 'yes': | |
try: | |
label_prefix = extract_label_prefix(file_path) | |
previous_state = db.get_graph_data(label_prefix) | |
# print(f'previous_state from the confirm_overwrite function: {previous_state}') | |
# print(f'label_prefix from the confirm_overwrite function: {label_prefix}') | |
# print(f'previouse_states from the confirm_overwrite function: {previous_states}') | |
if label_prefix not in previous_states: | |
previous_states[label_prefix] = [] | |
previous_states[label_prefix].append(previous_state) | |
else: | |
previous_states[label_prefix].append(previous_state) | |
if len(previous_states[label_prefix]) > 3: | |
previous_states[label_prefix] = previous_states[label_prefix][-3:] | |
db.delete_graph(label_prefix) | |
json_data = json.loads(json_content) | |
json_data = strip_keys(json_data) | |
db.dump_to_neo4j(json_data['nodes'], json_data['edges'], label_prefix) | |
result = f"Data successfully overwritten in the database with label prefix '{label_prefix}'." | |
new_history = f"{history}\n[OVERWRITE] {result}" if history else f"[OVERWRITE] {result}" | |
return result, new_history, previous_states, "" | |
except json.JSONDecodeError: | |
return "Invalid JSON data. Please check your input.", history, previous_states, "" | |
else: | |
return "Operation cancelled. The existing graph was not overwritten.", history, previous_states, "" | |
def revert_last_action(history, previous_states): | |
if not history: | |
return "No actions to revert.", history, previous_states | |
last_action = history.split('\n')[-1] | |
label_prefix = last_action.split("'")[1] | |
if label_prefix in previous_states and previous_states[label_prefix]: | |
db.delete_graph(label_prefix) | |
db.dump_to_neo4j(previous_states[label_prefix][-1]['nodes'], previous_states[label_prefix][-1]['edges'], label_prefix) | |
new_history = history + f"\n[REVERT] Reverted overwrite of graph with label prefix '{label_prefix}'" | |
previous_states[label_prefix].pop() | |
return f"Reverted last action: {last_action}", new_history, previous_states | |
elif label_prefix in previous_states and not previous_states[label_prefix]: | |
db.delete_graph(label_prefix) | |
new_history = history + f"\n[REVERT] Deleted newly added graph with label prefix '{label_prefix}'" | |
del previous_states[label_prefix] | |
return f"Reverted last action: {last_action}", new_history, previous_states | |
else: | |
return "Unable to revert the last action.", history, previous_states | |
def update_graph_from_edited_json(json_content, physics_enabled): | |
try: | |
json_data = json.loads(json_content) | |
json_data = strip_keys(json_data) | |
validate_json(json_data) | |
return visualize_graph(json_data, physics_enabled), "" | |
except json.JSONDecodeError as e: | |
return None, f"Invalid JSON format: {str(e)}" | |
except ValueError as e: | |
return None, f"Invalid graph structure: {str(e)}" | |
except Exception as e: | |
return None, f"An unexpected error occurred: {str(e)}" | |
def fetch_kg(image_file_path, model_choice_state): | |
if image_file_path: | |
mind_map_image = Image.open(image_file_path) | |
if model_choice_state == 'Gemini': | |
print(f'model choice is gemini') | |
kg_json_text = fetch_gemini_response(mind_map_image) | |
elif model_choice_state == 'OpenAI': | |
print(f'model choice is openai') | |
kg_json_text = openaiprocess_image_to_json(mind_map_image) | |
json_data = json.loads(kg_json_text) | |
return format_json(json_data), "" | |
return "", "No image uploaded or invalid file" | |
def input_file_handler(file_path): | |
if file_path: | |
image_path, error = process_image(file_path) | |
return image_path, error | |
return "", "No image uploaded or invalid file" | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Image to Knowledge Graph Transformation") | |
with gr.Row(): | |
file_input = gr.File(label="Upload File", file_count="single", | |
type="filepath", | |
file_types=[".pdf", ".png", ".jpeg", ".jpg", ".heic"]) | |
image_file = gr.Image(label="Input Image", type="filepath", visible=False) | |
json_editor = gr.Textbox(label="Edit JSON", lines=15, placeholder="JSON data will appear here after image upload") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
CCW_rotate_button = gr.Button('Rotate Image Counter-Clockwise') | |
CW_rotate_button = gr.Button('Rotate Image Clockwise') | |
with gr.Column(): | |
model_call = gr.Button('Transform Image into KG representation', scale=2) | |
with gr.Row(): | |
physics_button = gr.Checkbox(value=True, label="Enable Graph Physics") | |
model_choice = gr.Radio(label="Select Model", choices=["OpenAI", "Gemini"], value="Gemini", interactive=True) | |
graph_output = gr.HTML(label="Graph Output") | |
error_output = gr.Textbox(label="Error Messages", interactive=False) | |
update_button = gr.Button("Update Graph") | |
dump_button = gr.Button("Dump to Neo4j") | |
revert_button = gr.Button("Revert Last Action") | |
history_block = gr.Textbox(label="History", placeholder="Graphs pushed to the Database", interactive=False, lines=5, max_lines=50) | |
history_state = gr.State("") | |
previous_states = gr.State({}) | |
confirmation_output = gr.Textbox(label="Confirmation Message", visible=False, interactive=False) | |
confirmation_input = gr.Textbox(label="Type 'yes' to confirm overwrite", visible=False, interactive=True) | |
confirm_button = gr.Button("Confirm Overwrite", visible=False) | |
#----------------------------------------- | |
# Added 2 examples for this deployment only | |
# examples_list = ["image_examples/image1.png", "image_examples/image2.png"] | |
# # same full chain of events as the file.upload() below | |
# def process_input(file): | |
# # First, call input_file_handler | |
# processed_file, error = input_file_handler(file) | |
# # Then, update image visibility | |
# visible_image, hidden_file = update_image_visibility(processed_file) | |
# return processed_file, error, visible_image, hidden_file | |
# example_component = gr.Examples(examples_list, inputs=file_input, fn=process_input, outputs=[image_file, error_output, image_file, file_input]) | |
#------------------------------------------- | |
file_input.upload( | |
fn=input_file_handler, | |
inputs=[file_input], | |
outputs=[image_file, error_output] | |
).then( | |
lambda image_file: ( | |
gr.Image(value=image_file, visible=True), | |
gr.File(visible=False) | |
), | |
inputs=[image_file], | |
outputs=[image_file, file_input] | |
) | |
image_file.clear( | |
lambda file_input, image_file: ( | |
gr.File(visible=True), | |
gr.Image(visible=False) | |
), | |
inputs=[file_input, image_file], | |
outputs=[file_input, image_file] | |
) | |
def rotate_image_to_left(image_path): | |
if image_path: | |
image = Image.open(image_path) | |
image = image.rotate(-90, expand=True) | |
image.save(image_path) | |
return image_path | |
CW_rotate_button.click( | |
fn=rotate_image_to_left, | |
inputs=[image_file], | |
outputs=[image_file] | |
) | |
def rotate_image_to_right(image_path): | |
if image_path: | |
image = Image.open(image_path) | |
image = image.rotate(90, expand=True) | |
image.save(image_path) | |
return image_path | |
CCW_rotate_button.click( | |
fn=rotate_image_to_right, | |
inputs=[image_file], | |
outputs=[image_file] | |
) | |
dump_button.click( | |
dump_to_neo4j_with_confirmation, | |
inputs=[json_editor, image_file, history_state, previous_states], | |
outputs=[confirmation_output, history_state, previous_states, gr.State()] | |
).then( | |
lambda message, history, previous_states, label_prefix: ( | |
gr.Textbox(value=message, visible=True), | |
gr.Textbox(visible=True), | |
gr.Button(visible=True), | |
history, | |
previous_states, | |
label_prefix | |
), | |
inputs=[confirmation_output, history_state, previous_states, gr.State()], | |
outputs=[confirmation_output, confirmation_input, confirm_button, history_state, previous_states, gr.State()] | |
).then( | |
lambda history: history, | |
inputs=[history_state], | |
outputs=[history_block] | |
) | |
gr.on( | |
triggers=[confirm_button.click, confirmation_input.submit], | |
fn=confirm_overwrite, | |
inputs=[confirmation_input, gr.State(), json_editor, image_file, history_state, previous_states], | |
outputs=[confirmation_output, history_state, previous_states, confirmation_input] | |
).then( | |
lambda confirmation_output, confirmation_input: ( | |
gr.Textbox(value=confirmation_output, visible=True), | |
gr.Textbox(value='', visible=False), | |
gr.Button(visible=False) | |
), | |
inputs=[confirmation_output, confirmation_input], | |
outputs=[confirmation_output, confirmation_input, confirm_button] | |
).then( | |
lambda history: history, | |
inputs=[history_state], | |
outputs=[history_block] | |
) | |
revert_button.click( | |
revert_last_action, | |
inputs=[history_state, previous_states], | |
outputs=[confirmation_output, history_state, previous_states] | |
).then( | |
lambda confirmation_output: gr.Textbox(value=confirmation_output, visible=True), | |
inputs=[confirmation_output], | |
outputs=[confirmation_output] | |
).then( | |
lambda history: history, | |
inputs=[history_state], | |
outputs=[history_block] | |
) | |
update_button.click( | |
update_graph_from_edited_json, | |
inputs=[json_editor, physics_button], | |
outputs=[graph_output, error_output] | |
) | |
physics_button.change( | |
update_graph_from_edited_json, | |
inputs=[json_editor, physics_button], | |
outputs=[graph_output, error_output] | |
) | |
model_call.click( | |
fn=fetch_kg, | |
inputs=[image_file, model_choice], | |
outputs=[json_editor, error_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |