|
import torch |
|
import tqdm |
|
from torch import nn |
|
from transformers import MT5EncoderModel, MT5PreTrainedModel |
|
|
|
class MT5EncoderWithProjection(MT5PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.mt5_encoder = MT5EncoderModel(config) |
|
self.projection = nn.Linear(config.d_model, config.d_model, bias=False) |
|
self.post_init() |
|
|
|
def forward(self, **input_args): |
|
hidden_states = self.mt5_encoder(**input_args).last_hidden_state |
|
mask = input_args['attention_mask'] |
|
batch_embeddings = torch.sum(hidden_states * mask[:, :, None], dim=1) / torch.sum(mask, dim=1)[:, None] |
|
batch_embeddings = self.projection(batch_embeddings) |
|
return batch_embeddings |
|
|