import logging import torch import torch.nn as nn from transformers import AutoModel from utils.nn_utils import gelu from modules.token_embedders.bert_encoder import BertLinear logger = logging.getLogger(__name__) class PretrainedEncoder(nn.Module): """This class using pre-trained model to encode token, then fine-tuning the pre-trained model """ def __init__(self, pretrained_model_name, trainable=False, output_size=0, activation=gelu, dropout=0.0): """This function initialize pertrained model Arguments: pretrained_model_name {str} -- pre-trained model name Keyword Arguments: output_size {float} -- output size (default: {None}) activation {nn.Module} -- activation function (default: {gelu}) dropout {float} -- dropout rate (default: {0.0}) """ super().__init__() self.pretrained_model = AutoModel.from_pretrained(pretrained_model_name) logger.info("Load pre-trained model {} successfully.".format(pretrained_model_name)) self.output_size = output_size if trainable: logger.info("Start fine-tuning pre-trained model {}.".format(pretrained_model_name)) else: logger.info("Keep fixed pre-trained model {}.".format(pretrained_model_name)) for param in self.pretrained_model.parameters(): param.requires_grad = trainable if self.output_size > 0: self.mlp = BertLinear(input_size=self.pretrained_model.config.hidden_size, output_size=self.output_size, activation=activation) else: self.output_size = self.pretrained_model.config.hidden_size self.mlp = lambda x: x if dropout > 0: self.dropout = nn.Dropout(p=dropout) else: self.dropout = lambda x: x def get_output_dims(self): return self.output_size def forward(self, seq_inputs, token_type_inputs=None): """forward calculates forward propagation results, get token embedding Args: seq_inputs {tensor} -- sequence inputs (tokenized) token_type_inputs (tensor, optional): token type inputs. Defaults to None. Returns: tensor: bert output for tokens """ if token_type_inputs is None: token_type_inputs = torch.zeros_like(seq_inputs) mask_inputs = (seq_inputs != 0).long() outputs = self.pretrained_model(input_ids=seq_inputs, token_type_ids=token_type_inputs, attention_mask=mask_inputs) last_hidden_state = outputs[0] pooled_output = outputs[1] return self.dropout(self.mlp(last_hidden_state)), self.dropout(self.mlp(pooled_output))