import torch import torch.nn as nn import torch.nn.functional as F from transformers import BertModel, PreTrainedModel, BertConfig, PretrainedConfig, AutoModel from typing import * class ConcatModelConfig(PretrainedConfig): model_type = "arctic-s-bge-small" def __init__(self, **kwargs): super().__init__(**kwargs) # See https://huggingface.co./Marqo/marqo-chimera-arctic-bge-m class ConcatModel(PreTrainedModel): config_class = ConcatModelConfig def __init__(self, config: ConcatModelConfig): super().__init__(config) bert_config_1 = BertConfig( vocab_size=30522, hidden_size=384, num_hidden_layers=12, num_attention_heads=12, intermediate_size=1536, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, ) bert_config_2 = BertConfig( vocab_size=30522, hidden_size=384, num_hidden_layers=12, num_attention_heads=12, intermediate_size=1536, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, ) self.model = nn.ModuleDict( { "model_0": BertModel(bert_config_1), "model_1": BertModel(bert_config_2), } ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor = None, **kwargs ) -> torch.Tensor: embeddings = [] for _, model in self.model.items(): model_output = model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) pooled_output = model_output[0][:, 0] pooled_output = F.normalize(pooled_output, p=2, dim=-1) embeddings.append(pooled_output) return torch.cat(embeddings, dim=-1) def load_weights_from_automodels( self, in_models: List[str], has_pooling_layer: List[bool] ): model_list = [] for i, model_name in enumerate(in_models): model = AutoModel.from_pretrained( model_name, add_pooling_layer=has_pooling_layer[i], trust_remote_code=True, ) model.eval() model_list.append(model) self.model = nn.ModuleDict( {f"model_{i}": model for i, model in enumerate(model_list)} )