import gradio as gr from datasets import load_dataset, Dataset from PIL import Image import io import base64 import json from graph_visualization import visualize_graph # branch_name = "edges-sorted-ascending" branch_name = "Sorted_edges" # Load the dataset # dataset = load_dataset("Zaherrr/OOP_KG_Dataset", split='data', revision=branch_name) dataset = load_dataset("Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset", split='data') #, revision=branch_name) print(f'This is the dataset: {dataset}') print(dataset.info) print(f'This is an example: {dataset[-5]}') def reshape_json_data_to_fit_visualize_graph(graph_data): nodes = graph_data["nodes"] edges = graph_data["edges"] transformed_nodes = [ {"id": nodes["id"][idx], "label": nodes["label"][idx]} for idx in range(len(nodes["id"])) ] transformed_edges = [ {"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"} for idx in range(len(edges["source"])) ] graph_data = {"nodes": transformed_nodes, "edges": transformed_edges} return graph_data def display_example(index): example = dataset[index] img = example["image"] # Get image dimensions img_width, img_height = img.size # Prepare the graph data graph_data = {"nodes": example["nodes"], "edges": example["edges"]} transformed_graph_data = reshape_json_data_to_fit_visualize_graph(graph_data) # Generate the graph visualization graph_html = visualize_graph(transformed_graph_data) # Modify the iframe to have a fixed height graph_html = graph_html.replace('height: 100vh;', 'height: 500px;') # Convert graph_data to a formatted JSON string json_data = json.dumps(transformed_graph_data, indent=2) return img, graph_html, json_data, transformed_graph_data, f"Width: {img_width}px, Height: {img_height}px" def create_interface(): with gr.Blocks() as demo: gr.Markdown("# Knowledge Graph Visualizer for the [Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset](https://huggingface.co./datasets/Zaherrr/OOP_KG_MarkMap_Synthetic_Dataset) dataset") with gr.Row(): index_slider = gr.Slider( minimum=0, maximum=len(dataset) - 1, step=1, label="Example Index" ) with gr.Row(): image_output = gr.Image(type="pil", label="Image", height=500) graph_output = gr.HTML(label="Knowledge Graph") with gr.Row(): dimensions_output = gr.Textbox( label="Image Dimensions (pixels)", placeholder="Width and Height will appear here", interactive=False, ) with gr.Row(): json_output = gr.Code(language="json", label="Graph JSON Data") text_output = gr.Textbox( label="Graph Text Data", placeholder="Text data will appear here", interactive=False, ) index_slider.change( fn=display_example, inputs=[index_slider], outputs=[image_output, graph_output, json_output, text_output, dimensions_output], ) return demo # Create and launch the interface if __name__ == "__main__": demo = create_interface() demo.launch()