ing0 commited on
Commit
97f0d6e
·
1 Parent(s): ea04001
app.py CHANGED
@@ -13,7 +13,7 @@ from tqdm import tqdm
13
  import random
14
  import numpy as np
15
  import sys
16
- from diffrhythm.infer.infer_utils import (
17
  get_reference_latent,
18
  get_lrc_token,
19
  get_style_prompt,
 
13
  import random
14
  import numpy as np
15
  import sys
16
+ from huggface_diffrhythm.space.DiffRhythm.diffrhythm.infer.infer_utils import (
17
  get_reference_latent,
18
  get_lrc_token,
19
  get_style_prompt,
diffrhythm/g2p/g2p/mandarin.py CHANGED
@@ -187,7 +187,10 @@ with open(
187
  ) as fread:
188
  txt_list = fread.readlines()
189
  for txt in txt_list:
190
- word, pinyin = txt.strip().split("\t")
 
 
 
191
  word_pinyin_dict[word] = pinyin
192
  fread.close()
193
 
 
187
  ) as fread:
188
  txt_list = fread.readlines()
189
  for txt in txt_list:
190
+ try:
191
+ word, pinyin = txt.strip().split("\t")
192
+ except:
193
+ print(txt.strip())
194
  word_pinyin_dict[word] = pinyin
195
  fread.close()
196
 
diffrhythm/model/cfm.py CHANGED
@@ -193,25 +193,28 @@ class CFM(nn.Module):
193
  # test for no ref audio
194
  if no_ref_audio:
195
  cond = torch.zeros_like(cond)
 
 
 
 
 
 
 
 
 
196
 
197
 
198
  def fn(t, x):
199
- # at each step, conditioning is fixed
200
- # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
201
-
202
- # predict flow
203
  pred = self.transformer(
204
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, drop_prompt=False,
205
- style_prompt=style_prompt, style_prompt_lens=style_prompt_lens, start_time=start_time
206
  )
207
- if cfg_strength < 1e-5:
208
- return pred
209
 
210
- null_pred = self.transformer(
211
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, drop_prompt=False,
212
- style_prompt=negative_style_prompt, style_prompt_lens=style_prompt_lens, start_time=start_time
213
- )
214
- return pred + (pred - null_pred) * cfg_strength
215
 
216
  # noise input
217
  # to make sure batch inference result is same with different batch size, and for sure single inference
 
193
  # test for no ref audio
194
  if no_ref_audio:
195
  cond = torch.zeros_like(cond)
196
+
197
+ start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
198
+ _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
199
+
200
+ text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
201
+ text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
202
+ step_cond = torch.cat([step_cond, step_cond], 0)
203
+ style_prompt = torch.cat([style_prompt, negative_style_prompt], 0)
204
+ start_time_embed = torch.cat([start_time_embed, start_time_embed], 0)
205
 
206
 
207
  def fn(t, x):
208
+ x = torch.cat([x, x], 0)
 
 
 
209
  pred = self.transformer(
210
+ x=x, text_embed=text_embed, text_residuals=text_residuals, cond=step_cond, time=t,
211
+ drop_audio_cond=True, drop_prompt=False, style_prompt=style_prompt, start_time=start_time_embed
212
  )
 
 
213
 
214
+ positive_pred, negative_pred = pred.chunk(2, 0)
215
+ cfg_pred = positive_pred + (positive_pred - negative_pred) * cfg_strength
216
+
217
+ return cfg_pred
 
218
 
219
  # noise input
220
  # to make sure batch inference result is same with different batch size, and for sure single inference
diffrhythm/model/dit.py CHANGED
@@ -15,7 +15,7 @@ import torch
15
  import torch.nn.functional as F
16
 
17
  from x_transformers.x_transformers import RotaryEmbedding
18
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
19
  from transformers.models.llama import LlamaConfig
20
  from torch.utils.checkpoint import checkpoint
21
 
@@ -28,7 +28,8 @@ from diffrhythm.model.modules import (
28
  precompute_freqs_cis,
29
  get_pos_embed_indices,
30
  )
31
-
 
32
 
33
  # Text embedding
34
 
@@ -134,9 +135,11 @@ class DiT(nn.Module):
134
  #)
135
  llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
136
  llama_config._attn_implementation = 'sdpa'
 
137
  self.transformer_blocks = nn.ModuleList(
138
  [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
139
  )
 
140
  self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
141
 
142
  self.text_fusion_linears = nn.ModuleList(
@@ -157,60 +160,53 @@ class DiT(nn.Module):
157
  # if use_style_prompt:
158
  # self.prompt_rnn = nn.LSTM(64, cond_dim, 1, batch_first=True)
159
 
 
 
 
 
 
 
 
 
 
160
 
161
  def forward(
162
  self,
163
  x: float["b n d"], # nosied input audio # noqa: F722
 
 
164
  cond: float["b n d"], # masked cond audio # noqa: F722
165
- text: int["b nt"], # text # noqa: F722
166
  time: float["b"] | float[""], # time step # noqa: F821 F722
167
  drop_audio_cond, # cfg for cond audio
168
- drop_text, # cfg for text
169
  drop_prompt=False,
170
  style_prompt=None, # [b d t]
171
- style_prompt_lens=None,
172
- mask: bool["b n"] | None = None, # noqa: F722
173
- grad_ckpt=False,
174
  start_time=None,
175
  ):
176
  batch, seq_len = x.shape[0], x.shape[1]
177
  if time.ndim == 0:
178
  time = time.repeat(batch)
179
 
180
- # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
181
  t = self.time_embed(time)
182
- s_t = self.start_time_embed(start_time)
183
- c = t + s_t
184
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
185
 
186
- # import pdb; pdb.set_trace()
187
  if drop_prompt:
188
  style_prompt = torch.zeros_like(style_prompt)
189
- # if self.training:
190
- # packed_style_prompt = torch.nn.utils.rnn.pack_padded_sequence(style_prompt.transpose(1, 2), style_prompt_lens.cpu(), batch_first=True, enforce_sorted=False)
191
- # else:
192
- # packed_style_prompt = style_prompt.transpose(1, 2)
193
- #print(packed_style_prompt.shape)
194
- # _, style_emb = self.prompt_rnn.forward(packed_style_prompt)
195
- # _, (h_n, c_n) = self.prompt_rnn.forward(packed_style_prompt)
196
- # style_emb = h_n.squeeze(0) # 1, B, dim -> B, dim
197
 
198
- style_emb = style_prompt # [b, 512]
199
 
200
- x = self.input_embed(x, cond, text_embed, style_emb, c, drop_audio_cond=drop_audio_cond)
201
 
202
  if self.long_skip_connection is not None:
203
  residual = x
204
 
205
  pos_ids = torch.arange(x.shape[1], device=x.device)
206
  pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
 
 
207
  for i, block in enumerate(self.transformer_blocks):
208
- if not grad_ckpt:
209
- x, *_ = block(x, position_ids=pos_ids)
210
- else:
211
- x, *_ = checkpoint(block, x, position_ids=pos_ids, use_reentrant=False)
212
  if i < self.depth // 2:
213
- x = x + self.text_fusion_linears[i](text_embed)
214
 
215
  if self.long_skip_connection is not None:
216
  x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
 
15
  import torch.nn.functional as F
16
 
17
  from x_transformers.x_transformers import RotaryEmbedding
18
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
19
  from transformers.models.llama import LlamaConfig
20
  from torch.utils.checkpoint import checkpoint
21
 
 
28
  precompute_freqs_cis,
29
  get_pos_embed_indices,
30
  )
31
+ from liger_kernel.transformers import apply_liger_kernel_to_llama
32
+ apply_liger_kernel_to_llama()
33
 
34
  # Text embedding
35
 
 
135
  #)
136
  llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
137
  llama_config._attn_implementation = 'sdpa'
138
+ #llama_config._attn_implementation = ''
139
  self.transformer_blocks = nn.ModuleList(
140
  [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
141
  )
142
+ self.rotary_emb = LlamaRotaryEmbedding(config=llama_config)
143
  self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
144
 
145
  self.text_fusion_linears = nn.ModuleList(
 
160
  # if use_style_prompt:
161
  # self.prompt_rnn = nn.LSTM(64, cond_dim, 1, batch_first=True)
162
 
163
+ def forward_timestep_invariant(self, text, seq_len, drop_text, start_time):
164
+ s_t = self.start_time_embed(start_time)
165
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
166
+ text_residuals = []
167
+ for layer in self.text_fusion_linears:
168
+ text_residual = layer(text_embed)
169
+ text_residuals.append(text_residual)
170
+ return s_t, text_embed, text_residuals
171
+
172
 
173
  def forward(
174
  self,
175
  x: float["b n d"], # nosied input audio # noqa: F722
176
+ text_embed: int["b nt"], # text # noqa: F722
177
+ text_residuals,
178
  cond: float["b n d"], # masked cond audio # noqa: F722
 
179
  time: float["b"] | float[""], # time step # noqa: F821 F722
180
  drop_audio_cond, # cfg for cond audio
 
181
  drop_prompt=False,
182
  style_prompt=None, # [b d t]
 
 
 
183
  start_time=None,
184
  ):
185
  batch, seq_len = x.shape[0], x.shape[1]
186
  if time.ndim == 0:
187
  time = time.repeat(batch)
188
 
 
189
  t = self.time_embed(time)
190
+ c = t + start_time
 
 
191
 
 
192
  if drop_prompt:
193
  style_prompt = torch.zeros_like(style_prompt)
 
 
 
 
 
 
 
 
194
 
195
+ style_embed = style_prompt # [b, 512]
196
 
197
+ x = self.input_embed(x, cond, text_embed, style_embed, c, drop_audio_cond=drop_audio_cond)
198
 
199
  if self.long_skip_connection is not None:
200
  residual = x
201
 
202
  pos_ids = torch.arange(x.shape[1], device=x.device)
203
  pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
204
+ rotary_embed = self.rotary_emb(x, pos_ids)
205
+
206
  for i, block in enumerate(self.transformer_blocks):
207
+ x, *_ = block(x, position_embeddings=rotary_embed)
 
 
 
208
  if i < self.depth // 2:
209
+ x = x + text_residuals[i]
210
 
211
  if self.long_skip_connection is not None:
212
  x = self.long_skip_connection(torch.cat((x, residual), dim=-1))