File size: 3,531 Bytes
96ee597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""T2S model definition.

Copyright PolyAI Limited.
"""
import os

import numpy as np
from torch import nn
from transformers import EvalPrediction, T5Config, T5ForConditionalGeneration

from data.collation import get_text_semantic_token_collater


def compute_custom_metrics(eval_prediction: EvalPrediction):
        # eval_prediction: tuple
        # eval_prediction[0]: tensor of decoder outputs(logits) (n_batch, n_semantic, n_tokens)  # noqa
        # eval_prediction[1]: tensor of encoder outputs (n_batch, n_text/n_phone, n_hidden)  # noqa
        logits = eval_prediction.predictions[0]
        labels = eval_prediction.label_ids
        n_vocab = logits.shape[-1]
        mask = labels == -100
        top_1 = np.argmax(logits, axis=-1) == labels
        top_1[mask] = False
        top_5 = np.argsort(logits, axis=-1)[:, :, -5:]
        top_5 = np.any(top_5 == np.expand_dims(labels, axis=-1), axis=-1)
        top_5[mask] = False

        top_10 = np.argsort(logits, axis=-1)[:, :, -10:]
        top_10 = np.any(top_10 == np.expand_dims(labels, axis=-1), axis=-1)
        top_10[mask] = False

        top_1_accuracy = np.sum(top_1) / np.sum(~mask)
        top_5_accuracy = np.sum(top_5) / np.sum(~mask)
        top_10_accuracy = np.sum(top_10) / np.sum(~mask)

        return {
            "top_1_accuracy": top_1_accuracy,
            "top_5_accuracy": top_5_accuracy,
            "top_10_accuracy": top_10_accuracy,
        }


class T2S(nn.Module):
    def __init__(self, hp):
        super().__init__()
        self.text_tokens_file = "ckpt/unique_text_tokens.k2symbols"
        self.collater = get_text_semantic_token_collater(self.text_tokens_file)
        self.model_size = hp.model_size
        self.vocab_size = len(self.collater.idx2token)
        self.config = self._define_model_config(self.model_size)
        
        print(f"{self.config = }")
        self.t2s = T5ForConditionalGeneration(self.config)

    def _define_model_config(self, model_size):
        if model_size == "test":
            # n_params = 16M
            d_ff = 16
            d_model = 8
            d_kv = 32
            num_heads = 1
            num_decoder_layers = 1
            num_layers = 1
        elif model_size == "tiny":
            # n_params = 16M
            d_ff = 1024
            d_model = 256
            d_kv = 32
            num_heads = 4
            num_decoder_layers = 4
            num_layers = 4
        elif model_size == "t5small":
            # n_params = 60M
            d_ff = 2048
            d_model = 512
            d_kv = 64
            num_heads = 8
            num_decoder_layers = 6
            num_layers = 6
        elif model_size == "large":
            # n_params = 100M
            d_ff = 2048
            d_model = 512
            d_kv = 64
            num_heads = 8
            num_decoder_layers = 14
            num_layers = 14
        elif model_size == "Large":
            # n_params = 114M
            d_ff = 4096
            d_model = 512
            d_kv = 64
            num_heads = 8
            num_decoder_layers = 6
            num_layers = 10
        else:
            raise ValueError(f"unknown {model_size}")

        config = T5Config(
            d_ff=d_ff,
            d_model=d_model,
            d_kv=d_kv,
            num_heads=num_heads,
            num_decoder_layers=num_decoder_layers,
            num_layers=num_layers,
            decoder_start_token_id=0,
            eos_token_id=2,
            vocab_size=self.vocab_size,
        )

        return config