text2triple-flan-t5 / README.md
pat-jj's picture
Update README.md
df1323c verified
|
raw
history blame
3.34 kB
metadata
base_model:
  - google/flan-t5-large
library_name: transformers
license: mit

A Text-to-Triple Model Trained on WikiOfGraph dataset

Base Model: Flan-T5-Large by Google

Trained by Patrick Jiang @ UIUC

Wandb Training Report (Dec 5, 2024)

Example Input:

"William Gerald Standridge (November 27, 1953 – April 12, 2014) was an American stock car racing driver. He was a competitor in the NASCAR Winston Cup Series and Busch Series."

Output:

(S> William gerald standridge| P> Nationality| O> American),
(S> William gerald standridge| P> Occupation| O> Stock car racing driver),
(S> William gerald standridge| P> Competitor| O> Busch series),
(S> William gerald standridge| P> Competitor| O> Nascar winston cup series),
(S> William gerald standridge| P> Birth date| O> November 27, 1953),
(S> William gerald standridge| P> Death date| O> April 12, 2014)

How to Run?

from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

def generate_triples(input_text: str, model_path: str = "pat-jj/text2triple-flan-t5"):
    # Initialize tokenizer and model
    tokenizer = T5Tokenizer.from_pretrained(model_path)
    model = T5ForConditionalGeneration.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16  # Use bfloat16 for efficiency
    )
    
    # Tokenize input with proper padding and attention mask
    inputs = tokenizer(
        input_text,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )
    
    # Move inputs to the same device as model
    input_ids = inputs['input_ids'].to(model.device)
    attention_mask = inputs['attention_mask'].to(model.device)

    # Generate with better parameters
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=512,
            num_beams=4,  # Use beam search
            early_stopping=True,
            length_penalty=0.6,  # Penalize very long outputs
            use_cache=True  # Use KV cache for faster generation
        )
    
    # Decode and return the generated triples
    generated_triples = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_triples

Example usage

input_text = """Albert Einstein was born in Ulm, Germany in 1879. He developed the theory of relativity and won the Nobel Prize in Physics in 1921.
Einstein worked as a professor at Princeton University until his death in 1955."""

generated_triples = generate_triples(input_text)
print("Generated triples:", generated_triples)

Output:

Generated triples: (S> Albert einstein| P> Birth place| O> Ulm, germany), (S> Albert einstein| P> Birth year| O> 1879), (S> Albert einstein| P> Award| O> Nobel prize in physics), (S> Albert einstein| P> Death year| O> 1955), (S> Albert einstein| P> Occupation| O> Professor), (S> Albert einstein| P> Workplace| O> Princeton university)

Paper of WikiOfGraph dataset:

Daehee Kim et al., "Ontology-Free General-Domain Knowledge Graph-to-Text Generation Dataset Synthesis using Large Language Model", 2024.