Spaces:
Sleeping
Sleeping
import sys | |
import json | |
import torch | |
import gradio as gr | |
from pyvis.network import Network | |
sys.path.append(".") | |
import re | |
from src.benchmarks import get_semistructured_data | |
CONCURRENCY_LIMIT = 1000 | |
TITLE = "STaRK Semi-structured Knowledge Base Explorer" | |
BRAND_NAME = { | |
"amazon": "STaRK-Amazon", | |
"mag": "STaRK-MAG", | |
"primekg": "STaRK-Prime", | |
} | |
NODE_COLORS = [ | |
"#4285F4", # Blue | |
"#F4B400", # Yellow | |
"#0F9D58", # Green | |
"#00796B", # Teal | |
"#03A9F4", # Light Blue | |
"#CDDC39", # Lime | |
"#3F51B5", # Indigo | |
"#00BCD4", # Cyan | |
"#FFC107", # Amber | |
"#8BC34A", # Light Green | |
"#9E9E9E", # Grey | |
"#607D8B", # Blue Grey | |
"#FFEB3B", # Bright Yellow | |
"#E1F5FE", # Light Blue 50 | |
"#F1F8E9", # Light Green 50 | |
"#FFF3E0", # Orange 50 | |
"#FFFDE7", # Yellow 50 | |
"#E0F7FA", # Cyan 50 | |
"#E8F5E9", # Green 50 | |
"#E3F2FD", # Blue 50 | |
"#FFF8E1", # Amber 50 | |
"#E0F2F1", # Teal 50 | |
"#F9FBE7", # Lime 50 | |
] | |
EDGE_COLORS = [ | |
"#1B5E20", # Green 900 | |
"#004D40", # Teal 900 | |
"#1A237E", # Indigo 900 | |
"#3E2723", # Brown 900 | |
"#880E4F", # Pink 900 | |
"#01579B", # Light Blue 900 | |
"#F57F17", # Yellow 900 | |
"#FF6F00", # Amber 900 | |
"#4A148C", # Purple 900 | |
"#0D47A1", # Blue 900 | |
"#006064", # Cyan 900 | |
"#827717", # Lime 900 | |
"#E8EAF6", # Indigo 50 | |
"#ECEFF1", # Blue Grey 50 | |
"#9C27B0", # Purple | |
"#311B92", # Deep Purple 900 | |
"#673AB7", # Deep Purple | |
"#EDE7F6", # Deep Purple 50 | |
] | |
VISJS_HEAD = """ | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.9/dist/vis-network.min.js" integrity="sha512-4/EGWWWj7LIr/e+CvsslZkRk0fHDpf04dydJHoHOH32Mpw8jYU28GNI6mruO7fh/1kq15kSvwhKJftMSlgm0FA==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.9/dist/dist/vis-network.min.css" integrity="sha512-WgxfT5LWjfszlPHXRmBWHkV2eceiWTOBvrKCNbdgDYTHrT2AeLCGbF4sZlZw3UMN3WtL0tGUoIAKsu8mllg/XA==" crossorigin="anonymous" referrerpolicy="no-referrer" /> | |
<style type="text/css"> .graph-area { flex-basis: 30% !important; } .network-graph { width: 100%; height: 600px; background-color: #ffffff; border: 1px solid lightgray; position: relative; float: left; } </style> | |
""" | |
with open("interactive/draw_graph.js", "r") as f: | |
VISJS_HEAD += f"<script>{f.read()}</script>" | |
def relabel(x, edge_index, batch, pos=None): | |
num_nodes = x.size(0) | |
sub_nodes = torch.unique(edge_index) | |
x = x[sub_nodes] | |
batch = batch[sub_nodes] | |
row, col = edge_index | |
# remapping the nodes in the explanatory subgraph to new ids. | |
node_idx = row.new_full((num_nodes,), -1) | |
node_idx[sub_nodes] = torch.arange(sub_nodes.size(0), device=row.device) | |
edge_index = node_idx[edge_index] | |
if pos is not None: | |
pos = pos[sub_nodes] | |
return x, edge_index, batch, pos | |
def generate_network(kb, node_id, max_nodes=10, num_hops='2'): | |
max_nodes = int(max_nodes) | |
if 'gene/protein' in kb.node_type_dict.values(): | |
indirected = True | |
net = Network(directed=False) | |
else: | |
indirected = False | |
net = Network() | |
def get_one_hop(kb, node_id, max_nodes): | |
edge_index = kb.edge_index | |
mask = ( | |
torch.Tensor(edge_index[0] == node_id).float() | |
+ torch.Tensor(edge_index[1] == node_id).float() | |
) > 0 | |
edge_index_with_node_id = edge_index[:, mask] | |
edge_types = kb.edge_types[mask] | |
# take the edge index with | |
# ramdomly sample max_nodes edges | |
if edge_index_with_node_id.size(1) > max_nodes: | |
perm = torch.randperm(edge_index_with_node_id.size(1)) | |
edge_index_with_node_id = edge_index_with_node_id[:, perm[:max_nodes]] | |
edge_types = edge_types[perm[:max_nodes]] | |
return edge_index_with_node_id, edge_types | |
if num_hops == "1": | |
edge_index, edge_types = get_one_hop(kb, node_id, max_nodes) | |
if num_hops == "2": | |
edge_index, edge_types = get_one_hop(kb, node_id, max_nodes) | |
neighbor_nodes = torch.unique(edge_index).tolist() | |
if node_id in neighbor_nodes: | |
neighbor_nodes.remove(node_id) | |
for neighbor_node in neighbor_nodes: | |
e_index, e_type = get_one_hop(kb, neighbor_node, max_nodes=1) | |
edge_index = torch.cat([edge_index, e_index], dim=1) | |
edge_types = torch.cat([edge_types, e_type], dim=0) | |
if num_hops == "inf": | |
edge_index, edge_types = kb.edge_index, kb.edge_types | |
# sample max_nodes edges | |
if edge_index.size(1) > max_nodes: | |
perm = torch.randperm(edge_index.size(1)) | |
edge_index = edge_index[:, perm[:max_nodes]] | |
edge_types = edge_types[perm[:max_nodes]] | |
add_edge_index, add_edge_types = get_one_hop(kb, node_id, max_nodes=1) | |
edge_index = torch.cat([edge_index, add_edge_index], dim=1) | |
edge_types = torch.cat([edge_types, add_edge_types], dim=0) | |
# add a self-loop for node_id to avoid isolated node | |
edge_index = torch.concat([edge_index, torch.LongTensor([[node_id], [node_id]])], dim=1) | |
node_ids, relabel_edge_index, _, _ = relabel( | |
torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes()) | |
) | |
for idx, n_id in enumerate(node_ids): | |
if node_id == n_id: | |
net.add_node( | |
idx, | |
node_id=n_id.item(), | |
color="#DB4437", | |
size=20, | |
label=f"{kb.node_type_dict[kb.node_types[n_id].item()]}<{n_id}>", | |
font={"align": "middle", "size": 10}, | |
) | |
else: | |
net.add_node( | |
idx, | |
node_id=n_id.item(), | |
size=15, | |
color=NODE_COLORS[kb.node_types[n_id].item()], | |
label=f"{kb.node_type_dict[kb.node_types[n_id].item()]}", | |
font={"align": "middle", "size": 10}, | |
) | |
for idx in range(relabel_edge_index.size(-1)): | |
if relabel_edge_index[0][idx].item() == relabel_edge_index[1][idx].item(): | |
continue | |
if indirected: | |
net.add_edge( | |
relabel_edge_index[0][idx].item(), | |
relabel_edge_index[1][idx].item(), | |
color=EDGE_COLORS[edge_types[idx].item()], | |
label=kb.edge_type_dict[edge_types[idx].item()] | |
.replace('___', " ") | |
.replace('_', " "), | |
width=1, | |
font={"align": "middle", "size": 10}) | |
else: | |
net.add_edge( | |
relabel_edge_index[0][idx].item(), | |
relabel_edge_index[1][idx].item(), | |
color=EDGE_COLORS[edge_types[idx].item()], | |
label=kb.edge_type_dict[edge_types[idx].item()] | |
.replace('___', " ") | |
.replace('_', " "), | |
width=1, | |
font={"align": "middle", "size": 10}, | |
arrows="to", | |
arrowStrikethrough=False) | |
return net.get_network_data() | |
def get_text_html(kb, node_id): | |
text = kb.get_doc_info(node_id, add_rel=False, compact=False) | |
# add a title | |
text = text.replace("\n", "<br>").replace(" ", " ") | |
text = f"<h3>Textual Info of Entity {node_id}:</h3>{text}" | |
text = re.sub(r"\$([^$]+)\$", r"\\(\1\\)", text) | |
# show the text as what it is with empty space and can be scrolled | |
return f"""<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> | |
<script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> | |
<div style="width: 100%; height: 600px; overflow-x: hidden; overflow-y: scroll; overflow-wrap: break-word; hyphens: auto; padding: 10px; margin: 0 auto; border: 1px solid #ccc; line-height: 1.5; | |
font-family: SF Pro Text, SF Pro Icons, Helvetica Neue, Helvetica, Arial, sans-serif;">{text}</div>""" | |
def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops='1'): | |
network = generate_network(kb, node_id, max_nodes, num_hops) | |
nodes = network[0] | |
edges = network[1] | |
# A dirty hack to trigger the drawGraph function ;) | |
# Have to do it this way because of the way Gradio handles HTML updates | |
figure_html = f""" | |
<div id="{kb_name}-network" class="network-graph"></div> | |
<img src="/dummy.img" style="display: none;" onerror='drawGraph({json.dumps({"nodes": nodes, "edges": edges, "dataset": kb_name})});'> | |
""" | |
return figure_html | |
def main(): | |
# kb = get_semistructured_data(DATASET_NAME) | |
kbs = {k: get_semistructured_data(k, indirected=False) for k in BRAND_NAME.keys()} | |
with gr.Blocks(head=VISJS_HEAD, title=TITLE) as demo: | |
gr.Markdown(f"# {TITLE}") | |
for name, kb in kbs.items(): | |
with gr.Tab(BRAND_NAME[name]): | |
with gr.Row(): | |
entity_id = gr.Number( | |
label="Entity ID", | |
elem_id=f"{name}-entity-id-input" | |
) | |
max_paths = gr.Slider( | |
1, 200, 10, step=1, label="Max Number of Paths" | |
) | |
num_hops = gr.Dropdown( | |
["1", "2", "inf"], value="2", label="Number of Hops" | |
) | |
query_btn = gr.Button( | |
value="Display Semi-structured Data", | |
variant="primary", | |
elem_id=f"{name}-fetch-btn" | |
) | |
with gr.Row(): | |
graph_area = gr.HTML(elem_classes="graph-area") | |
text_area = gr.HTML(elem_classes="text-area") | |
query_btn.click( | |
# copy capture current kb and name | |
lambda e, n, h, kb=kb, name=name: ( | |
get_subgraph_html(kb, name, e, n, h), | |
get_text_html(kb, e), | |
), | |
inputs=[entity_id, max_paths, num_hops], | |
outputs=[graph_area, text_area], | |
api_name=f"{name}-fetch-graph" | |
) | |
# Hidden inputs for fetch just text | |
with gr.Row(visible=False): | |
entity_for_text = gr.Number( | |
label="Text Entity ID", elem_id=f"{name}-entity-id-text-input" | |
) | |
query_text_btn = gr.Button( | |
value="Show Text", elem_id=f"{name}-fetch-text-btn" | |
) | |
query_text_btn.click( | |
lambda e, kb=kb: get_text_html(kb, e), | |
inputs=[entity_for_text], | |
outputs=text_area, | |
api_name=f"{name}-fetch-text" | |
) | |
demo.queue(max_size=2*CONCURRENCY_LIMIT, default_concurrency_limit=CONCURRENCY_LIMIT) | |
demo.launch(share=True) | |
if __name__ == "__main__": | |
main() |