|
from transformers import AutoModel, AutoTokenizer, AutoConfig |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from transformers import CONFIG_MAPPING, MODEL_MAPPING |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
|
|
|
|
class JinaJudgeConfig(PretrainedConfig): |
|
model_type = "jina-judge" |
|
|
|
def __init__(self, n_classes=3, hidden_dim=512, num_decoder_layers=4, nhead=8, dropout_prob=0.1, **kwargs): |
|
super().__init__(**kwargs) |
|
self.n_classes = n_classes |
|
self.hidden_dim = hidden_dim |
|
self.num_decoder_layers = num_decoder_layers |
|
self.nhead = nhead |
|
self.dropout_prob = dropout_prob |
|
|
|
|
|
class JinaJudge(PreTrainedModel): |
|
config_class = JinaJudgeConfig |
|
|
|
def __init__(self, config: JinaJudgeConfig): |
|
super().__init__(config) |
|
self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) |
|
jina_config = AutoConfig.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) |
|
self.encoder = AutoModel.from_config(jina_config, trust_remote_code=True, torch_dtype=torch.bfloat16) |
|
self.encoder.lora_main_params_trainable = True |
|
|
|
decoder_layer = nn.TransformerDecoderLayer( |
|
d_model=self.encoder.config.hidden_size, |
|
nhead=config.nhead, |
|
dim_feedforward=self.encoder.config.hidden_size, |
|
dropout=config.dropout_prob |
|
) |
|
|
|
self.decoder = nn.TransformerDecoder( |
|
decoder_layer, |
|
num_layers=config.num_decoder_layers |
|
) |
|
|
|
self.decoder_input_embedding = nn.Parameter( |
|
torch.randn(1, 1, self.encoder.config.hidden_size) |
|
) |
|
|
|
self.classification_head = nn.Sequential( |
|
nn.Linear(self.encoder.config.hidden_size, config.n_classes) |
|
) |
|
|
|
def forward(self, prompts): |
|
inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(self.device) |
|
encoder_outputs = self.encoder(**inputs) |
|
encoder_hidden_states = encoder_outputs.last_hidden_state.float() |
|
|
|
encoder_padding_mask = (inputs["attention_mask"] == 0).to(self.device) |
|
|
|
batch_size = encoder_hidden_states.size(0) |
|
decoder_input = self.decoder_input_embedding.expand(1, batch_size, -1).to(self.device) |
|
|
|
decoder_output = self.decoder( |
|
tgt=decoder_input, |
|
memory=encoder_hidden_states.transpose(0, 1), |
|
memory_key_padding_mask=encoder_padding_mask |
|
).squeeze(0) |
|
|
|
logits = self.classification_head(decoder_output) |
|
return logits |
|
|
|
|
|
AutoConfig.register("jina-judge", JinaJudgeConfig) |
|
AutoModel.register(JinaJudgeConfig, JinaJudge) |