|
|
|
""" |
|
Backbone modules. |
|
""" |
|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from torch import nn |
|
from typing import Dict, List |
|
|
|
from utils.misc import NestedTensor, is_main_process |
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
class BERT(nn.Module): |
|
def __init__(self, name: str, train_bert: bool, hidden_dim: int, max_len: int, enc_num): |
|
super().__init__() |
|
|
|
|
|
|
|
|
|
self.num_channels = 768 |
|
self.enc_num = enc_num |
|
|
|
self.bert = AutoModel.from_pretrained(name) |
|
|
|
if not train_bert: |
|
for parameter in self.bert.parameters(): |
|
parameter.requires_grad_(False) |
|
|
|
def forward(self, tensor_list: NestedTensor): |
|
|
|
if self.enc_num > 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
bert_output = self.bert(tensor_list.tensors, token_type_ids=None, attention_mask=tensor_list.mask) |
|
xs = bert_output.last_hidden_state |
|
else: |
|
xs = self.bert.embeddings.word_embeddings(tensor_list.tensors) |
|
|
|
mask = tensor_list.mask.to(torch.bool) |
|
mask = ~mask |
|
out = NestedTensor(xs, mask) |
|
|
|
return out |
|
|
|
def build_bert(args): |
|
|
|
train_bert = args.lr_bert > 0 |
|
bert = BERT(args.bert_model, train_bert, args.hidden_dim, args.max_query_len, args.bert_enc_num) |
|
|
|
|
|
return bert |
|
|