Misha24-10's picture
Create mian.py
be995d4
raw
history blame
3.46 kB
!pip install -q transformers
from transformers import RemBertForTokenClassification, RemBertTokenizerFast
from transformers import XLMRobertaForTokenClassification, XLMRobertaTokenizerFast
import torch
main_path = "Misha24-10/MultiCoNER-2-recognition-model"
model_1 = XLMRobertaForTokenClassification.from_pretrained(main_path,
subfolder = "xlm_roberta_large_mountain")
tokenizer_1 = XLMRobertaTokenizerFast.from_pretrained(main_path,
subfolder = "xlm_roberta_large_mountain")
model_2 = RemBertForTokenClassification.from_pretrained(main_path,
subfolder = "google-rembert-ft_for_multi_ner_v3")
tokenizer_2 = RemBertTokenizerFast.from_pretrained(main_path,
subfolder = "google-rembert-ft_for_multi_ner_v3")
model_3 = RemBertForTokenClassification.from_pretrained(main_path,
subfolder = "google-rembert-ft_for_multi_ner_sky")
tokenizer_3 = RemBertTokenizerFast.from_pretrained(main_path,
subfolder = "google-rembert-ft_for_multi_ner_sky")
import torch
def compute_last_leyer_probs(model, tokenizer, sentence):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
number_of_tokens = tokenizer.encode_plus(sentence, return_tensors='pt',)['input_ids'].shape[-1]
list_of_words = sentence.split()
inputs = tokenizer(list_of_words, is_split_into_words=True, padding='max_length', max_length = min(number_of_tokens,512), truncation=True, return_tensors="pt")
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
label_ids = torch.Tensor(align_word_ids(inputs.word_ids()))
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
return (logits[:, (label_ids == 1), :])
weights = {'model_1': 1, 'model_2': 1, 'model_3': 1}
def align_word_ids(word_ids, return_word_ids=False):
previous_word_idx = None
label_ids = []
index_list = []
for idx, word_idx in enumerate(word_ids):
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
try:
label_ids.append(1)
index_list.append(idx)
except:
label_ids.append(-100)
else:
try:
label_ids.append(1 if label_all_tokens else -100)
except:
label_ids.append(-100)
previous_word_idx = word_idx
if return_word_ids:
return label_ids, index_list
else:
return label_ids
def weighted_voting(sentence):
predictions = []
for idx, (model, tokenizer) in enumerate([(model_1, tokenizer_1), (model_2, tokenizer_2), (model_3, tokenizer_3)]):
logits = compute_last_leyer_probs(model, tokenizer, sentence)
predictions.append(logits * weights[f'model_{idx+1}'])
final_logits = sum(predictions)
final_predictions = torch.argmax(final_logits, dim=2)
labels = [model_1.config.id2label[i] for i in final_predictions.tolist()[0]]
return labels
sent_ex = "Elon Musk 's brother sits on the boards of tesla".lower()
weighted_voting(sent_ex)