Zaherrr commited on
Commit
b2ad712
·
verified ·
1 Parent(s): e224c53

Upload 8 files

Browse files
any_to_image.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import fitz # PyMuPDF
3
+ import os
4
+ import pyheif
5
+
6
+
7
+ def pdf_to_images(pdf_path):
8
+ # Ensure the PDF file exists
9
+ if not os.path.exists(pdf_path):
10
+ print(f"The file {pdf_path} does not exist.")
11
+ return []
12
+
13
+ # Open the PDF file
14
+ pdf_document = fitz.open(pdf_path)
15
+
16
+ # List to store PIL images
17
+ images = []
18
+
19
+ # Process each page
20
+ for page_num in range(len(pdf_document)):
21
+ # Get the page
22
+ page = pdf_document.load_page(page_num)
23
+
24
+ # Convert the page to a PIL image
25
+ pix = page.get_pixmap()
26
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
27
+
28
+ # Append the image to the list
29
+ images.append(img)
30
+
31
+ # Close the PDF document
32
+ pdf_document.close()
33
+
34
+ return images
35
+
36
+
37
+ def heic_to_image(heic_path):
38
+
39
+ # Ensure the HEIC file exists
40
+ if not os.path.exists(heic_path):
41
+ print(f"The file {heic_path} does not exist.")
42
+ return []
43
+
44
+ if heic_path.endswith(".HEIC"):
45
+ # Create the new filename by replacing .HEIC with .heic
46
+ new_file_path = heic_path[:-5] + ".heic"
47
+ # Rename the file
48
+ os.rename(heic_path, new_file_path)
49
+ print(f"Renamed: {heic_path} to {new_file_path}")
50
+ heic_path = new_file_path
51
+
52
+ try:
53
+ # Open the HEIC file
54
+ heif_file = pyheif.read(heic_path)
55
+
56
+ # Convert to a PIL image
57
+ image = Image.frombytes(
58
+ heif_file.mode,
59
+ heif_file.size,
60
+ heif_file.data,
61
+ "raw",
62
+ heif_file.mode,
63
+ heif_file.stride,
64
+ )
65
+ except Exception as e:
66
+ print(f"An error occurred while processing the HEIC file: {e}")
67
+ return []
68
+
69
+ return image
70
+
71
+
72
+ def process_image(file_path):
73
+
74
+ if file_path.endswith(".pdf"):
75
+
76
+ images = pdf_to_images(file_path)
77
+
78
+ # Save the images so we can pass their path to the rest of the gradio functions
79
+ if images:
80
+ output_dir = "data_processed"
81
+ os.makedirs(output_dir, exist_ok=True)
82
+ image_paths = []
83
+ label_prefix = file_path.split(os.sep)[-1].split(".")[0]
84
+
85
+ for i, img in enumerate(images, start=1):
86
+ image_path = os.path.join(output_dir, f"{label_prefix}_page_{i}.png")
87
+ img.save(image_path, "PNG")
88
+ # saving it back to the same path assigned by gradio so that we can benefit from gradio's cache
89
+ # replace the .pdf with .png in the file_path
90
+ file_path = file_path.replace(".pdf", ".png")
91
+ # img.save(file_path, "PNG")
92
+ # image_paths.append(image_path)
93
+ image_paths.append(file_path)
94
+ print(f"Saved {image_path}")
95
+ print(f"Saved {file_path}")
96
+ # saving the first image of the pdf only to be processed in the gradio UI.
97
+ # TODO: Accomodate for multiple images
98
+ images[0].save(file_path, "PNG")
99
+
100
+ return file_path, ""
101
+
102
+ # return image_paths[0], ""
103
+ else:
104
+ return None, "No image uploaded or invalid file"
105
+
106
+ elif (
107
+ file_path.endswith(".png")
108
+ or file_path.endswith(".jpg")
109
+ or file_path.endswith(".jpeg")
110
+ ):
111
+ print(
112
+ f"file_path from the image processing function for compatible images: {file_path}"
113
+ )
114
+ return file_path, ""
115
+
116
+ elif file_path.endswith(".JPEG"):
117
+ # Create the new filename by replacing .JPEG with .jpeg
118
+ new_file_path = file_path[:-5] + ".jpeg"
119
+ # Rename the file
120
+ os.rename(file_path, new_file_path)
121
+ print(f"Renamed: {file_path} to {new_file_path}")
122
+ file_path = new_file_path
123
+ return file_path, ""
124
+
125
+ elif file_path.endswith(".JPG"):
126
+ # Create the new filename by replacing .JPG with .jpg
127
+ new_file_path = file_path[:-4] + ".jpg"
128
+ # Rename the file
129
+ os.rename(file_path, new_file_path)
130
+ print(f"Renamed: {file_path} to {new_file_path}")
131
+ file_path = new_file_path
132
+ return file_path, ""
133
+
134
+ elif file_path.endswith(".PNG"):
135
+ # Create the new filename by replacing .PNG with .png
136
+ new_file_path = file_path[:-4] + ".png"
137
+ # Rename the file
138
+ os.rename(file_path, new_file_path)
139
+ print(f"Renamed: {file_path} to {new_file_path}")
140
+ file_path = new_file_path
141
+ return file_path, ""
142
+
143
+ elif file_path.endswith(".heic") or file_path.endswith(".HEIC"):
144
+ image = heic_to_image(file_path)
145
+
146
+ output_dir = "data_processed"
147
+ os.makedirs(output_dir, exist_ok=True)
148
+
149
+ image_path = os.path.join(
150
+ output_dir, f"{os.path.splitext(os.path.basename(file_path))[0]}.png"
151
+ )
152
+ image.save(image_path, "PNG")
153
+ # saving it back to the same path assigned by gradio so that we can benefit from gradio's cache
154
+ image.save(file_path, "PNG")
155
+
156
+ print(f"Saved {image_path}")
157
+ print(f"Saved {file_path}")
158
+
159
+ # return image_path, ""
160
+ return file_path, ""
161
+
162
+ else:
163
+ return None, "No image uploaded or invalid file"
164
+
165
+
166
+ # Example usage
167
+ if __name__ == "__main__":
168
+ pdf_path = "data/Augustin REMY Mindmap OOP .pdf"
169
+ images = pdf_to_images(pdf_path)
170
+
171
+ # Example: Save the images if you want to check them
172
+ if images:
173
+ output_dir = "data_processed"
174
+ os.makedirs(output_dir, exist_ok=True)
175
+
176
+ for i, img in enumerate(images, start=1):
177
+ image_path = os.path.join(output_dir, f"page_{i}.png")
178
+ img.save(image_path, "PNG")
179
+ print(f"Saved {image_path}")
database_operations.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from neo4j import GraphDatabase
2
+
3
+ class Neo4jDatabase:
4
+ def __init__(self, uri, username, password):
5
+ self.driver = GraphDatabase.driver(uri, auth=(username, password))
6
+
7
+ def close(self):
8
+ self.driver.close()
9
+
10
+ def dump_to_neo4j(self, nodes, edges, label_prefix):
11
+ with self.driver.session() as session:
12
+ for node in nodes:
13
+ session.run(f"CREATE (n:{label_prefix}:Node {{id: $id, label: $label}})", id=node['id'], label=node['label'])
14
+
15
+ for edge in edges:
16
+ session.run(f"""
17
+ MATCH (a:{label_prefix}:Node {{id: $source}}), (b:{label_prefix}:Node {{id: $target}})
18
+ CREATE (a)-[r:RELATION {{type: $type}}]->(b)
19
+ """, source=edge['source'], target=edge['target'], type=edge['type'])
20
+
21
+ def check_existing_graph(self, label_prefix):
22
+ with self.driver.session() as session:
23
+ result = session.run(f"MATCH (n:{label_prefix}) RETURN count(n) as count")
24
+ count = result.single()["count"]
25
+ return count > 0
26
+
27
+ def get_graph_data(self, label_prefix):
28
+ with self.driver.session() as session:
29
+ nodes = session.run(f"MATCH (n:{label_prefix}) RETURN n.id AS id, n.label AS label")
30
+ edges = session.run(f"MATCH (a:{label_prefix})-[r]->(b:{label_prefix}) RETURN a.id AS source, b.id AS target, type(r) AS type")
31
+
32
+ nodes = [{"id": record["id"], "label": record["label"]} for record in nodes]
33
+ edges = [{"source": record["source"], "target": record["target"], "type": record["type"]} for record in edges]
34
+
35
+ return {"nodes": nodes, "edges": edges}
36
+
37
+ def delete_graph(self, label_prefix):
38
+ with self.driver.session() as session:
39
+ session.run(f"MATCH (n:{label_prefix}) DETACH DELETE n")
gemini_image_to_json.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import google.generativeai as genai
2
+
3
+ from dotenv import load_dotenv
4
+ import os
5
+
6
+ load_dotenv()
7
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
8
+
9
+ genai.configure(api_key=GOOGLE_API_KEY)
10
+
11
+ # gemini-1.5-pro only gives 50 requests per day. check https://ai.google.dev/pricing for more details
12
+ # model = genai.GenerativeModel('gemini-1.5-pro',
13
+ model = genai.GenerativeModel(
14
+ "gemini-1.5-flash",
15
+ # Set the `response_mime_type` to output JSON
16
+ # Pass the schema object to the `response_schema` field
17
+ generation_config={
18
+ "response_mime_type": "application/json",
19
+ "temperature": 0.0,
20
+ },
21
+ )
22
+ # "response_schema": Recipe, 'max_output_tokens':4000})
23
+
24
+ PROMPT = """
25
+ You are responsible for extracting the entities (nodes) and relationships (edges) from the images of mind maps. The mind maps are for Object Oriented Programming.
26
+ Don't make up facts, just extracts them. Do not create new entity types that aren't mentioned in the image, and at the same time don't miss anything.
27
+ Give the output in JSON format with this schema:
28
+ {
29
+ "nodes": [{"id": "1", "label": string},{"id": "2", "label": string}],"edges": [{"source": SOURCE_ID, "target": TARGET_ID, "type": "->"},{"source": SOURCE_ID, "target": TARGET_ID, "type": "->"}]
30
+ }
31
+
32
+ Now extract the entities and relationships from this image:
33
+ """
34
+
35
+
36
+ def fetch_gemini_response(mind_map_image):
37
+
38
+ print("fetching gemini response")
39
+ response = model.generate_content([PROMPT, mind_map_image], stream=False)
40
+
41
+ return response.text
graph_visualization.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pyvis.network import Network
3
+
4
+ def create_graph(nodes, edges, physics_enabled=True):
5
+ net = Network(notebook=True, height='100vh', width='100vw', bgcolor='#222222', font_color='white', cdn_resources='remote')
6
+
7
+ for node in nodes:
8
+ net.add_node(
9
+ node['id'],
10
+ label=node['label'],
11
+ title=node['label'],
12
+ color='blue' if node['label'] == 'OOP' else 'green'
13
+ )
14
+
15
+ for edge in edges:
16
+ net.add_edge(edge['source'], edge['target'], title=edge['type'])
17
+
18
+ net.force_atlas_2based(
19
+ gravity=-50,
20
+ central_gravity=0.01,
21
+ spring_length=100,
22
+ spring_strength=0.08,
23
+ damping=0.4
24
+ )
25
+
26
+ options = {
27
+ "nodes": {
28
+ "physics": physics_enabled
29
+ },
30
+ "edges": {
31
+ "smooth": True
32
+ },
33
+ "interaction": {
34
+ "hover": True,
35
+ "zoomView": True
36
+ },
37
+ "physics": {
38
+ "enabled": physics_enabled,
39
+ "stabilization": {
40
+ "enabled": True,
41
+ "iterations": 200
42
+ }
43
+ }
44
+ }
45
+
46
+ net.set_options(json.dumps(options))
47
+ return net
48
+
49
+ def visualize_graph(json_data, physics_enabled=True):
50
+ if isinstance(json_data, str):
51
+ data = json.loads(json_data)
52
+ else:
53
+ data = json_data
54
+ nodes = data['nodes']
55
+ edges = data['edges']
56
+ net = create_graph(nodes, edges, physics_enabled)
57
+ html = net.generate_html()
58
+ html = html.replace("'", "\"")
59
+ html = html.replace('<div id="mynetwork"', '<div id="mynetwork" style="height: 100vh; width: 100%;"')
60
+ return f"""<iframe style="width: 100%; height: 100vh; border: none; margin: 0; padding: 0;" srcdoc='{html}'></iframe>"""
main.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ from PIL import Image
5
+ from database_operations import Neo4jDatabase
6
+ from graph_visualization import visualize_graph
7
+ from utils import extract_label_prefix, strip_keys, format_json, validate_json
8
+ from models.gemini_image_to_json import fetch_gemini_response
9
+ from models.openai_image_to_json import openaiprocess_image_to_json
10
+ from any_to_image import pdf_to_images, process_image
11
+
12
+ # Initialize Neo4j database
13
+ db = Neo4jDatabase("bolt://localhost:7687", "neo4j", "password123")
14
+
15
+ def dump_to_neo4j_with_confirmation(json_content, file_path, history, previous_states):
16
+ if not file_path:
17
+ return "No image uploaded or invalid file", history, previous_states, None
18
+
19
+ try:
20
+ json_data = json.loads(json_content)
21
+ except json.JSONDecodeError:
22
+ return "Invalid JSON data. Please check your input.", history, previous_states, None
23
+
24
+ label_prefix = extract_label_prefix(file_path)
25
+
26
+ if db.check_existing_graph(label_prefix):
27
+ previous_state = db.get_graph_data(label_prefix)
28
+ return f"A graph with label prefix '{label_prefix}' already exists in the database. Do you want to overwrite it?", history, previous_states, label_prefix
29
+ else:
30
+ json_data = strip_keys(json_data)
31
+ db.dump_to_neo4j(json_data['nodes'], json_data['edges'], label_prefix)
32
+ result = f"Data successfully dumped into the database with label prefix '{label_prefix}'."
33
+ new_history = f"{history}\n[NEW ENTRY] {result}" if history else f"[NEW ENTRY] {result}"
34
+ previous_states[label_prefix] = []
35
+ return result, new_history, previous_states, None
36
+
37
+ def confirm_overwrite(confirmation, gradio_state, json_content, file_path, history, previous_states):
38
+ if confirmation.lower() == 'yes':
39
+ try:
40
+ label_prefix = extract_label_prefix(file_path)
41
+ previous_state = db.get_graph_data(label_prefix)
42
+ # print(f'previous_state from the confirm_overwrite function: {previous_state}')
43
+ # print(f'label_prefix from the confirm_overwrite function: {label_prefix}')
44
+ # print(f'previouse_states from the confirm_overwrite function: {previous_states}')
45
+
46
+ if label_prefix not in previous_states:
47
+ previous_states[label_prefix] = []
48
+ previous_states[label_prefix].append(previous_state)
49
+ else:
50
+ previous_states[label_prefix].append(previous_state)
51
+
52
+ if len(previous_states[label_prefix]) > 3:
53
+ previous_states[label_prefix] = previous_states[label_prefix][-3:]
54
+
55
+ db.delete_graph(label_prefix)
56
+
57
+ json_data = json.loads(json_content)
58
+ json_data = strip_keys(json_data)
59
+ db.dump_to_neo4j(json_data['nodes'], json_data['edges'], label_prefix)
60
+ result = f"Data successfully overwritten in the database with label prefix '{label_prefix}'."
61
+ new_history = f"{history}\n[OVERWRITE] {result}" if history else f"[OVERWRITE] {result}"
62
+ return result, new_history, previous_states, ""
63
+ except json.JSONDecodeError:
64
+ return "Invalid JSON data. Please check your input.", history, previous_states, ""
65
+ else:
66
+ return "Operation cancelled. The existing graph was not overwritten.", history, previous_states, ""
67
+
68
+ def revert_last_action(history, previous_states):
69
+ if not history:
70
+ return "No actions to revert.", history, previous_states
71
+
72
+ last_action = history.split('\n')[-1]
73
+ label_prefix = last_action.split("'")[1]
74
+
75
+ if label_prefix in previous_states and previous_states[label_prefix]:
76
+ db.delete_graph(label_prefix)
77
+ db.dump_to_neo4j(previous_states[label_prefix][-1]['nodes'], previous_states[label_prefix][-1]['edges'], label_prefix)
78
+ new_history = history + f"\n[REVERT] Reverted overwrite of graph with label prefix '{label_prefix}'"
79
+ previous_states[label_prefix].pop()
80
+ return f"Reverted last action: {last_action}", new_history, previous_states
81
+ elif label_prefix in previous_states and not previous_states[label_prefix]:
82
+ db.delete_graph(label_prefix)
83
+ new_history = history + f"\n[REVERT] Deleted newly added graph with label prefix '{label_prefix}'"
84
+ del previous_states[label_prefix]
85
+ return f"Reverted last action: {last_action}", new_history, previous_states
86
+ else:
87
+ return "Unable to revert the last action.", history, previous_states
88
+
89
+ def update_graph_from_edited_json(json_content, physics_enabled):
90
+ try:
91
+ json_data = json.loads(json_content)
92
+ json_data = strip_keys(json_data)
93
+ validate_json(json_data)
94
+ return visualize_graph(json_data, physics_enabled), ""
95
+ except json.JSONDecodeError as e:
96
+ return None, f"Invalid JSON format: {str(e)}"
97
+ except ValueError as e:
98
+ return None, f"Invalid graph structure: {str(e)}"
99
+ except Exception as e:
100
+ return None, f"An unexpected error occurred: {str(e)}"
101
+
102
+ def fetch_kg(image_file_path, model_choice_state):
103
+ if image_file_path:
104
+ mind_map_image = Image.open(image_file_path)
105
+
106
+ if model_choice_state == 'Gemini':
107
+ print(f'model choice is gemini')
108
+ kg_json_text = fetch_gemini_response(mind_map_image)
109
+ elif model_choice_state == 'OpenAI':
110
+ print(f'model choice is openai')
111
+ kg_json_text = openaiprocess_image_to_json(mind_map_image)
112
+
113
+ json_data = json.loads(kg_json_text)
114
+ return format_json(json_data), ""
115
+ return "", "No image uploaded or invalid file"
116
+
117
+ def input_file_handler(file_path):
118
+ if file_path:
119
+ image_path, error = process_image(file_path)
120
+ return image_path, error
121
+
122
+ return "", "No image uploaded or invalid file"
123
+
124
+ # Gradio interface
125
+ with gr.Blocks() as demo:
126
+ gr.Markdown("## Image to Knowledge Graph Transformation")
127
+
128
+ with gr.Row():
129
+ file_input = gr.File(label="Upload File", file_count="single",
130
+ type="filepath",
131
+ file_types=[".pdf", ".png", ".jpeg", ".jpg", ".heic"])
132
+ image_file = gr.Image(label="Input Image", type="filepath", visible=False)
133
+ json_editor = gr.Textbox(label="Edit JSON", lines=15, placeholder="JSON data will appear here after image upload")
134
+
135
+ with gr.Row():
136
+ with gr.Column():
137
+ with gr.Row():
138
+ CCW_rotate_button = gr.Button('Rotate Image Counter-Clockwise')
139
+ CW_rotate_button = gr.Button('Rotate Image Clockwise')
140
+ with gr.Column():
141
+ model_call = gr.Button('Transform Image into KG representation', scale=2)
142
+ with gr.Row():
143
+ physics_button = gr.Checkbox(value=True, label="Enable Graph Physics")
144
+ model_choice = gr.Radio(label="Select Model", choices=["OpenAI", "Gemini"], value="Gemini", interactive=True)
145
+
146
+ graph_output = gr.HTML(label="Graph Output")
147
+ error_output = gr.Textbox(label="Error Messages", interactive=False)
148
+
149
+ update_button = gr.Button("Update Graph")
150
+ dump_button = gr.Button("Dump to Neo4j")
151
+ revert_button = gr.Button("Revert Last Action")
152
+
153
+ history_block = gr.Textbox(label="History", placeholder="Graphs pushed to the Database", interactive=False, lines=5, max_lines=50)
154
+ history_state = gr.State("")
155
+ previous_states = gr.State({})
156
+
157
+ confirmation_output = gr.Textbox(label="Confirmation Message", visible=False, interactive=False)
158
+ confirmation_input = gr.Textbox(label="Type 'yes' to confirm overwrite", visible=False, interactive=True)
159
+ confirm_button = gr.Button("Confirm Overwrite", visible=False)
160
+
161
+ file_input.upload(
162
+ fn=input_file_handler,
163
+ inputs=[file_input],
164
+ outputs=[image_file, error_output]
165
+ ).then(
166
+ lambda image_file: (
167
+ gr.Image(value=image_file, visible=True),
168
+ gr.File(visible=False)
169
+ ),
170
+ inputs=[image_file],
171
+ outputs=[image_file, file_input]
172
+ )
173
+
174
+ image_file.clear(
175
+ lambda file_input, image_file: (
176
+ gr.File(visible=True),
177
+ gr.Image(visible=False)
178
+ ),
179
+ inputs=[file_input, image_file],
180
+ outputs=[file_input, image_file]
181
+ )
182
+
183
+ def rotate_image_to_left(image_path):
184
+ if image_path:
185
+ image = Image.open(image_path)
186
+ image = image.rotate(-90, expand=True)
187
+ image.save(image_path)
188
+ return image_path
189
+
190
+ CW_rotate_button.click(
191
+ fn=rotate_image_to_left,
192
+ inputs=[image_file],
193
+ outputs=[image_file]
194
+ )
195
+
196
+ def rotate_image_to_right(image_path):
197
+ if image_path:
198
+ image = Image.open(image_path)
199
+ image = image.rotate(90, expand=True)
200
+ image.save(image_path)
201
+ return image_path
202
+
203
+ CCW_rotate_button.click(
204
+ fn=rotate_image_to_right,
205
+ inputs=[image_file],
206
+ outputs=[image_file]
207
+ )
208
+
209
+ dump_button.click(
210
+ dump_to_neo4j_with_confirmation,
211
+ inputs=[json_editor, image_file, history_state, previous_states],
212
+ outputs=[confirmation_output, history_state, previous_states, gr.State()]
213
+ ).then(
214
+ lambda message, history, previous_states, label_prefix: (
215
+ gr.Textbox(value=message, visible=True),
216
+ gr.Textbox(visible=True),
217
+ gr.Button(visible=True),
218
+ history,
219
+ previous_states,
220
+ label_prefix
221
+ ),
222
+ inputs=[confirmation_output, history_state, previous_states, gr.State()],
223
+ outputs=[confirmation_output, confirmation_input, confirm_button, history_state, previous_states, gr.State()]
224
+ ).then(
225
+ lambda history: history,
226
+ inputs=[history_state],
227
+ outputs=[history_block]
228
+ )
229
+
230
+ gr.on(
231
+ triggers=[confirm_button.click, confirmation_input.submit],
232
+ fn=confirm_overwrite,
233
+ inputs=[confirmation_input, gr.State(), json_editor, image_file, history_state, previous_states],
234
+ outputs=[confirmation_output, history_state, previous_states, confirmation_input]
235
+ ).then(
236
+ lambda confirmation_output, confirmation_input: (
237
+ gr.Textbox(value=confirmation_output, visible=True),
238
+ gr.Textbox(value='', visible=False),
239
+ gr.Button(visible=False)
240
+ ),
241
+ inputs=[confirmation_output, confirmation_input],
242
+ outputs=[confirmation_output, confirmation_input, confirm_button]
243
+ ).then(
244
+ lambda history: history,
245
+ inputs=[history_state],
246
+ outputs=[history_block]
247
+ )
248
+
249
+ revert_button.click(
250
+ revert_last_action,
251
+ inputs=[history_state, previous_states],
252
+ outputs=[confirmation_output, history_state, previous_states]
253
+ ).then(
254
+ lambda confirmation_output: gr.Textbox(value=confirmation_output, visible=True),
255
+ inputs=[confirmation_output],
256
+ outputs=[confirmation_output]
257
+ ).then(
258
+ lambda history: history,
259
+ inputs=[history_state],
260
+ outputs=[history_block]
261
+ )
262
+
263
+ update_button.click(
264
+ update_graph_from_edited_json,
265
+ inputs=[json_editor, physics_button],
266
+ outputs=[graph_output, error_output]
267
+ )
268
+
269
+ physics_button.change(
270
+ update_graph_from_edited_json,
271
+ inputs=[json_editor, physics_button],
272
+ outputs=[graph_output, error_output]
273
+ )
274
+
275
+ model_call.click(
276
+ fn=fetch_kg,
277
+ inputs=[image_file, model_choice],
278
+ outputs=[json_editor, error_output]
279
+ )
280
+
281
+ if __name__ == "__main__":
282
+ demo.launch()
openai_image_to_json.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import requests
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ import os
6
+ from dotenv import load_dotenv
7
+ import json
8
+
9
+ # Load the .env file
10
+ load_dotenv()
11
+
12
+ # Get the API key from the environment
13
+ api_key = os.getenv('OPENAI_API_KEY')
14
+
15
+ # Function to encode the image
16
+ def encode_image(image):
17
+ # Convert the image to RGB if it has an alpha channel
18
+ if image.mode == 'RGBA':
19
+ image = image.convert('RGB')
20
+
21
+ buffered = BytesIO()
22
+ image.save(buffered, format="JPEG")
23
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
24
+
25
+ def openaiprocess_image_to_json(image):
26
+
27
+ print(f'fetching openai response')
28
+
29
+ # Encode the image
30
+ base64_image = encode_image(image)
31
+
32
+ headers = {
33
+ "Content-Type": "application/json",
34
+ "Authorization": f"Bearer {api_key}"
35
+ }
36
+
37
+ PROMPT = '''
38
+ You are responsible for extracting the entities (nodes) and relationships (edges) from the images of mind maps. The mind maps are for Object Oriented Programming.
39
+ Don't make up facts, just extracts them. Do not create new entity types that aren't mentioned in the image, and at the same time don't miss anything.
40
+ Give the output in JSON format as follows:
41
+ {
42
+ "nodes": [
43
+ {"id": "1", "label": string},
44
+ {"id": "2", "label": string},...
45
+ ],
46
+ "edges": [
47
+ {"source": SOURCE_ID, "target": TARGET_ID, "type": "->"},
48
+ {"source": SOURCE_ID, "target": TARGET_ID, "type": "->"},...
49
+ ]
50
+ }
51
+ Only return valid python dictionary, dont include (line jump)n in it, dont include spaces, only a dictionary. Do not include any other text outside the Dictionary structure. Make sure that i will get a valid Python dictionary.
52
+ make sure that what you return as json_string i can use it in python in this function: json.loads(json_string)
53
+ Now extract the entities and relationships from this image:
54
+ '''
55
+
56
+ payload = {
57
+ "model": "gpt-4o",
58
+ "messages": [
59
+ {
60
+ "role": "user",
61
+ "content": [
62
+ {
63
+ "type": "text",
64
+ "text": PROMPT
65
+ },
66
+ {
67
+ "type": "image_url",
68
+ "image_url": {
69
+ "url": f"data:image/jpeg;base64,{base64_image}"
70
+ }
71
+ }
72
+ ]
73
+ }
74
+ ]
75
+ }
76
+
77
+ # Send the request to the OpenAI API
78
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
79
+
80
+ # Parse the response
81
+ response_data = response.json()
82
+ print(response_data)
83
+
84
+ # Extract the JSON graph data from the response
85
+ if "choices" in response_data and response_data["choices"]:
86
+ content = response_data["choices"][0]["message"]["content"]
87
+ try:
88
+ graph_data = content
89
+ except json.JSONDecodeError as e:
90
+ print("Failed:", e)
91
+ graph_data = None
92
+ else:
93
+ raise ValueError("No valid response from OpenAI API")
94
+
95
+ return graph_data
requirements.txt ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.4.0
4
+ asttokens==2.4.1
5
+ cachetools==5.5.0
6
+ certifi==2024.8.30
7
+ cffi==1.17.1
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ contourpy==1.3.0
11
+ cycler==0.12.1
12
+ decorator==5.1.1
13
+ exceptiongroup==1.2.2
14
+ executing==2.1.0
15
+ fastapi==0.114.1
16
+ ffmpy==0.4.0
17
+ filelock==3.16.0
18
+ fonttools==4.53.1
19
+ fsspec==2024.9.0
20
+ google-ai-generativelanguage==0.6.9
21
+ google-api-core==2.19.2
22
+ google-api-python-client==2.145.0
23
+ google-auth==2.34.0
24
+ google-auth-httplib2==0.2.0
25
+ google-generativeai==0.8.1
26
+ googleapis-common-protos==1.65.0
27
+ gradio==4.44.0
28
+ gradio_client==1.3.0
29
+ grpcio==1.66.1
30
+ grpcio-status==1.66.1
31
+ h11==0.14.0
32
+ httpcore==1.0.5
33
+ httplib2==0.22.0
34
+ httpx==0.27.2
35
+ huggingface-hub==0.24.7
36
+ idna==3.8
37
+ importlib_resources==6.4.5
38
+ ipython==8.27.0
39
+ jedi==0.19.1
40
+ Jinja2==3.1.4
41
+ jsonpickle==3.3.0
42
+ kiwisolver==1.4.7
43
+ markdown-it-py==3.0.0
44
+ MarkupSafe==2.1.5
45
+ matplotlib==3.9.2
46
+ matplotlib-inline==0.1.7
47
+ mdurl==0.1.2
48
+ neo4j==5.24.0
49
+ networkx==3.3
50
+ numpy==2.1.1
51
+ orjson==3.10.7
52
+ packaging==24.1
53
+ pandas==2.2.2
54
+ parso==0.8.4
55
+ pexpect==4.9.0
56
+ pillow==10.4.0
57
+ prompt_toolkit==3.0.47
58
+ proto-plus==1.24.0
59
+ protobuf==5.28.1
60
+ ptyprocess==0.7.0
61
+ pure_eval==0.2.3
62
+ pyasn1==0.6.1
63
+ pyasn1_modules==0.4.1
64
+ pycparser==2.22
65
+ pydantic==2.9.1
66
+ pydantic_core==2.23.3
67
+ pydub==0.25.1
68
+ Pygments==2.18.0
69
+ pyheif==0.8.0
70
+ PyMuPDF==1.24.10
71
+ PyMuPDFb==1.24.10
72
+ pyparsing==3.1.4
73
+ python-dateutil==2.9.0.post0
74
+ python-dotenv==1.0.1
75
+ python-multipart==0.0.9
76
+ pytz==2024.2
77
+ pyvis==0.3.2
78
+ PyYAML==6.0.2
79
+ requests==2.32.3
80
+ rich==13.8.1
81
+ rsa==4.9
82
+ ruff==0.6.4
83
+ semantic-version==2.10.0
84
+ shellingham==1.5.4
85
+ six==1.16.0
86
+ sniffio==1.3.1
87
+ stack-data==0.6.3
88
+ starlette==0.38.5
89
+ tomlkit==0.12.0
90
+ tqdm==4.66.5
91
+ traitlets==5.14.3
92
+ typer==0.12.5
93
+ typing_extensions==4.12.2
94
+ tzdata==2024.1
95
+ uritemplate==4.1.1
96
+ urllib3==2.2.3
97
+ uvicorn==0.30.6
98
+ wcwidth==0.2.13
99
+ websockets==12.0
utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ def extract_label_prefix(file_name):
5
+ label_prefix = os.path.splitext(os.path.basename(file_name))[0]
6
+ return label_prefix.replace(" ", "_").replace("-", "_")
7
+
8
+ def strip_keys(d):
9
+ if isinstance(d, dict):
10
+ return {k.strip(): strip_keys(v) for k, v in d.items()}
11
+ elif isinstance(d, list):
12
+ return [strip_keys(i) for i in d]
13
+ else:
14
+ return d
15
+
16
+ def format_json(json_data):
17
+ formatted_json = "{\n \"nodes\": [\n"
18
+ for node in json_data['nodes']:
19
+ formatted_json += f" {json.dumps(node)},\n"
20
+ formatted_json = formatted_json.rstrip(',\n') + "\n ],\n \"edges\": [\n"
21
+ for edge in json_data['edges']:
22
+ formatted_json += f" {json.dumps(edge)},\n"
23
+ formatted_json = formatted_json.rstrip(',\n') + "\n ]\n}"
24
+ return formatted_json
25
+
26
+ def validate_json(json_data):
27
+ if not isinstance(json_data, dict) or 'nodes' not in json_data or 'edges' not in json_data:
28
+ raise ValueError("JSON must contain 'nodes' and 'edges' keys")
29
+
30
+ if not isinstance(json_data['nodes'], list) or not isinstance(json_data['edges'], list):
31
+ raise ValueError("'nodes' and 'edges' must be lists")
32
+
33
+ for node in json_data['nodes']:
34
+ if 'id' not in node or 'label' not in node:
35
+ raise ValueError("Each node must have 'id' and 'label' properties")
36
+
37
+ for edge in json_data['edges']:
38
+ if 'source' not in edge or 'target' not in edge or 'type' not in edge:
39
+ raise ValueError("Each edge must have 'source', 'target', and 'type' properties")
40
+
41
+ if edge['type'] != "->":
42
+ raise ValueError("Edge type must be '->' strictly")