File size: 3,341 Bytes
5862643
e870cbe
 
9a35953
 
5862643
a90e2a5
5862643
793b3b5
5862643
793b3b5
d5af3f6
df1323c
 
a90e2a5
e870cbe
5862643
a90e2a5
e870cbe
0d51ceb
e870cbe
0d51ceb
e870cbe
0d51ceb
e870cbe
0d51ceb
e870cbe
0d51ceb
f9dd9f0
 
a90e2a5
73c57ed
a90e2a5
 
f9dd9f0
a90e2a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73c57ed
a90e2a5
 
 
 
 
 
23c96fd
 
 
 
a90e2a5
 
 
 
0d51ceb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
---
base_model:
- google/flan-t5-large
library_name: transformers
license: mit
---
# **A Text-to-Triple Model Trained on WikiOfGraph dataset**

Base Model: [Flan-T5-Large](https://huggingface.co./google/flan-t5-large) by Google

Trained by [Patrick Jiang](https://pat-jj.github.io/) @ UIUC

[Wandb Training Report](https://api.wandb.ai/links/patjj/2njqb94u) (Dec 5, 2024)

## **Example Input:** 
"William Gerald Standridge (November 27, 1953 – April 12, 2014) was an American stock car racing driver. He was a competitor in the NASCAR Winston Cup Series and Busch Series."

## **Output:** 
(S> William gerald standridge| P> Nationality| O> American), 
\
(S> William gerald standridge| P> Occupation| O> Stock car racing driver), 
\
(S> William gerald standridge| P> Competitor| O> Busch series), 
\
(S> William gerald standridge| P> Competitor| O> Nascar winston cup series), 
\
(S> William gerald standridge| P> Birth date| O> November 27, 1953),
\
(S> William gerald standridge| P> Death date| O> April 12, 2014)

## **How to Run?**
```python
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

def generate_triples(input_text: str, model_path: str = "pat-jj/text2triple-flan-t5"):
    # Initialize tokenizer and model
    tokenizer = T5Tokenizer.from_pretrained(model_path)
    model = T5ForConditionalGeneration.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16  # Use bfloat16 for efficiency
    )
    
    # Tokenize input with proper padding and attention mask
    inputs = tokenizer(
        input_text,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )
    
    # Move inputs to the same device as model
    input_ids = inputs['input_ids'].to(model.device)
    attention_mask = inputs['attention_mask'].to(model.device)

    # Generate with better parameters
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=512,
            num_beams=4,  # Use beam search
            early_stopping=True,
            length_penalty=0.6,  # Penalize very long outputs
            use_cache=True  # Use KV cache for faster generation
        )
    
    # Decode and return the generated triples
    generated_triples = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_triples
```

## Example usage
```python
input_text = """Albert Einstein was born in Ulm, Germany in 1879. He developed the theory of relativity and won the Nobel Prize in Physics in 1921.
Einstein worked as a professor at Princeton University until his death in 1955."""

generated_triples = generate_triples(input_text)
print("Generated triples:", generated_triples)
```
## Output:
```
Generated triples: (S> Albert einstein| P> Birth place| O> Ulm, germany), (S> Albert einstein| P> Birth year| O> 1879), (S> Albert einstein| P> Award| O> Nobel prize in physics), (S> Albert einstein| P> Death year| O> 1955), (S> Albert einstein| P> Occupation| O> Professor), (S> Albert einstein| P> Workplace| O> Princeton university)
```



## **Paper of WikiOfGraph dataset**: 
Daehee Kim et al., "Ontology-Free General-Domain Knowledge Graph-to-Text Generation Dataset Synthesis using Large Language Model", 2024.