How to use

1. modelとtokenizerの呼び出し

from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True)
model = AutoModel.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True)

2. modelのoutput

text = '船井電機は民事再生法の適用を東京地裁へ申請しました。同社は10月に裁判所から破産手続きの開始決定を受けており、会長は11月に即時抗告を申し立て。'
model.eval()
a = tokenizer(text,truncation =True,max_length=512, entities=["船井電機・ホールディングス株式会社"], entity_spans=[(0,4)],return_tensors='pt')
outputs = model(**a)
mlm_logits = outputs.logits
tep_logits = outputs.topic_entity_logits
mep_logits = outputs.entity_logits
print(mlm_logits.shape) >> torch.Size([1, 49, 32770])
print(tep_logits.shape) >> torch.Size([1, 20972])
print(mep_logits.shape) >> torch.Size([1, 1, 20972]) 
  • modelのencode結果は,logits, topic_entity_logits, entity_logitsを属性として持ちます.
  • logitsは通常のBERTなどの言語モデルと同様の扱い方です.
  • topic_entity_logits(文章における各enitityの関連度) と entity_logits(entityの埋め込み表現)に関しては,このモデル固有のものであり,以下に扱い方を解説します.

3. topic_entity_logits(文章における各enitityの関連度を取得)

tokenizer = AutoTokenizer.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True)
model = AutoModelForPreTraining.from_pretrained("uzabase/UBKE-LUKE", output_hidden_states=True, trust_remote_code=True)

text = '船井電機は民事再生法の適用を東京地裁へ申請しました。同社は10月に裁判所から破産手続きの開始決定を受けており、会長は11月に即時抗告を申し立て。'

model.eval()
a = tokenizer(text,truncation =True,max_length=512, return_tensors='pt')
outputs = model(**a)
tep_logits = outputs.topic_entity_logits
print("tep_logits shape: ", tep_logits.shape) # >> torch.Size([1, 20972]) each dimentions correspond to entities

ent = { tokenizer.entity_vocab[i]:i for i in tokenizer.entity_vocab}

print("Entity Recognition Results:")
topk_logits, topk_entity_ids = tep_logits.topk(10, dim=1)
for logit, entity_id in zip(topk_logits[0].tolist(), topk_entity_ids[0].tolist()):
    print("\t", ent[entity_id], logit)
>>>
Entity Recognition Results:
     船井電機・ホールディングス株式会社 1.8898193836212158
     セイノーホールディングス 1.668973684310913
     東洋電機 1.658090353012085
     横河電機 1.6363312005996704
     船井総研ホールディングス 1.618525743484497
     西菱電機 1.587844967842102
     フォスター電機 1.5436134338378906
     東洋電機製造 1.493951678276062
     ヒロセ電機 1.458113193511963
     サクサ 1.4461733102798462
  • modelのencode結果は,topic_entity_logits属性を持ちます.
  • topic_entity_logitsは, torch.Size([batch_size, entity_size])のtoroch.tensorです.
  • 各次元のlogit値は,入力文章における各entityの関連度を表現しています.

4. entity_logits(entityの埋め込み表現)

  • entityの一覧は,tokenizerがentity_vocabに辞書形式で持ちます.
tokenizer.entity_vocab # => {"": 0, ... ,"AGC": 48, ....
tokenizer.entity_vocab["味の素"] # => 8469(味の素のentity_id) 
  • entity_spans及びentitties引数をtokenizerに渡し,tokenをencodeすることで,entityの埋め込み表現を得ます.
model.eval()
tokens = tokenizer("味の素", entities=["味の素"], entity_spans=[(0, 3)], truncation=True, max_length=512, return_tensors="pt")
print(tokens["entity_ids"]) # => tensor([[8469]])
with torch.no_grad():
    outputs = model(**tokens)
outputs.entity_logits.shape # 味の素のentity_vector
  • entityの埋め込み表現の内積(やコサイン類似度)を計算することで,entity同士の類似度を計算可能です.
def encode(entity_text):
    model.eval()
    tokens = tokenizer(entity_text, entities=[entity_text], entity_spans=[(0, len(entity_text))],
                       truncation=True, max_length=512, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**tokens)
    return outputs.entity_logits[0][0]
azinomoto = encode("味の素")
nisshin = encode("日清食品ホールディングス")
kameda = encode("亀田製菓")
sony = encode("ソニーホールディングス")
print(azinomoto @ nisshin) # => tensor(24834.6836)
print(azinomoto @ kameda) # => tensor(17547.6895)
print(azinomoto @ sony) # => tensor(8699.2871)

Licenses

The model parameters model.safetensors is licensed under CC BY-NC. モデルの重みファイル model.safetensors はCC BY-NCライセンスで利用可能です。

Other files are subject to the same license as LUKE itself. その他のファイルはLUKE自体と同じライセンスが適用されます。

Reference

  • 開発の背景などについてはブログを参照してください
  • もしUBKE-LUKEの活用に興味をお持ちの方は [email protected] までご連絡ください
Downloads last month
132
Safetensors
Model size
139M params
Tensor type
F32
·
Inference Examples
Inference API (serverless) does not yet support model repos that contain custom code.