SKB-Explorer / interactive /pyvis_graph.py
shirwu's picture
rephrsse
a00d62c
import sys
import json
import torch
import gradio as gr
from pyvis.network import Network
sys.path.append(".")
import re
from src.benchmarks import get_semistructured_data
CONCURRENCY_LIMIT = 1000
TITLE = "STaRK Semi-structured Knowledge Base Explorer"
BRAND_NAME = {
"amazon": "STaRK-Amazon",
"mag": "STaRK-MAG",
"primekg": "STaRK-Prime",
}
NODE_COLORS = [
"#4285F4", # Blue
"#F4B400", # Yellow
"#0F9D58", # Green
"#00796B", # Teal
"#03A9F4", # Light Blue
"#CDDC39", # Lime
"#3F51B5", # Indigo
"#00BCD4", # Cyan
"#FFC107", # Amber
"#8BC34A", # Light Green
"#9E9E9E", # Grey
"#607D8B", # Blue Grey
"#FFEB3B", # Bright Yellow
"#E1F5FE", # Light Blue 50
"#F1F8E9", # Light Green 50
"#FFF3E0", # Orange 50
"#FFFDE7", # Yellow 50
"#E0F7FA", # Cyan 50
"#E8F5E9", # Green 50
"#E3F2FD", # Blue 50
"#FFF8E1", # Amber 50
"#E0F2F1", # Teal 50
"#F9FBE7", # Lime 50
]
EDGE_COLORS = [
"#1B5E20", # Green 900
"#004D40", # Teal 900
"#1A237E", # Indigo 900
"#3E2723", # Brown 900
"#880E4F", # Pink 900
"#01579B", # Light Blue 900
"#F57F17", # Yellow 900
"#FF6F00", # Amber 900
"#4A148C", # Purple 900
"#0D47A1", # Blue 900
"#006064", # Cyan 900
"#827717", # Lime 900
"#E8EAF6", # Indigo 50
"#ECEFF1", # Blue Grey 50
"#9C27B0", # Purple
"#311B92", # Deep Purple 900
"#673AB7", # Deep Purple
"#EDE7F6", # Deep Purple 50
]
VISJS_HEAD = """
<script src="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.9/dist/vis-network.min.js" integrity="sha512-4/EGWWWj7LIr/e+CvsslZkRk0fHDpf04dydJHoHOH32Mpw8jYU28GNI6mruO7fh/1kq15kSvwhKJftMSlgm0FA==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.9/dist/dist/vis-network.min.css" integrity="sha512-WgxfT5LWjfszlPHXRmBWHkV2eceiWTOBvrKCNbdgDYTHrT2AeLCGbF4sZlZw3UMN3WtL0tGUoIAKsu8mllg/XA==" crossorigin="anonymous" referrerpolicy="no-referrer" />
<style type="text/css"> .graph-area { flex-basis: 30% !important; } .network-graph { width: 100%; height: 600px; background-color: #ffffff; border: 1px solid lightgray; position: relative; float: left; } </style>
"""
with open("interactive/draw_graph.js", "r") as f:
VISJS_HEAD += f"<script>{f.read()}</script>"
def relabel(x, edge_index, batch, pos=None):
num_nodes = x.size(0)
sub_nodes = torch.unique(edge_index)
x = x[sub_nodes]
batch = batch[sub_nodes]
row, col = edge_index
# remapping the nodes in the explanatory subgraph to new ids.
node_idx = row.new_full((num_nodes,), -1)
node_idx[sub_nodes] = torch.arange(sub_nodes.size(0), device=row.device)
edge_index = node_idx[edge_index]
if pos is not None:
pos = pos[sub_nodes]
return x, edge_index, batch, pos
def generate_network(kb, node_id, max_nodes=10, num_hops='2'):
max_nodes = int(max_nodes)
if 'gene/protein' in kb.node_type_dict.values():
indirected = True
net = Network(directed=False)
else:
indirected = False
net = Network()
def get_one_hop(kb, node_id, max_nodes):
edge_index = kb.edge_index
mask = (
torch.Tensor(edge_index[0] == node_id).float()
+ torch.Tensor(edge_index[1] == node_id).float()
) > 0
edge_index_with_node_id = edge_index[:, mask]
edge_types = kb.edge_types[mask]
# take the edge index with
# ramdomly sample max_nodes edges
if edge_index_with_node_id.size(1) > max_nodes:
perm = torch.randperm(edge_index_with_node_id.size(1))
edge_index_with_node_id = edge_index_with_node_id[:, perm[:max_nodes]]
edge_types = edge_types[perm[:max_nodes]]
return edge_index_with_node_id, edge_types
if num_hops == "1":
edge_index, edge_types = get_one_hop(kb, node_id, max_nodes)
if num_hops == "2":
edge_index, edge_types = get_one_hop(kb, node_id, max_nodes)
neighbor_nodes = torch.unique(edge_index).tolist()
if node_id in neighbor_nodes:
neighbor_nodes.remove(node_id)
for neighbor_node in neighbor_nodes:
e_index, e_type = get_one_hop(kb, neighbor_node, max_nodes=1)
edge_index = torch.cat([edge_index, e_index], dim=1)
edge_types = torch.cat([edge_types, e_type], dim=0)
if num_hops == "inf":
edge_index, edge_types = kb.edge_index, kb.edge_types
# sample max_nodes edges
if edge_index.size(1) > max_nodes:
perm = torch.randperm(edge_index.size(1))
edge_index = edge_index[:, perm[:max_nodes]]
edge_types = edge_types[perm[:max_nodes]]
add_edge_index, add_edge_types = get_one_hop(kb, node_id, max_nodes=1)
edge_index = torch.cat([edge_index, add_edge_index], dim=1)
edge_types = torch.cat([edge_types, add_edge_types], dim=0)
# add a self-loop for node_id to avoid isolated node
edge_index = torch.concat([edge_index, torch.LongTensor([[node_id], [node_id]])], dim=1)
node_ids, relabel_edge_index, _, _ = relabel(
torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes())
)
for idx, n_id in enumerate(node_ids):
if node_id == n_id:
net.add_node(
idx,
node_id=n_id.item(),
color="#DB4437",
size=20,
label=f"{kb.node_type_dict[kb.node_types[n_id].item()]}<{n_id}>",
font={"align": "middle", "size": 10},
)
else:
net.add_node(
idx,
node_id=n_id.item(),
size=15,
color=NODE_COLORS[kb.node_types[n_id].item()],
label=f"{kb.node_type_dict[kb.node_types[n_id].item()]}",
font={"align": "middle", "size": 10},
)
for idx in range(relabel_edge_index.size(-1)):
if relabel_edge_index[0][idx].item() == relabel_edge_index[1][idx].item():
continue
if indirected:
net.add_edge(
relabel_edge_index[0][idx].item(),
relabel_edge_index[1][idx].item(),
color=EDGE_COLORS[edge_types[idx].item()],
label=kb.edge_type_dict[edge_types[idx].item()]
.replace('___', " ")
.replace('_', " "),
width=1,
font={"align": "middle", "size": 10})
else:
net.add_edge(
relabel_edge_index[0][idx].item(),
relabel_edge_index[1][idx].item(),
color=EDGE_COLORS[edge_types[idx].item()],
label=kb.edge_type_dict[edge_types[idx].item()]
.replace('___', " ")
.replace('_', " "),
width=1,
font={"align": "middle", "size": 10},
arrows="to",
arrowStrikethrough=False)
return net.get_network_data()
def get_text_html(kb, node_id):
text = kb.get_doc_info(node_id, add_rel=False, compact=False)
# add a title
text = text.replace("\n", "<br>").replace(" ", "&nbsp;")
text = f"<h3>Textual Info of Entity {node_id}:</h3>{text}"
text = re.sub(r"\$([^$]+)\$", r"\\(\1\\)", text)
# show the text as what it is with empty space and can be scrolled
return f"""<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
<script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
<div style="width: 100%; height: 600px; overflow-x: hidden; overflow-y: scroll; overflow-wrap: break-word; hyphens: auto; padding: 10px; margin: 0 auto; border: 1px solid #ccc; line-height: 1.5;
font-family: SF Pro Text, SF Pro Icons, Helvetica Neue, Helvetica, Arial, sans-serif;">{text}</div>"""
def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops='1'):
network = generate_network(kb, node_id, max_nodes, num_hops)
nodes = network[0]
edges = network[1]
# A dirty hack to trigger the drawGraph function ;)
# Have to do it this way because of the way Gradio handles HTML updates
figure_html = f"""
<div id="{kb_name}-network" class="network-graph"></div>
<img src="/dummy.img" style="display: none;" onerror='drawGraph({json.dumps({"nodes": nodes, "edges": edges, "dataset": kb_name})});'>
"""
return figure_html
def main():
# kb = get_semistructured_data(DATASET_NAME)
kbs = {k: get_semistructured_data(k, indirected=False) for k in BRAND_NAME.keys()}
with gr.Blocks(head=VISJS_HEAD, title=TITLE) as demo:
gr.Markdown(f"# {TITLE}")
for name, kb in kbs.items():
with gr.Tab(BRAND_NAME[name]):
with gr.Row():
entity_id = gr.Number(
label="Entity ID",
elem_id=f"{name}-entity-id-input"
)
max_paths = gr.Slider(
1, 200, 10, step=1, label="Max Number of Paths"
)
num_hops = gr.Dropdown(
["1", "2", "inf"], value="2", label="Number of Hops"
)
query_btn = gr.Button(
value="Display Semi-structured Data",
variant="primary",
elem_id=f"{name}-fetch-btn"
)
with gr.Row():
graph_area = gr.HTML(elem_classes="graph-area")
text_area = gr.HTML(elem_classes="text-area")
query_btn.click(
# copy capture current kb and name
lambda e, n, h, kb=kb, name=name: (
get_subgraph_html(kb, name, e, n, h),
get_text_html(kb, e),
),
inputs=[entity_id, max_paths, num_hops],
outputs=[graph_area, text_area],
api_name=f"{name}-fetch-graph"
)
# Hidden inputs for fetch just text
with gr.Row(visible=False):
entity_for_text = gr.Number(
label="Text Entity ID", elem_id=f"{name}-entity-id-text-input"
)
query_text_btn = gr.Button(
value="Show Text", elem_id=f"{name}-fetch-text-btn"
)
query_text_btn.click(
lambda e, kb=kb: get_text_html(kb, e),
inputs=[entity_for_text],
outputs=text_area,
api_name=f"{name}-fetch-text"
)
demo.queue(max_size=2*CONCURRENCY_LIMIT, default_concurrency_limit=CONCURRENCY_LIMIT)
demo.launch(share=True)
if __name__ == "__main__":
main()