Compact_Facts / modules /token_embedders /pretrained_encoder.py
khulnasoft's picture
Upload 108 files
4fb0bd1 verified
import logging
import torch
import torch.nn as nn
from transformers import AutoModel
from utils.nn_utils import gelu
from modules.token_embedders.bert_encoder import BertLinear
logger = logging.getLogger(__name__)
class PretrainedEncoder(nn.Module):
"""This class using pre-trained model to encode token,
then fine-tuning the pre-trained model
"""
def __init__(self, pretrained_model_name, trainable=False, output_size=0, activation=gelu, dropout=0.0):
"""This function initialize pertrained model
Arguments:
pretrained_model_name {str} -- pre-trained model name
Keyword Arguments:
output_size {float} -- output size (default: {None})
activation {nn.Module} -- activation function (default: {gelu})
dropout {float} -- dropout rate (default: {0.0})
"""
super().__init__()
self.pretrained_model = AutoModel.from_pretrained(pretrained_model_name)
logger.info("Load pre-trained model {} successfully.".format(pretrained_model_name))
self.output_size = output_size
if trainable:
logger.info("Start fine-tuning pre-trained model {}.".format(pretrained_model_name))
else:
logger.info("Keep fixed pre-trained model {}.".format(pretrained_model_name))
for param in self.pretrained_model.parameters():
param.requires_grad = trainable
if self.output_size > 0:
self.mlp = BertLinear(input_size=self.pretrained_model.config.hidden_size,
output_size=self.output_size,
activation=activation)
else:
self.output_size = self.pretrained_model.config.hidden_size
self.mlp = lambda x: x
if dropout > 0:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = lambda x: x
def get_output_dims(self):
return self.output_size
def forward(self, seq_inputs, token_type_inputs=None):
"""forward calculates forward propagation results, get token embedding
Args:
seq_inputs {tensor} -- sequence inputs (tokenized)
token_type_inputs (tensor, optional): token type inputs. Defaults to None.
Returns:
tensor: bert output for tokens
"""
if token_type_inputs is None:
token_type_inputs = torch.zeros_like(seq_inputs)
mask_inputs = (seq_inputs != 0).long()
outputs = self.pretrained_model(input_ids=seq_inputs,
token_type_ids=token_type_inputs,
attention_mask=mask_inputs)
last_hidden_state = outputs[0]
pooled_output = outputs[1]
return self.dropout(self.mlp(last_hidden_state)), self.dropout(self.mlp(pooled_output))