File size: 4,067 Bytes
7eca486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import List
from transformers import pipeline
from pyvis.network import Network
from functools import lru_cache
import spacy
from spacy import displacy


DEFAULT_LABEL_COLORS = {
    "ORG": "#7aecec",
    "PRODUCT": "#bfeeb7",
    "GPE": "#feca74",
    "LOC": "#ff9561",
    "PERSON": "#aa9cfc",
    "NORP": "#c887fb",
    "FACILITY": "#9cc9cc",
    "EVENT": "#ffeb80",
    "LAW": "#ff8197",
    "LANGUAGE": "#ff8197",
    "WORK_OF_ART": "#f0d0ff",
    "DATE": "#bfe1d9",
    "TIME": "#bfe1d9",
    "MONEY": "#e4e7d2",
    "QUANTITY": "#e4e7d2",
    "ORDINAL": "#e4e7d2",
    "CARDINAL": "#e4e7d2",
    "PERCENT": "#e4e7d2",
}

def generate_knowledge_graph(texts: List[str], filename: str):
    nlp = spacy.load("en_core_web_sm")
    doc = nlp("\n".join(texts).lower())
    NERs = [ent.text for ent in doc.ents]
    NER_types =  [ent.label_ for ent in doc.ents]

    triplets = []
    for triplet in texts:
        triplets.extend(generate_partial_graph(triplet))
    heads = [ t["head"].lower() for t in triplets]
    tails = [ t["tail"].lower() for t in triplets]

    nodes = list(set(heads + tails))
    net = Network(directed=True, width="700px", height="700px")

    for n in nodes:
        if n in NERs:
            NER_type = NER_types[NERs.index(n)]
            if NER_type in NER_types:
                if NER_type in DEFAULT_LABEL_COLORS.keys():
                    color = DEFAULT_LABEL_COLORS[NER_type]
                else:
                    color = "#666666"
                net.add_node(n, title=NER_type, shape="circle", color=color)
            else:
                net.add_node(n, shape="circle")
        else:
            net.add_node(n, shape="circle")

    unique_triplets = set()
    stringify_trip = lambda x : x["tail"] + x["head"] + x["type"].lower()
    for triplet in triplets:
        if stringify_trip(triplet) not in unique_triplets:
            net.add_edge(triplet["head"].lower(), triplet["tail"].lower(),
                         title=triplet["type"], label=triplet["type"])
            unique_triplets.add(stringify_trip(triplet))

    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename)
    return nodes


@lru_cache(maxsize=16)
def generate_partial_graph(text: str):
    triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
    a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
    extracted_text = triplet_extractor.tokenizer.batch_decode(a)
    extracted_triplets = extract_triplets(extracted_text[0])
    return extracted_triplets


def extract_triplets(text):
    """
    Function to parse the generated text and extract the triplets
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})

    return triplets