mrfakename commited on
Commit
648ac03
1 Parent(s): 5dc7366

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. scripts/prepare_csv_wavs.py +132 -0
scripts/prepare_csv_wavs.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ from pathlib import Path
5
+ import json
6
+ import shutil
7
+ import argparse
8
+
9
+ import csv
10
+ import torchaudio
11
+ from tqdm import tqdm
12
+ from datasets.arrow_writer import ArrowWriter
13
+
14
+ from model.utils import (
15
+ convert_char_to_pinyin,
16
+ )
17
+
18
+ PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
19
+
20
+ def is_csv_wavs_format(input_dataset_dir):
21
+ fpath = Path(input_dataset_dir)
22
+ metadata = fpath / "metadata.csv"
23
+ wavs = fpath / 'wavs'
24
+ return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
25
+
26
+
27
+ def prepare_csv_wavs_dir(input_dir):
28
+ assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
29
+ input_dir = Path(input_dir)
30
+ metadata_path = input_dir / "metadata.csv"
31
+ audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
32
+
33
+ sub_result, durations = [], []
34
+ vocab_set = set()
35
+ polyphone = True
36
+ for audio_path, text in audio_path_text_pairs:
37
+ if not Path(audio_path).exists():
38
+ print(f"audio {audio_path} not found, skipping")
39
+ continue
40
+ audio_duration = get_audio_duration(audio_path)
41
+ # assume tokenizer = "pinyin" ("pinyin" | "char")
42
+ text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
43
+ sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
44
+ durations.append(audio_duration)
45
+ vocab_set.update(list(text))
46
+
47
+ return sub_result, durations, vocab_set
48
+
49
+ def get_audio_duration(audio_path):
50
+ audio, sample_rate = torchaudio.load(audio_path)
51
+ num_channels = audio.shape[0]
52
+ return audio.shape[1] / (sample_rate * num_channels)
53
+
54
+ def read_audio_text_pairs(csv_file_path):
55
+ audio_text_pairs = []
56
+
57
+ parent = Path(csv_file_path).parent
58
+ with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
59
+ reader = csv.reader(csvfile, delimiter='|')
60
+ next(reader) # Skip the header row
61
+ for row in reader:
62
+ if len(row) >= 2:
63
+ audio_file = row[0].strip() # First column: audio file path
64
+ text = row[1].strip() # Second column: text
65
+ audio_file_path = parent / audio_file
66
+ audio_text_pairs.append((audio_file_path.as_posix(), text))
67
+
68
+ return audio_text_pairs
69
+
70
+
71
+ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
72
+ out_dir = Path(out_dir)
73
+ # save preprocessed dataset to disk
74
+ out_dir.mkdir(exist_ok=True, parents=True)
75
+ print(f"\nSaving to {out_dir} ...")
76
+
77
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
78
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
79
+ raw_arrow_path = out_dir / "raw.arrow"
80
+ with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
81
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
82
+ writer.write(line)
83
+
84
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
85
+ dur_json_path = out_dir / "duration.json"
86
+ with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
87
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
88
+
89
+ # vocab map, i.e. tokenizer
90
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
91
+ # if tokenizer == "pinyin":
92
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
93
+ voca_out_path = out_dir / "vocab.txt"
94
+ with open(voca_out_path.as_posix(), "w") as f:
95
+ for vocab in sorted(text_vocab_set):
96
+ f.write(vocab + "\n")
97
+
98
+ if is_finetune:
99
+ file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
100
+ shutil.copy2(file_vocab_finetune, voca_out_path)
101
+ else:
102
+ with open(voca_out_path, "w") as f:
103
+ for vocab in sorted(text_vocab_set):
104
+ f.write(vocab + "\n")
105
+
106
+ dataset_name = out_dir.stem
107
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
108
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
109
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
110
+
111
+
112
+ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
113
+ if is_finetune:
114
+ assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
115
+ sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
116
+ save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
117
+
118
+
119
+ def cli():
120
+ # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
121
+ # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
122
+ parser = argparse.ArgumentParser(description="Prepare and save dataset.")
123
+ parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
124
+ parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
125
+ parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
126
+
127
+ args = parser.parse_args()
128
+
129
+ prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
130
+
131
+ if __name__ == "__main__":
132
+ cli()