Spaces:
Runtime error
Runtime error
Remsky
commited on
Commit
Β·
4289090
verified
Β·
0
Parent(s):
Super-squash branch 'main' using huggingface_hub
Browse files- .gitattributes +35 -0
- README.md +13 -0
- app.py +94 -0
- lib/__init__.py +0 -0
- lib/graph_extract.py +142 -0
- lib/samples.py +46 -0
- lib/visualize.py +111 -0
- requirements.txt +7 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Triplex Knowledge Graph Visualizer
|
3 |
+
emoji: πΈοΈ
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: true
|
9 |
+
models:
|
10 |
+
- SciPhi/Triplex
|
11 |
+
preload_from_hub:
|
12 |
+
- SciPhi/Triplex
|
13 |
+
---
|
app.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import random
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import spaces
|
6 |
+
|
7 |
+
from lib.graph_extract import triplextract, parse_triples
|
8 |
+
from lib.visualize import create_cytoscape_plot
|
9 |
+
from lib.samples import snippets
|
10 |
+
|
11 |
+
WORD_LIMIT = 300
|
12 |
+
|
13 |
+
@spaces.GPU
|
14 |
+
def process_text(text, entity_types, predicates):
|
15 |
+
if not text:
|
16 |
+
return None, "Please enter some text."
|
17 |
+
|
18 |
+
words = text.split()
|
19 |
+
if len(words) > WORD_LIMIT:
|
20 |
+
return None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}"
|
21 |
+
|
22 |
+
entity_types = [et.strip() for et in entity_types.split(",") if et.strip()]
|
23 |
+
predicates = [p.strip() for p in predicates.split(",") if p.strip()]
|
24 |
+
|
25 |
+
if not entity_types:
|
26 |
+
return None, "Please enter at least one entity type."
|
27 |
+
if not predicates:
|
28 |
+
return None, "Please enter at least one predicate."
|
29 |
+
|
30 |
+
try:
|
31 |
+
prediction = triplextract(text, entity_types, predicates)
|
32 |
+
if prediction.startswith("Error"):
|
33 |
+
return None, prediction
|
34 |
+
|
35 |
+
entities, relationships = parse_triples(prediction)
|
36 |
+
|
37 |
+
if not entities and not relationships:
|
38 |
+
return (
|
39 |
+
None,
|
40 |
+
"No entities or relationships found. Try different text or check your input.",
|
41 |
+
)
|
42 |
+
|
43 |
+
fig = create_cytoscape_plot(entities, relationships)
|
44 |
+
return (
|
45 |
+
fig,
|
46 |
+
f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}",
|
47 |
+
)
|
48 |
+
except Exception as e:
|
49 |
+
print(f"Error in process_text: {e}")
|
50 |
+
return None, f"An error occurred: {str(e)}"
|
51 |
+
|
52 |
+
def update_inputs(sample_name):
|
53 |
+
sample = snippets[sample_name]
|
54 |
+
return sample.text_input, sample.entity_types, sample.predicates
|
55 |
+
|
56 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
57 |
+
gr.Markdown("# Knowledge Graph Extractor")
|
58 |
+
|
59 |
+
default_sample_name = random.choice(list(snippets.keys()))
|
60 |
+
default_sample = snippets[default_sample_name]
|
61 |
+
|
62 |
+
with gr.Row():
|
63 |
+
with gr.Column(scale=1):
|
64 |
+
sample_dropdown = gr.Dropdown(
|
65 |
+
choices=list(snippets.keys()),
|
66 |
+
label="Select Sample",
|
67 |
+
value=default_sample_name
|
68 |
+
)
|
69 |
+
input_text = gr.Textbox(
|
70 |
+
label="Input Text",
|
71 |
+
lines=5,
|
72 |
+
value=default_sample.text_input
|
73 |
+
)
|
74 |
+
entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types)
|
75 |
+
predicates = gr.Textbox(label="Predicates", value=default_sample.predicates)
|
76 |
+
submit_btn = gr.Button("Extract Knowledge Graph")
|
77 |
+
with gr.Column(scale=2):
|
78 |
+
output_graph = gr.Plot(label="Knowledge Graph")
|
79 |
+
error_message = gr.Textbox(label="Textual Output")
|
80 |
+
|
81 |
+
sample_dropdown.change(
|
82 |
+
update_inputs,
|
83 |
+
inputs=[sample_dropdown],
|
84 |
+
outputs=[input_text, entity_types, predicates]
|
85 |
+
)
|
86 |
+
|
87 |
+
submit_btn.click(
|
88 |
+
process_text,
|
89 |
+
inputs=[input_text, entity_types, predicates],
|
90 |
+
outputs=[output_graph, error_message],
|
91 |
+
)
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
demo.launch()
|
lib/__init__.py
ADDED
File without changes
|
lib/graph_extract.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
4 |
+
import torch
|
5 |
+
import warnings
|
6 |
+
import spaces
|
7 |
+
|
8 |
+
flash_attn_installed = False
|
9 |
+
try:
|
10 |
+
import subprocess
|
11 |
+
print("Installing flash-attn...")
|
12 |
+
subprocess.run(
|
13 |
+
"pip install flash-attn --no-build-isolation",
|
14 |
+
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
15 |
+
shell=True,
|
16 |
+
)
|
17 |
+
flash_attn_installed = True
|
18 |
+
except Exception as e:
|
19 |
+
print(f"Error installing flash-attn: {e}")
|
20 |
+
|
21 |
+
|
22 |
+
# Suppress specific warnings
|
23 |
+
warnings.filterwarnings(
|
24 |
+
"ignore",
|
25 |
+
message="You have modified the pretrained model configuration to control generation.",
|
26 |
+
)
|
27 |
+
warnings.filterwarnings(
|
28 |
+
"ignore",
|
29 |
+
message="You are not running the flash-attention implementation, expect numerical differences.",
|
30 |
+
)
|
31 |
+
|
32 |
+
print("Initializing application...")
|
33 |
+
|
34 |
+
model = AutoModelForCausalLM.from_pretrained(
|
35 |
+
"sciphi/triplex",
|
36 |
+
trust_remote_code=True,
|
37 |
+
attn_implementation="flash_attention_2" if flash_attn_installed else None,
|
38 |
+
torch_dtype=torch.bfloat16,
|
39 |
+
device_map="auto",
|
40 |
+
low_cpu_mem_usage=True,#advised if any device map given
|
41 |
+
).eval()
|
42 |
+
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
44 |
+
"sciphi/triplex",
|
45 |
+
trust_remote_code=True,
|
46 |
+
attn_implementation="flash_attention_2",
|
47 |
+
torch_dtype=torch.bfloat16,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
print("Model and tokenizer loaded successfully.")
|
52 |
+
|
53 |
+
# Set up generation config
|
54 |
+
generation_config = GenerationConfig.from_pretrained("sciphi/triplex")
|
55 |
+
generation_config.max_length = 2048
|
56 |
+
generation_config.pad_token_id = tokenizer.eos_token_id
|
57 |
+
@spaces.GPU
|
58 |
+
def triplextract(text, entity_types, predicates):
|
59 |
+
input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates. Return the result as a JSON object with an "entities_and_triples" key containing an array of entities and triples.
|
60 |
+
**Entity Types:**
|
61 |
+
{entity_types}
|
62 |
+
**Predicates:**
|
63 |
+
{predicates}
|
64 |
+
**Text:**
|
65 |
+
{text}
|
66 |
+
"""
|
67 |
+
message = input_format.format(
|
68 |
+
entity_types = json.dumps({"entity_types": entity_types}),
|
69 |
+
predicates = json.dumps({"predicates": predicates}),
|
70 |
+
text = text)
|
71 |
+
|
72 |
+
# message = input_format.format(
|
73 |
+
# entity_types=entity_types, predicates=predicates, text=text
|
74 |
+
# )
|
75 |
+
|
76 |
+
messages = [{"role": "user", "content": message}]
|
77 |
+
|
78 |
+
print("Tokenizing input...")
|
79 |
+
input_ids = tokenizer.apply_chat_template(
|
80 |
+
messages, add_generation_prompt=True, return_tensors="pt"
|
81 |
+
).to(model.device)
|
82 |
+
|
83 |
+
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
84 |
+
|
85 |
+
print("Generating output...")
|
86 |
+
try:
|
87 |
+
with torch.no_grad():
|
88 |
+
output = model.generate(
|
89 |
+
input_ids=input_ids,
|
90 |
+
attention_mask=attention_mask,
|
91 |
+
generation_config=generation_config,
|
92 |
+
do_sample=True,
|
93 |
+
)
|
94 |
+
|
95 |
+
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
|
96 |
+
print("Decoding output completed.")
|
97 |
+
|
98 |
+
return decoded_output
|
99 |
+
except torch.cuda.OutOfMemoryError as e:
|
100 |
+
print(f"CUDA out of memory error: {e}")
|
101 |
+
return "Error: CUDA out of memory."
|
102 |
+
except Exception as e:
|
103 |
+
print(f"Error in generation: {e}")
|
104 |
+
return f"Error in generation: {str(e)}"
|
105 |
+
|
106 |
+
def parse_triples(prediction):
|
107 |
+
entities = {}
|
108 |
+
relationships = []
|
109 |
+
|
110 |
+
try:
|
111 |
+
data = json.loads(prediction)
|
112 |
+
items = data.get("entities_and_triples", [])
|
113 |
+
except json.JSONDecodeError:
|
114 |
+
json_match = re.search(r"```json\s*(.*?)\s*```", prediction, re.DOTALL)
|
115 |
+
if json_match:
|
116 |
+
try:
|
117 |
+
data = json.loads(json_match.group(1))
|
118 |
+
items = data.get("entities_and_triples", [])
|
119 |
+
except json.JSONDecodeError:
|
120 |
+
items = re.findall(r"\[(.*?)\]", prediction)
|
121 |
+
else:
|
122 |
+
items = re.findall(r"\[(.*?)\]", prediction)
|
123 |
+
|
124 |
+
for item in items:
|
125 |
+
if isinstance(item, str):
|
126 |
+
if ":" in item:
|
127 |
+
id, entity = item.split(",", 1)
|
128 |
+
id = id.strip("[]").strip()
|
129 |
+
entity_type, entity_value = entity.split(":", 1)
|
130 |
+
entities[id] = {
|
131 |
+
"type": entity_type.strip(),
|
132 |
+
"value": entity_value.strip(),
|
133 |
+
}
|
134 |
+
else:
|
135 |
+
parts = item.split()
|
136 |
+
if len(parts) >= 3:
|
137 |
+
source = parts[0].strip("[]")
|
138 |
+
relation = " ".join(parts[1:-1])
|
139 |
+
target = parts[-1].strip("[]")
|
140 |
+
relationships.append((source, relation.strip(), target))
|
141 |
+
|
142 |
+
return entities, relationships
|
lib/samples.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
Snippet = namedtuple('Snippet', ['text_input', 'entity_types', 'predicates'])
|
4 |
+
|
5 |
+
snippets = {
|
6 |
+
'paris': Snippet(
|
7 |
+
text_input="""Paris is the capital of France. It has a population of 2.16 million people.
|
8 |
+
The Eiffel Tower, located in Paris, is a famous landmark with a height of 324 meters.
|
9 |
+
Paris is known for its romantic atmosphere.""",
|
10 |
+
entity_types="LOCATION, POPULATION, STYLE",
|
11 |
+
predicates="HAS, IS"
|
12 |
+
),
|
13 |
+
|
14 |
+
'dickens': Snippet(
|
15 |
+
text_input="""It was the best of times, it was the worst of times, it was the age of wisdom,
|
16 |
+
it was the age of foolishness, it was the epoch of belief, it was the epoch of incredulity,
|
17 |
+
it was the season of Light, it was the season of Darkness, it was the spring of hope,
|
18 |
+
it was the winter of despair, we had everything before us, we had nothing before us,
|
19 |
+
we were all going direct to Heaven, we were all going direct the other way β in short,
|
20 |
+
the period was so far like the present period, that some of its noisiest authorities
|
21 |
+
insisted on its being received, for good or for evil, in the superlative degree of comparison only.""",
|
22 |
+
entity_types="TIME, EMOTION, LOCATION, EVENT, OUTCOME, PLACE",
|
23 |
+
predicates="WAS, HAD, WERE"
|
24 |
+
),
|
25 |
+
|
26 |
+
'tech_company': Snippet(
|
27 |
+
text_input="""Apple Inc. was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne in 1976.
|
28 |
+
Headquartered in Cupertino, California, Apple designs and produces consumer electronics,
|
29 |
+
software, and online services. The company's flagship products include the iPhone smartphone,
|
30 |
+
iPad tablet, and Mac personal computer. As of 2023, Apple has over 150,000 employees worldwide
|
31 |
+
and generates annual revenue exceeding $350 billion.""",
|
32 |
+
entity_types="COMPANY, PERSON, PRODUCT, LOCATION, DATE, NUMBER, EVENT, SUBJECT",
|
33 |
+
predicates="FOUNDED, HEADQUARTERED_IN, PRODUCES, HAS, EMPLOYEES, "
|
34 |
+
),
|
35 |
+
|
36 |
+
'climate_change': Snippet(
|
37 |
+
text_input="""Global warming is causing significant changes to Earth's climate. The average global
|
38 |
+
temperature has increased by approximately 1.1Β°C since the pre-industrial era. This warming is
|
39 |
+
primarily caused by human activities, particularly the emission of greenhouse gases like carbon dioxide.
|
40 |
+
The Paris Agreement, signed in 2015, aims to limit global temperature increase to well below 2Β°C above
|
41 |
+
pre-industrial levels. To achieve this goal, many countries are implementing policies to reduce carbon
|
42 |
+
emissions and transition to renewable energy sources.""",
|
43 |
+
entity_types="PHENOMENON, PLANET, TEMPERATURE, CAUSE, CHEMICAL, AGREEMENT, DATE, GOAL, POLICY",
|
44 |
+
predicates="CAUSES, INCREASED_BY, CAUSED_BY, SIGNED_IN, AIMS_TO, IMPLEMENTING"
|
45 |
+
)
|
46 |
+
}
|
lib/visualize.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.graph_objects as go
|
2 |
+
import networkx as nx
|
3 |
+
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
import networkx as nx
|
6 |
+
|
7 |
+
def create_cytoscape_plot(entities, relationships):
|
8 |
+
G = nx.DiGraph() # Use DiGraph for directed edges
|
9 |
+
|
10 |
+
for entity_id, entity_data in entities.items():
|
11 |
+
G.add_node(entity_id, **entity_data)
|
12 |
+
|
13 |
+
for source, relation, target in relationships:
|
14 |
+
G.add_edge(source, target, relation=relation)
|
15 |
+
|
16 |
+
pos = nx.spring_layout(G, k=0.5, iterations=50) # Adjust layout parameters
|
17 |
+
|
18 |
+
edge_trace = go.Scatter(
|
19 |
+
x=[],
|
20 |
+
y=[],
|
21 |
+
line=dict(width=1, color="#888"),
|
22 |
+
hoverinfo="text",
|
23 |
+
mode="lines",
|
24 |
+
text=[],
|
25 |
+
)
|
26 |
+
|
27 |
+
node_trace = go.Scatter(
|
28 |
+
x=[],
|
29 |
+
y=[],
|
30 |
+
mode="markers+text",
|
31 |
+
hoverinfo="text",
|
32 |
+
marker=dict(
|
33 |
+
showscale=True,
|
34 |
+
colorscale="Viridis",
|
35 |
+
reversescale=True,
|
36 |
+
color=[],
|
37 |
+
size=15,
|
38 |
+
colorbar=dict(
|
39 |
+
thickness=15,
|
40 |
+
title="Node Connections",
|
41 |
+
xanchor="left",
|
42 |
+
titleside="right",
|
43 |
+
),
|
44 |
+
line_width=2,
|
45 |
+
),
|
46 |
+
text=[],
|
47 |
+
textposition="top center",
|
48 |
+
)
|
49 |
+
|
50 |
+
edge_labels = []
|
51 |
+
|
52 |
+
for edge in G.edges():
|
53 |
+
x0, y0 = pos[edge[0]]
|
54 |
+
x1, y1 = pos[edge[1]]
|
55 |
+
edge_trace["x"] += (x0, x1, None)
|
56 |
+
edge_trace["y"] += (y0, y1, None)
|
57 |
+
|
58 |
+
# Calculate midpoint for edge label
|
59 |
+
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
|
60 |
+
edge_labels.append(
|
61 |
+
go.Scatter(
|
62 |
+
x=[mid_x],
|
63 |
+
y=[mid_y],
|
64 |
+
mode="text",
|
65 |
+
text=[G.edges[edge]["relation"]],
|
66 |
+
textposition="middle center",
|
67 |
+
hoverinfo="none",
|
68 |
+
showlegend=False,
|
69 |
+
textfont=dict(size=8),
|
70 |
+
)
|
71 |
+
)
|
72 |
+
|
73 |
+
for node in G.nodes():
|
74 |
+
x, y = pos[node]
|
75 |
+
node_trace["x"] += (x,)
|
76 |
+
node_trace["y"] += (y,)
|
77 |
+
node_info = f"{entities[node]['value']} ({entities[node]['type']})"
|
78 |
+
node_trace["text"] += (node_info,)
|
79 |
+
node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)
|
80 |
+
|
81 |
+
fig = go.Figure(
|
82 |
+
data=[edge_trace, node_trace] + edge_labels,
|
83 |
+
layout=go.Layout(
|
84 |
+
title="Knowledge Graph",
|
85 |
+
titlefont_size=16,
|
86 |
+
showlegend=False,
|
87 |
+
hovermode="closest",
|
88 |
+
margin=dict(b=20, l=5, r=5, t=40),
|
89 |
+
annotations=[],
|
90 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
91 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
92 |
+
width=800,
|
93 |
+
height=600,
|
94 |
+
),
|
95 |
+
)
|
96 |
+
|
97 |
+
# Enable dragging of nodes
|
98 |
+
fig.update_layout(
|
99 |
+
newshape=dict(line_color="#009900"),
|
100 |
+
# Enable zoom
|
101 |
+
xaxis=dict(
|
102 |
+
scaleanchor="y",
|
103 |
+
scaleratio=1,
|
104 |
+
),
|
105 |
+
yaxis=dict(
|
106 |
+
scaleanchor="x",
|
107 |
+
scaleratio=1,
|
108 |
+
),
|
109 |
+
)
|
110 |
+
|
111 |
+
return fig
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.39.0
|
2 |
+
plotly==5.23.0
|
3 |
+
matplotlib==3.7.2
|
4 |
+
torch==2.0.1
|
5 |
+
transformers==4.43.3
|
6 |
+
accelerate==0.33.0
|
7 |
+
networkx
|