hubert-ecg-large / hubert_ecg.py
Edoardo-BS's picture
Upload HuBERTECG
cd67436 verified
import torch
import torch.nn as nn
from transformers import HubertConfig, HubertModel
from typing import List
class HuBERTECGConfig(HubertConfig):
model_type = "hubert_ecg"
def __init__(self, ensemble_length: int = 1, vocab_sizes: List[int] = [100], **kwargs):
super().__init__(**kwargs)
self.ensemble_length = ensemble_length
self.vocab_sizes = vocab_sizes if isinstance(vocab_sizes, list) else [vocab_sizes]
class HuBERTECG(HubertModel):
config_class = HuBERTECGConfig
def __init__(self, config: HuBERTECGConfig):
super().__init__(config)
self.config = config
self.pretraining_vocab_sizes = config.vocab_sizes
assert config.ensemble_length > 0 and config.ensemble_length == len(config.vocab_sizes), f"ensemble_length {config.ensemble_length} must be equal to len(vocab_sizes) {len(config.vocab_sizes)}"
# final projection layer to map encodings into the space of the codebook
self.final_proj = nn.ModuleList([nn.Linear(config.hidden_size, config.classifier_proj_size) for _ in range(config.ensemble_length)])
# embedding for codebooks
self.label_embedding = nn.ModuleList([nn.Embedding(vocab_size, config.classifier_proj_size) for vocab_size in config.vocab_sizes])
assert len(self.final_proj) == len(self.label_embedding), f"final_proj and label_embedding must have the same length"
def logits(self, transformer_output: torch.Tensor) -> torch.Tensor:
# takes (B, T, D)
# compute a projected output for each ensemble
projected_outputs = [final_projection(transformer_output) for final_projection in self.final_proj]
ensemble_logits = [torch.cosine_similarity(
projected_output.unsqueeze(2),
label_emb.weight.unsqueeze(0).unsqueeze(0),
dim=-1,
) / 0.1 for projected_output, label_emb in zip(projected_outputs, self.label_embedding)]
return ensemble_logits # returns [(BS, T, V)] * ensemble_length