import torch from typing import List, Union from transformers import BertForMaskedLM, BertTokenizerFast class BertForLexPrediction(BertForMaskedLM): def __init__(self, config): super().__init__(config) def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast): if isinstance(sentences, str): sentences = [sentences] # predict the logits for the sentence inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt') inputs = {k:v.to(self.device) for k,v in inputs.items()} logits = self.forward(**inputs, return_dict=True).logits # for each token, we will take the top 10, and search for one that is appropriate. If none, then # return a [BLANK] for that word. input_ids = inputs['input_ids'] batch_ret = [] for batch_idx in range(len(sentences)): ret = [] batch_ret.append(ret) for tok_idx in range(input_ids.shape[1]): token_id = input_ids[batch_idx, tok_idx] # ignore cls, sep, pad if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue token = tokenizer._convert_id_to_token(token_id) # wordpieces should just be appended to the previous word if token.startswith('##'): ret[-1] = (ret[-1][0] + token[2:], ret[-1][1]) continue ret.append((token, tokenizer._convert_id_to_token(torch.argmax(logits[batch_idx, tok_idx])))) return batch_ret