import logging import torch import torch.nn as nn from transformers import BertModel from utils.nn_utils import gelu logger = logging.getLogger(__name__) class BertEncoder(nn.Module): """This class using pretrained `Bert` model to encode token, then fine-tuning `Bert` model """ def __init__(self, bert_model_name, trainable=False, output_size=0, activation=gelu, dropout=0.0): """This function initialize pertrained `Bert` model Arguments: bert_model_name {str} -- bert model name Keyword Arguments: output_size {float} -- output size (default: {None}) activation {nn.Module} -- activation function (default: {gelu}) dropout {float} -- dropout rate (default: {0.0}) """ super().__init__() self.bert_model = BertModel.from_pretrained(bert_model_name) logger.info("Load bert model {} successfully.".format(bert_model_name)) self.output_size = output_size if trainable: logger.info("Start fine-tuning bert model {}.".format(bert_model_name)) else: logger.info("Keep fixed bert model {}.".format(bert_model_name)) for param in self.bert_model.parameters(): param.requires_grad = trainable if self.output_size > 0: self.mlp = BertLinear(input_size=self.bert_model.config.hidden_size, output_size=self.output_size, activation=activation) else: self.output_size = self.bert_model.config.hidden_size self.mlp = lambda x: x if dropout > 0: self.dropout = nn.Dropout(p=dropout) else: self.dropout = lambda x: x def get_output_dims(self): return self.output_size def forward(self, seq_inputs, token_type_inputs=None): """forward calculates forward propagation results, get token embedding Args: seq_inputs {tensor} -- sequence inputs (tokenized) token_type_inputs (tensor, optional): token type inputs. Defaults to None. Returns: tensor: bert output for tokens """ if token_type_inputs is None: token_type_inputs = torch.zeros_like(seq_inputs) mask_inputs = (seq_inputs != 0).long() outputs = self.bert_model(input_ids=seq_inputs, attention_mask=mask_inputs, token_type_ids=token_type_inputs) last_hidden_state = outputs[0] pooled_output = outputs[1] return self.dropout(self.mlp(last_hidden_state)), self.dropout(self.mlp(pooled_output)) class BertLayerNorm(nn.Module): """This class is LayerNorm model for Bert """ def __init__(self, hidden_size, eps=1e-12): """This function sets `BertLayerNorm` parameters Arguments: hidden_size {int} -- input size Keyword Arguments: eps {float} -- epsilon (default: {1e-12}) """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): """This function propagates forwardly Arguments: x {tensor} -- input tesor Returns: tensor -- LayerNorm outputs """ u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class BertLinear(nn.Module): """This class is Linear model for Bert """ def __init__(self, input_size, output_size, activation=gelu, dropout=0.0): """This function sets `BertLinear` model parameters Arguments: input_size {int} -- input size output_size {int} -- output size Keyword Arguments: activation {function} -- activation function (default: {gelu}) dropout {float} -- dropout rate (default: {0.0}) """ super().__init__() self.input_size = input_size self.output_size = output_size self.linear = nn.Linear(input_size, output_size) self.linear.weight.data.normal_(mean=0.0, std=0.02) self.linear.bias.data.zero_() self.activation = activation self.layer_norm = BertLayerNorm(self.output_size) if dropout > 0: self.dropout = nn.Dropout(p=dropout) else: self.dropout = lambda x: x def get_input_dims(self): return self.input_size def get_output_dims(self): return self.output_size def forward(self, x): """This function propagates forwardly Arguments: x {tensor} -- input tensor Returns: tenor -- Linear outputs """ output = self.activation(self.linear(x)) return self.dropout(self.layer_norm(output))