File size: 3,660 Bytes
853e052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from transformers import LlamaConfig, LlamaModel, PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import KwargsForCausalLM
from transformers.processing_utils import Unpack

from configuration_speechunit import SpeechUnitConfig


# Copied from transformer.models.llama.modeling_llama.LlamaPreTrainedModel class
class SpeechUnitPreTrainedModel(PreTrainedModel):
    config_class = SpeechUnitConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["LlamaDecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, SpeechUnitModel):
            src_model = LlamaModel.from_pretrained(self.config.base_model_id)
            with torch.no_grad():
                for name, param in module.llama_model.named_parameters():
                    param.copy_(src_model.state_dict()[name])

class SpeechUnitModel(SpeechUnitPreTrainedModel):
    def __init__(self, config: SpeechUnitConfig):
        super(SpeechUnitModel, self).__init__(config)
 
        # Initialize LLaMA model and load weights
        llama_config = LlamaConfig.from_pretrained(config.base_model_id)
        llama_config.num_hidden_layers = config.num_hidden_layers
        self.llama_model = LlamaModel._from_config(llama_config)
 
        # Embedding layers
        original_vocab_size, embed_dim = self.llama_model.embed_tokens.weight.shape
 
        # Audio embeddings (16400 = codebook size + 2 for BOS and EOS tokens)
        self.audio_embed = nn.Embedding(16400, embed_dim)
        nn.init.xavier_uniform_(self.audio_embed.weight.data)
 
        # Learnable weights for token integration
        self.token_weights = nn.Parameter(torch.ones(config.num_heads))
 
        # Prediction heads
        self.heads = nn.ModuleList([nn.Linear(embed_dim, config.output_dim) for _ in range(config.num_heads)])
        
        self.post_init()
 
    def forward(self, 
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        # 參考 https://github.com/huggingface/transformers/blob/b05df6611e6e3e6834acca2b50baeb7cdd5fbe3c/src/transformers/models/llama/modeling_llama.py#L784
        pass