|
import logging |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from transformers import AutoModel |
|
|
|
from utils.nn_utils import gelu |
|
from modules.token_embedders.bert_encoder import BertLinear |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class PretrainedEncoder(nn.Module): |
|
"""This class using pre-trained model to encode token, |
|
then fine-tuning the pre-trained model |
|
""" |
|
def __init__(self, pretrained_model_name, trainable=False, output_size=0, activation=gelu, dropout=0.0): |
|
"""This function initialize pertrained model |
|
|
|
Arguments: |
|
pretrained_model_name {str} -- pre-trained model name |
|
|
|
Keyword Arguments: |
|
output_size {float} -- output size (default: {None}) |
|
activation {nn.Module} -- activation function (default: {gelu}) |
|
dropout {float} -- dropout rate (default: {0.0}) |
|
""" |
|
|
|
super().__init__() |
|
self.pretrained_model = AutoModel.from_pretrained(pretrained_model_name) |
|
logger.info("Load pre-trained model {} successfully.".format(pretrained_model_name)) |
|
|
|
self.output_size = output_size |
|
|
|
if trainable: |
|
logger.info("Start fine-tuning pre-trained model {}.".format(pretrained_model_name)) |
|
else: |
|
logger.info("Keep fixed pre-trained model {}.".format(pretrained_model_name)) |
|
|
|
for param in self.pretrained_model.parameters(): |
|
param.requires_grad = trainable |
|
|
|
if self.output_size > 0: |
|
self.mlp = BertLinear(input_size=self.pretrained_model.config.hidden_size, |
|
output_size=self.output_size, |
|
activation=activation) |
|
else: |
|
self.output_size = self.pretrained_model.config.hidden_size |
|
self.mlp = lambda x: x |
|
|
|
if dropout > 0: |
|
self.dropout = nn.Dropout(p=dropout) |
|
else: |
|
self.dropout = lambda x: x |
|
|
|
def get_output_dims(self): |
|
return self.output_size |
|
|
|
def forward(self, seq_inputs, token_type_inputs=None): |
|
"""forward calculates forward propagation results, get token embedding |
|
|
|
Args: |
|
seq_inputs {tensor} -- sequence inputs (tokenized) |
|
token_type_inputs (tensor, optional): token type inputs. Defaults to None. |
|
|
|
Returns: |
|
tensor: bert output for tokens |
|
""" |
|
|
|
if token_type_inputs is None: |
|
token_type_inputs = torch.zeros_like(seq_inputs) |
|
mask_inputs = (seq_inputs != 0).long() |
|
|
|
outputs = self.pretrained_model(input_ids=seq_inputs, |
|
token_type_ids=token_type_inputs, |
|
attention_mask=mask_inputs) |
|
last_hidden_state = outputs[0] |
|
pooled_output = outputs[1] |
|
|
|
return self.dropout(self.mlp(last_hidden_state)), self.dropout(self.mlp(pooled_output)) |
|
|