PolyAI-pheme / data /collation.py
taras-sereda's picture
minimal set of files to run inference; pheme-small checkpoint
96ee597
raw
history blame
5.72 kB
"""Collators for T2S and S2A.
Copyright PolyAI Limited.
"""
from pathlib import Path
from typing import List, Tuple, Union
import numpy as np
import torch
from utils.symbol_table import SymbolTable
class GlobalCollater:
def __init__(self, n_codes, n_semantic_codes):
self.n_codes = n_codes
self.sem_mask_id = n_semantic_codes
def collate(self, batch):
output = {
'speaker': [],
'tts_quantize_input': [],
'tts_quantize_output': [],
'quantize_mask': [],
'f_names': [],
'semantic_tokens': [],
'quantization_lengths': [],
}
# Get the max length of everything
max_len_q = 0
for _, q_s, q_e, _, _ in batch:
if len(q_s) > max_len_q:
max_len_q = len(q_s)
output['quantization_lengths'].append(len(q_s))
# Pad each element, create mask
for spkr, qs, qe, itm_name, s_tokens in batch:
# Deal with quantizations
q_mask = np.array(
[False] * len(qs) + [True] * (max_len_q - len(qs)))
qs = np.pad(
qs,
[[0, max_len_q-len(qs)], [0, 0]],
constant_values=self.n_codes
)
qe = np.pad(
qe,
[[0, max_len_q-len(qe)], [0, 0]],
constant_values=self.n_codes
)
# Deal with semantics
s_tokens = s_tokens.flatten()
s_tokens = np.pad(
s_tokens,
(0, max_len_q-len(s_tokens)),
constant_values=self.sem_mask_id
)
# Speaker padding
spkr = np.concatenate(
(spkr, np.zeros((max_len_q - len(spkr), 512))))
# Aggregate
output['speaker'].append(spkr)
output['tts_quantize_input'].append(qs)
output['tts_quantize_output'].append(qe)
output['quantize_mask'].append(q_mask)
output['f_names'].append(itm_name)
output["semantic_tokens"].append(s_tokens)
for k in output.keys():
if k == 'f_names':
continue
output[k] = np.array(output[k])
if 'mask' in k:
output[k] = torch.BoolTensor(output[k])
elif k in [
'tts_quantize_input', 'tts_quantize_output',
'semantic_tokens', 'quantization_lengths'
]:
output[k] = torch.LongTensor(output[k])
else:
output[k] = torch.FloatTensor(output[k])
return output
class TextTokenCollater:
def __init__(
self,
text_tokens: List[str],
add_eos: bool = True,
add_bos: bool = True,
pad_symbol: str = "<pad>",
bos_symbol: str = "<bos>",
eos_symbol: str = "<eos>",
spkr_1_symbol: str = "spkr_1",
spkr_2_symbol: str = "spkr_2",
):
self.pad_symbol = pad_symbol
self.add_eos = add_eos
self.add_bos = add_bos
self.bos_symbol = bos_symbol
self.eos_symbol = eos_symbol
self.spkr_1_symbol = spkr_1_symbol
self.spkr_2_symbol = spkr_2_symbol
unique_tokens = (
[pad_symbol]
+ ([bos_symbol] if add_bos else [])
+ ([eos_symbol] if add_eos else [])
+ ([spkr_1_symbol])
+ ([spkr_2_symbol])
+ sorted(text_tokens)
)
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
self.idx2token = [token for token in unique_tokens]
def __call__(
self, texts: List[str], texts_2: Union[None, List[str]] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
tokens_seqs = [[p for p in text] for text in texts]
if texts_2 is None:
seqs = [
([self.bos_symbol] if self.add_bos else [])
+ [self.spkr_1_symbol]
+ list(seq)
+ ([self.eos_symbol] if self.add_eos else [])
for seq in tokens_seqs
]
else:
tokens_seqs_2 = [[p for p in text] for text in texts_2]
seqs = [
([self.bos_symbol] if self.add_bos else [])
+ [self.spkr_1_symbol]
+ list(seq)
+ ([self.spkr_2_symbol])
+ list(seq_2)
+ ([self.eos_symbol] if self.add_eos else [])
for seq, seq_2 in zip(tokens_seqs, tokens_seqs_2)
]
tokens_batch = torch.from_numpy(
np.array(
[[self.token2idx[token] for token in seq] for seq in seqs],
dtype=np.int64,
)
)
return tokens_batch
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
text_tokens_path = Path(text_tokens_file)
unique_tokens = SymbolTable.from_file(text_tokens_path)
collater = TextTokenCollater(
unique_tokens.symbols, add_bos=True, add_eos=True
)
return collater
def get_text_semantic_token_collater(
text_tokens_file: str, n_semantic_tokens=1024) -> TextTokenCollater:
text_tokens_path = Path(text_tokens_file)
unique_tokens = SymbolTable.from_file(text_tokens_path)
for semantic_idx in range(n_semantic_tokens):
unique_tokens.add(str(semantic_idx))
collater = TextTokenCollater(
unique_tokens.symbols, add_bos=True, add_eos=True
)
return collater
if __name__ == '__main__':
text_tokens_file = 'ckpt/unique_text_tokens.k2symbols'
collater = get_text_semantic_token_collater(text_tokens_file)