|
import torch |
|
import torch.nn as nn |
|
from transformers import HubertConfig, HubertModel |
|
from typing import List |
|
|
|
class HuBERTECGConfig(HubertConfig): |
|
|
|
model_type = "hubert_ecg" |
|
|
|
def __init__(self, ensemble_length: int = 1, vocab_sizes: List[int] = [100], **kwargs): |
|
super().__init__(**kwargs) |
|
self.ensemble_length = ensemble_length |
|
self.vocab_sizes = vocab_sizes if isinstance(vocab_sizes, list) else [vocab_sizes] |
|
|
|
class HuBERTECG(HubertModel): |
|
|
|
config_class = HuBERTECGConfig |
|
|
|
def __init__(self, config: HuBERTECGConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.pretraining_vocab_sizes = config.vocab_sizes |
|
|
|
assert config.ensemble_length > 0 and config.ensemble_length == len(config.vocab_sizes), f"ensemble_length {config.ensemble_length} must be equal to len(vocab_sizes) {len(config.vocab_sizes)}" |
|
|
|
|
|
self.final_proj = nn.ModuleList([nn.Linear(config.hidden_size, config.classifier_proj_size) for _ in range(config.ensemble_length)]) |
|
|
|
|
|
self.label_embedding = nn.ModuleList([nn.Embedding(vocab_size, config.classifier_proj_size) for vocab_size in config.vocab_sizes]) |
|
|
|
assert len(self.final_proj) == len(self.label_embedding), f"final_proj and label_embedding must have the same length" |
|
|
|
def logits(self, transformer_output: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
projected_outputs = [final_projection(transformer_output) for final_projection in self.final_proj] |
|
|
|
ensemble_logits = [torch.cosine_similarity( |
|
projected_output.unsqueeze(2), |
|
label_emb.weight.unsqueeze(0).unsqueeze(0), |
|
dim=-1, |
|
) / 0.1 for projected_output, label_emb in zip(projected_outputs, self.label_embedding)] |
|
|
|
return ensemble_logits |