How to use
1. modelとtokenizerの呼び出し
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True)
model = AutoModel.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True)
2. modelのoutput
text = '船井電機は民事再生法の適用を東京地裁へ申請しました。同社は10月に裁判所から破産手続きの開始決定を受けており、会長は11月に即時抗告を申し立て。'
model.eval()
a = tokenizer(text,truncation =True,max_length=512, entities=["船井電機・ホールディングス株式会社"], entity_spans=[(0,4)],return_tensors='pt')
outputs = model(**a)
mlm_logits = outputs.logits
tep_logits = outputs.topic_entity_logits
mep_logits = outputs.entity_logits
print(mlm_logits.shape) >> torch.Size([1, 49, 32770])
print(tep_logits.shape) >> torch.Size([1, 20972])
print(mep_logits.shape) >> torch.Size([1, 1, 20972])
- modelのencode結果は,logits, topic_entity_logits, entity_logitsを属性として持ちます.
- logitsは通常のBERTなどの言語モデルと同様の扱い方です.
- topic_entity_logits(文章における各enitityの関連度) と entity_logits(entityの埋め込み表現)に関しては,このモデル固有のものであり,以下に扱い方を解説します.
3. topic_entity_logits(文章における各enitityの関連度を取得)
tokenizer = AutoTokenizer.from_pretrained("uzabase/UBKE-LUKE", trust_remote_code=True)
model = AutoModelForPreTraining.from_pretrained("uzabase/UBKE-LUKE", output_hidden_states=True, trust_remote_code=True)
text = '船井電機は民事再生法の適用を東京地裁へ申請しました。同社は10月に裁判所から破産手続きの開始決定を受けており、会長は11月に即時抗告を申し立て。'
model.eval()
a = tokenizer(text,truncation =True,max_length=512, return_tensors='pt')
outputs = model(**a)
tep_logits = outputs.topic_entity_logits
print("tep_logits shape: ", tep_logits.shape) # >> torch.Size([1, 20972]) each dimentions correspond to entities
ent = { tokenizer.entity_vocab[i]:i for i in tokenizer.entity_vocab}
print("Entity Recognition Results:")
topk_logits, topk_entity_ids = tep_logits.topk(10, dim=1)
for logit, entity_id in zip(topk_logits[0].tolist(), topk_entity_ids[0].tolist()):
print("\t", ent[entity_id], logit)
>>>
Entity Recognition Results:
船井電機・ホールディングス株式会社 1.8898193836212158
セイノーホールディングス 1.668973684310913
東洋電機 1.658090353012085
横河電機 1.6363312005996704
船井総研ホールディングス 1.618525743484497
西菱電機 1.587844967842102
フォスター電機 1.5436134338378906
東洋電機製造 1.493951678276062
ヒロセ電機 1.458113193511963
サクサ 1.4461733102798462
- modelのencode結果は,topic_entity_logits属性を持ちます.
- topic_entity_logitsは, torch.Size([batch_size, entity_size])のtoroch.tensorです.
- 各次元のlogit値は,入力文章における各entityの関連度を表現しています.
4. entity_logits(entityの埋め込み表現)
- entityの一覧は,tokenizerがentity_vocabに辞書形式で持ちます.
tokenizer.entity_vocab # => {"": 0, ... ,"AGC": 48, ....
tokenizer.entity_vocab["味の素"] # => 8469(味の素のentity_id)
- entity_spans及びentitties引数をtokenizerに渡し,tokenをencodeすることで,entityの埋め込み表現を得ます.
model.eval()
tokens = tokenizer("味の素", entities=["味の素"], entity_spans=[(0, 3)], truncation=True, max_length=512, return_tensors="pt")
print(tokens["entity_ids"]) # => tensor([[8469]])
with torch.no_grad():
outputs = model(**tokens)
outputs.entity_logits.shape # 味の素のentity_vector
- entityの埋め込み表現の内積(やコサイン類似度)を計算することで,entity同士の類似度を計算可能です.
def encode(entity_text):
model.eval()
tokens = tokenizer(entity_text, entities=[entity_text], entity_spans=[(0, len(entity_text))],
truncation=True, max_length=512, return_tensors="pt")
with torch.no_grad():
outputs = model(**tokens)
return outputs.entity_logits[0][0]
azinomoto = encode("味の素")
nisshin = encode("日清食品ホールディングス")
kameda = encode("亀田製菓")
sony = encode("ソニーホールディングス")
print(azinomoto @ nisshin) # => tensor(24834.6836)
print(azinomoto @ kameda) # => tensor(17547.6895)
print(azinomoto @ sony) # => tensor(8699.2871)
Licenses
The model parameters model.safetensors
is licensed under CC BY-NC.
モデルの重みファイル model.safetensors
はCC BY-NCライセンスで利用可能です。
Other files are subject to the same license as LUKE itself. その他のファイルはLUKE自体と同じライセンスが適用されます。
Reference
- 開発の背景などについてはブログを参照してください
- もしUBKE-LUKEの活用に興味をお持ちの方は [email protected] までご連絡ください
- Downloads last month
- 132
Inference API (serverless) does not yet support model repos that contain custom code.