mlinmg commited on
Commit
09a868c
·
verified ·
1 Parent(s): cc7a4b1

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.json +11 -3
  2. gpt_config.py +84 -38
config.json CHANGED
@@ -23,7 +23,7 @@
23
  "sample_rate": 22050,
24
  "win_length": 1024
25
  },
26
- "batch_size": 1,
27
  "char_limits": {
28
  "ar": 166,
29
  "cs": 186,
@@ -43,13 +43,19 @@
43
  "zh": 82
44
  },
45
  "checkpointing": false,
 
46
  "code_stride_len": 1024,
47
  "cond_chunk_len": 4,
 
48
  "cond_len": 30,
 
 
 
49
  "duration_const": 102400,
50
  "embd_pdrop": 0.1,
51
  "enable_redaction": false,
52
  "hidden_size": 1024,
 
53
  "kv_cache": true,
54
  "label_smoothing": 0.0,
55
  "languages": [
@@ -80,10 +86,11 @@
80
  "model_type": "xtts_gpt",
81
  "n_inner": null,
82
  "num_attention_heads": 16,
83
- "num_audio_tokens": 1026,
84
  "num_chars": 255,
85
  "num_hidden_layers": 30,
86
  "number_text_tokens": 6681,
 
 
87
  "perceiver_cond_length_compression": 256,
88
  "reorder_and_upcast_attn": false,
89
  "resid_pdrop": 0.1,
@@ -93,9 +100,10 @@
93
  "start_text_token": null,
94
  "stop_audio_token": 1025,
95
  "stop_text_token": null,
 
96
  "train_solo_embeddings": false,
97
  "transformers_version": "4.46.0",
98
  "use_masking_gt_prompt_approach": true,
99
  "use_perceiver_resampler": true,
100
- "vocab_size": 256
101
  }
 
23
  "sample_rate": 22050,
24
  "win_length": 1024
25
  },
26
+ "batch_size": 32,
27
  "char_limits": {
28
  "ar": 166,
29
  "cs": 186,
 
43
  "zh": 82
44
  },
45
  "checkpointing": false,
46
+ "clvp_checkpoint": null,
47
  "code_stride_len": 1024,
48
  "cond_chunk_len": 4,
49
+ "cond_d_vector_in_each_upsampling_layer": true,
50
  "cond_len": 30,
51
+ "d_vector_dim": 512,
52
+ "decoder_checkpoint": null,
53
+ "decoder_input_dim": 1024,
54
  "duration_const": 102400,
55
  "embd_pdrop": 0.1,
56
  "enable_redaction": false,
57
  "hidden_size": 1024,
58
+ "input_sample_rate": 22050,
59
  "kv_cache": true,
60
  "label_smoothing": 0.0,
61
  "languages": [
 
86
  "model_type": "xtts_gpt",
87
  "n_inner": null,
88
  "num_attention_heads": 16,
 
89
  "num_chars": 255,
90
  "num_hidden_layers": 30,
91
  "number_text_tokens": 6681,
92
+ "output_hop_length": 256,
93
+ "output_sample_rate": 24000,
94
  "perceiver_cond_length_compression": 256,
95
  "reorder_and_upcast_attn": false,
96
  "resid_pdrop": 0.1,
 
100
  "start_text_token": null,
101
  "stop_audio_token": 1025,
102
  "stop_text_token": null,
103
+ "tokenizer_file": "",
104
  "train_solo_embeddings": false,
105
  "transformers_version": "4.46.0",
106
  "use_masking_gt_prompt_approach": true,
107
  "use_perceiver_resampler": true,
108
+ "vocab_size": 1026
109
  }
gpt_config.py CHANGED
@@ -1,11 +1,10 @@
1
- from dataclasses import asdict, dataclass, field
2
  from typing import Dict, Optional, List
3
  from transformers.configuration_utils import PretrainedConfig
4
  from transformers.utils import logging
5
 
6
  logger = logging.get_logger(__name__)
7
 
8
-
9
  @dataclass
10
  class XTTSAudioConfig:
11
  """Configuration for audio processing parameters"""
@@ -20,15 +19,14 @@ class XTTSAudioConfig:
20
  power: float = 1.0
21
  mel_norms_file: Optional[str] = None
22
 
23
-
24
  class XTTSGPTConfig(PretrainedConfig):
25
- """Configuration class for the GPT component of XTTS"""
26
  model_type = "xtts_gpt"
27
 
28
  def __init__(
29
  self,
30
  # Model architecture
31
- vocab_size: int = 256,
32
  hidden_size: int = 1024, # Changed from gpt_n_model_channels
33
  num_hidden_layers: int = 30, # Changed from gpt_layers
34
  num_attention_heads: int = 16, # Changed from gpt_n_heads
@@ -49,7 +47,6 @@ class XTTSGPTConfig(PretrainedConfig):
49
  number_text_tokens: int = 6681, # Changed from gpt_number_text_tokens
50
  start_text_token: Optional[int] = None, # Changed from gpt_start_text_token
51
  stop_text_token: Optional[int] = None, # Changed from gpt_stop_text_token
52
- num_audio_tokens: int = 1026, # Changed from gpt_num_audio_tokens
53
  start_audio_token: int = 1024, # Changed from gpt_start_audio_token
54
  stop_audio_token: int = 1025, # Changed from gpt_stop_audio_token
55
  code_stride_len: int = 1024, # Changed from gpt_code_stride_len
@@ -65,11 +62,6 @@ class XTTSGPTConfig(PretrainedConfig):
65
  label_smoothing: float = 0.0,
66
 
67
  # Generation parameters
68
- #temperature: float = 0.75,
69
- #length_penalty: float = 1.0,
70
- #repetition_penalty: float = 5.0,
71
- #top_k: int = 50,
72
- #top_p: float = 0.85,
73
  cond_len: int = 30, # Changed from gpt_cond_len
74
  cond_chunk_len: int = 4, # Changed from gpt_cond_chunk_len
75
  max_ref_len: int = 30,
@@ -82,17 +74,29 @@ class XTTSGPTConfig(PretrainedConfig):
82
  duration_const: int = 102400,
83
  char_limits: Optional[Dict[str, int]] = None,
84
  languages: Optional[List[str]] = None,
85
- pad_token_id: Optional[int] = None,
86
- bos_token_id: Optional[int] = None,
87
- eos_token_id: Optional[int] = None,
88
 
89
  # GPT-2 compatibility flags
90
  scale_attn_by_inverse_layer_idx: bool = False,
91
  reorder_and_upcast_attn: bool = False,
92
  add_cross_attention: bool = False,
93
  tie_word_embeddings: bool = True,
94
- **kwargs,
95
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  if char_limits is None:
97
  char_limits = {
98
  "en": 250, "de": 253, "fr": 273, "es": 239,
@@ -101,22 +105,21 @@ class XTTSGPTConfig(PretrainedConfig):
101
  "tr": 226, "ja": 71, "hu": 224, "ko": 95,
102
  }
103
 
 
104
  if languages is None:
105
  languages = [
106
  "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl",
107
  "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
108
  ]
109
 
110
- if audio_config is None:
111
- audio_config = XTTSAudioConfig()
112
-
113
  super().__init__(
114
- pad_token_id=pad_token_id,
115
- bos_token_id=bos_token_id,
116
- eos_token_id=eos_token_id,
117
  **kwargs
118
  )
119
 
 
120
  self.vocab_size = vocab_size
121
  self.hidden_size = hidden_size
122
  self.num_hidden_layers = num_hidden_layers
@@ -129,7 +132,7 @@ class XTTSGPTConfig(PretrainedConfig):
129
  self.embd_pdrop = embd_pdrop
130
  self.attn_pdrop = attn_pdrop
131
 
132
- # XTTS specific parameters
133
  self.num_chars = num_chars
134
  self.batch_size = batch_size
135
  self.max_audio_tokens = max_audio_tokens
@@ -138,7 +141,6 @@ class XTTSGPTConfig(PretrainedConfig):
138
  self.number_text_tokens = number_text_tokens
139
  self.start_text_token = start_text_token
140
  self.stop_text_token = stop_text_token
141
- self.num_audio_tokens = num_audio_tokens
142
  self.start_audio_token = start_audio_token
143
  self.stop_audio_token = stop_audio_token
144
  self.code_stride_len = code_stride_len
@@ -147,48 +149,92 @@ class XTTSGPTConfig(PretrainedConfig):
147
  self.checkpointing = checkpointing
148
  self.train_solo_embeddings = train_solo_embeddings
149
 
150
- # Training parameters
151
  self.enable_redaction = enable_redaction
152
  self.kv_cache = kv_cache
153
  self.perceiver_cond_length_compression = perceiver_cond_length_compression
154
  self.label_smoothing = label_smoothing
155
 
156
- # Generation parameters
157
- #self.temperature = temperature
158
- #self.length_penalty = length_penalty
159
- #self.repetition_penalty = repetition_penalty
160
- #self.top_k = top_k
161
- #self.top_p = top_p
162
  self.cond_len = cond_len
163
  self.cond_chunk_len = cond_chunk_len
164
  self.max_ref_len = max_ref_len
165
  self.sound_norm_refs = sound_norm_refs
166
 
167
- # Audio processing
168
  self.audio_config = audio_config
169
-
170
- # Constants and limits
171
  self.duration_const = duration_const
172
  self.char_limits = char_limits
173
  self.languages = languages
174
 
175
- # GPT-2 compatibility flags
176
  self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
177
  self.reorder_and_upcast_attn = reorder_and_upcast_attn
178
  self.add_cross_attention = add_cross_attention
179
  self.tie_word_embeddings = tie_word_embeddings
180
 
181
- def to_dict(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  """Convert config to dictionary"""
183
  config_dict = super().to_dict()
184
  config_dict["audio_config"] = asdict(self.audio_config)
185
  return config_dict
186
 
187
  @classmethod
188
- def from_dict(cls, config_dict, *args, **kwargs):
189
  """Create config from dictionary"""
190
- audio_config = XTTSAudioConfig(**config_dict.pop("audio_config", {}))
191
- return cls(audio_config=audio_config, **config_dict, **kwargs)
 
 
192
 
193
  def update_with_tokenizer(self, tokenizer=None):
194
  """Update configuration values based on tokenizer"""
 
1
+ from dataclasses import asdict, dataclass
2
  from typing import Dict, Optional, List
3
  from transformers.configuration_utils import PretrainedConfig
4
  from transformers.utils import logging
5
 
6
  logger = logging.get_logger(__name__)
7
 
 
8
  @dataclass
9
  class XTTSAudioConfig:
10
  """Configuration for audio processing parameters"""
 
19
  power: float = 1.0
20
  mel_norms_file: Optional[str] = None
21
 
 
22
  class XTTSGPTConfig(PretrainedConfig):
23
+ """Configuration class for the GPT component of XTTS with automatic legacy conversion"""
24
  model_type = "xtts_gpt"
25
 
26
  def __init__(
27
  self,
28
  # Model architecture
29
+ vocab_size: int = 1026, # num_audio_tokens
30
  hidden_size: int = 1024, # Changed from gpt_n_model_channels
31
  num_hidden_layers: int = 30, # Changed from gpt_layers
32
  num_attention_heads: int = 16, # Changed from gpt_n_heads
 
47
  number_text_tokens: int = 6681, # Changed from gpt_number_text_tokens
48
  start_text_token: Optional[int] = None, # Changed from gpt_start_text_token
49
  stop_text_token: Optional[int] = None, # Changed from gpt_stop_text_token
 
50
  start_audio_token: int = 1024, # Changed from gpt_start_audio_token
51
  stop_audio_token: int = 1025, # Changed from gpt_stop_audio_token
52
  code_stride_len: int = 1024, # Changed from gpt_code_stride_len
 
62
  label_smoothing: float = 0.0,
63
 
64
  # Generation parameters
 
 
 
 
 
65
  cond_len: int = 30, # Changed from gpt_cond_len
66
  cond_chunk_len: int = 4, # Changed from gpt_cond_chunk_len
67
  max_ref_len: int = 30,
 
74
  duration_const: int = 102400,
75
  char_limits: Optional[Dict[str, int]] = None,
76
  languages: Optional[List[str]] = None,
77
+
 
 
78
 
79
  # GPT-2 compatibility flags
80
  scale_attn_by_inverse_layer_idx: bool = False,
81
  reorder_and_upcast_attn: bool = False,
82
  add_cross_attention: bool = False,
83
  tie_word_embeddings: bool = True,
84
+ **kwargs
85
  ):
86
+ # Handle legacy config conversion
87
+ if any(k.startswith('gpt_') for k in kwargs):
88
+ kwargs = self._convert_legacy_config(kwargs)
89
+
90
+ if 'model_args' in kwargs:
91
+ kwargs = self._convert_legacy_config(kwargs['model_args'])
92
+
93
+ # Initialize audio config
94
+ if audio_config is None:
95
+ audio_config = XTTSAudioConfig()
96
+ elif isinstance(audio_config, dict):
97
+ audio_config = XTTSAudioConfig(**audio_config)
98
+
99
+ # Set default char limits
100
  if char_limits is None:
101
  char_limits = {
102
  "en": 250, "de": 253, "fr": 273, "es": 239,
 
105
  "tr": 226, "ja": 71, "hu": 224, "ko": 95,
106
  }
107
 
108
+ # Set default languages
109
  if languages is None:
110
  languages = [
111
  "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl",
112
  "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
113
  ]
114
 
 
 
 
115
  super().__init__(
116
+ pad_token_id=kwargs.pop('pad_token_id', None),
117
+ bos_token_id=kwargs.pop('bos_token_id', None),
118
+ eos_token_id=kwargs.pop('eos_token_id', None),
119
  **kwargs
120
  )
121
 
122
+ # Set all attributes
123
  self.vocab_size = vocab_size
124
  self.hidden_size = hidden_size
125
  self.num_hidden_layers = num_hidden_layers
 
132
  self.embd_pdrop = embd_pdrop
133
  self.attn_pdrop = attn_pdrop
134
 
135
+ # XTTS specific
136
  self.num_chars = num_chars
137
  self.batch_size = batch_size
138
  self.max_audio_tokens = max_audio_tokens
 
141
  self.number_text_tokens = number_text_tokens
142
  self.start_text_token = start_text_token
143
  self.stop_text_token = stop_text_token
 
144
  self.start_audio_token = start_audio_token
145
  self.stop_audio_token = stop_audio_token
146
  self.code_stride_len = code_stride_len
 
149
  self.checkpointing = checkpointing
150
  self.train_solo_embeddings = train_solo_embeddings
151
 
152
+ # Training
153
  self.enable_redaction = enable_redaction
154
  self.kv_cache = kv_cache
155
  self.perceiver_cond_length_compression = perceiver_cond_length_compression
156
  self.label_smoothing = label_smoothing
157
 
158
+ # Generation
 
 
 
 
 
159
  self.cond_len = cond_len
160
  self.cond_chunk_len = cond_chunk_len
161
  self.max_ref_len = max_ref_len
162
  self.sound_norm_refs = sound_norm_refs
163
 
164
+ # Audio and other
165
  self.audio_config = audio_config
 
 
166
  self.duration_const = duration_const
167
  self.char_limits = char_limits
168
  self.languages = languages
169
 
170
+ # GPT-2 flags
171
  self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
172
  self.reorder_and_upcast_attn = reorder_and_upcast_attn
173
  self.add_cross_attention = add_cross_attention
174
  self.tie_word_embeddings = tie_word_embeddings
175
 
176
+ @staticmethod
177
+ def _convert_legacy_config(config_dict: Dict) -> Dict:
178
+ """Converts legacy config format to new format."""
179
+ mapping = {
180
+ 'gpt_batch_size': 'batch_size',
181
+ 'gpt_max_audio_tokens': 'max_audio_tokens',
182
+ 'gpt_max_text_tokens': 'max_text_tokens',
183
+ 'gpt_max_prompt_tokens': 'max_prompt_tokens',
184
+ 'gpt_layers': 'num_hidden_layers',
185
+ 'gpt_n_model_channels': 'hidden_size',
186
+ 'gpt_n_heads': 'num_attention_heads',
187
+ 'gpt_number_text_tokens': 'number_text_tokens',
188
+ 'gpt_start_text_token': 'start_text_token',
189
+ 'gpt_stop_text_token': 'stop_text_token',
190
+ 'gpt_num_audio_tokens': 'vocab_size',
191
+ 'gpt_start_audio_token': 'start_audio_token',
192
+ 'gpt_stop_audio_token': 'stop_audio_token',
193
+ 'gpt_code_stride_len': 'code_stride_len',
194
+ 'gpt_use_masking_gt_prompt_approach': 'use_masking_gt_prompt_approach',
195
+ 'gpt_use_perceiver_resampler': 'use_perceiver_resampler',
196
+ 'gpt_checkpointing': 'checkpointing',
197
+ 'gpt_train_solo_embeddings': 'train_solo_embeddings',
198
+ 'gpt_cond_len': 'cond_len',
199
+ 'gpt_cond_chunk_len': 'cond_chunk_len'
200
+ }
201
+
202
+ new_config = {}
203
+
204
+ # Convert keys
205
+ for old_key, new_key in mapping.items():
206
+ if old_key in config_dict:
207
+ new_config[new_key] = config_dict[old_key]
208
+
209
+ # Copy non-mapped keys
210
+ for k, v in config_dict.items():
211
+ if not k.startswith('gpt_') and k not in new_config:
212
+ new_config[k] = v
213
+
214
+ # Handle audio config
215
+ if 'input_sample_rate' in config_dict or 'output_sample_rate' in config_dict:
216
+ audio_config = {
217
+ 'sample_rate': config_dict.get('input_sample_rate', 22050),
218
+ 'output_sample_rate': config_dict.get('output_sample_rate', 24000),
219
+ 'hop_length': config_dict.get('output_hop_length', 256)
220
+ }
221
+ new_config['audio_config'] = audio_config
222
+
223
+ return new_config
224
+
225
+ def to_dict(self) -> Dict:
226
  """Convert config to dictionary"""
227
  config_dict = super().to_dict()
228
  config_dict["audio_config"] = asdict(self.audio_config)
229
  return config_dict
230
 
231
  @classmethod
232
+ def from_dict(cls, config_dict: Dict, **kwargs) -> 'XTTSGPTConfig':
233
  """Create config from dictionary"""
234
+ if isinstance(config_dict.get("audio_config"), dict):
235
+ audio_config = XTTSAudioConfig(**config_dict["audio_config"])
236
+ config_dict["audio_config"] = audio_config
237
+ return cls(**config_dict, **kwargs)
238
 
239
  def update_with_tokenizer(self, tokenizer=None):
240
  """Update configuration values based on tokenizer"""