Spaces:
Running
on
Zero
Running
on
Zero
test
Browse files- app.py +1 -1
- diffrhythm/g2p/g2p/mandarin.py +4 -1
- diffrhythm/model/cfm.py +16 -13
- diffrhythm/model/dit.py +23 -27
app.py
CHANGED
@@ -13,7 +13,7 @@ from tqdm import tqdm
|
|
13 |
import random
|
14 |
import numpy as np
|
15 |
import sys
|
16 |
-
from diffrhythm.infer.infer_utils import (
|
17 |
get_reference_latent,
|
18 |
get_lrc_token,
|
19 |
get_style_prompt,
|
|
|
13 |
import random
|
14 |
import numpy as np
|
15 |
import sys
|
16 |
+
from huggface_diffrhythm.space.DiffRhythm.diffrhythm.infer.infer_utils import (
|
17 |
get_reference_latent,
|
18 |
get_lrc_token,
|
19 |
get_style_prompt,
|
diffrhythm/g2p/g2p/mandarin.py
CHANGED
@@ -187,7 +187,10 @@ with open(
|
|
187 |
) as fread:
|
188 |
txt_list = fread.readlines()
|
189 |
for txt in txt_list:
|
190 |
-
|
|
|
|
|
|
|
191 |
word_pinyin_dict[word] = pinyin
|
192 |
fread.close()
|
193 |
|
|
|
187 |
) as fread:
|
188 |
txt_list = fread.readlines()
|
189 |
for txt in txt_list:
|
190 |
+
try:
|
191 |
+
word, pinyin = txt.strip().split("\t")
|
192 |
+
except:
|
193 |
+
print(txt.strip())
|
194 |
word_pinyin_dict[word] = pinyin
|
195 |
fread.close()
|
196 |
|
diffrhythm/model/cfm.py
CHANGED
@@ -193,25 +193,28 @@ class CFM(nn.Module):
|
|
193 |
# test for no ref audio
|
194 |
if no_ref_audio:
|
195 |
cond = torch.zeros_like(cond)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
|
198 |
def fn(t, x):
|
199 |
-
|
200 |
-
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
201 |
-
|
202 |
-
# predict flow
|
203 |
pred = self.transformer(
|
204 |
-
x=x,
|
205 |
-
|
206 |
)
|
207 |
-
if cfg_strength < 1e-5:
|
208 |
-
return pred
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
return pred + (pred - null_pred) * cfg_strength
|
215 |
|
216 |
# noise input
|
217 |
# to make sure batch inference result is same with different batch size, and for sure single inference
|
|
|
193 |
# test for no ref audio
|
194 |
if no_ref_audio:
|
195 |
cond = torch.zeros_like(cond)
|
196 |
+
|
197 |
+
start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
|
198 |
+
_, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
|
199 |
+
|
200 |
+
text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
|
201 |
+
text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
|
202 |
+
step_cond = torch.cat([step_cond, step_cond], 0)
|
203 |
+
style_prompt = torch.cat([style_prompt, negative_style_prompt], 0)
|
204 |
+
start_time_embed = torch.cat([start_time_embed, start_time_embed], 0)
|
205 |
|
206 |
|
207 |
def fn(t, x):
|
208 |
+
x = torch.cat([x, x], 0)
|
|
|
|
|
|
|
209 |
pred = self.transformer(
|
210 |
+
x=x, text_embed=text_embed, text_residuals=text_residuals, cond=step_cond, time=t,
|
211 |
+
drop_audio_cond=True, drop_prompt=False, style_prompt=style_prompt, start_time=start_time_embed
|
212 |
)
|
|
|
|
|
213 |
|
214 |
+
positive_pred, negative_pred = pred.chunk(2, 0)
|
215 |
+
cfg_pred = positive_pred + (positive_pred - negative_pred) * cfg_strength
|
216 |
+
|
217 |
+
return cfg_pred
|
|
|
218 |
|
219 |
# noise input
|
220 |
# to make sure batch inference result is same with different batch size, and for sure single inference
|
diffrhythm/model/dit.py
CHANGED
@@ -15,7 +15,7 @@ import torch
|
|
15 |
import torch.nn.functional as F
|
16 |
|
17 |
from x_transformers.x_transformers import RotaryEmbedding
|
18 |
-
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
19 |
from transformers.models.llama import LlamaConfig
|
20 |
from torch.utils.checkpoint import checkpoint
|
21 |
|
@@ -28,7 +28,8 @@ from diffrhythm.model.modules import (
|
|
28 |
precompute_freqs_cis,
|
29 |
get_pos_embed_indices,
|
30 |
)
|
31 |
-
|
|
|
32 |
|
33 |
# Text embedding
|
34 |
|
@@ -134,9 +135,11 @@ class DiT(nn.Module):
|
|
134 |
#)
|
135 |
llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
|
136 |
llama_config._attn_implementation = 'sdpa'
|
|
|
137 |
self.transformer_blocks = nn.ModuleList(
|
138 |
[LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
|
139 |
)
|
|
|
140 |
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
141 |
|
142 |
self.text_fusion_linears = nn.ModuleList(
|
@@ -157,60 +160,53 @@ class DiT(nn.Module):
|
|
157 |
# if use_style_prompt:
|
158 |
# self.prompt_rnn = nn.LSTM(64, cond_dim, 1, batch_first=True)
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
def forward(
|
162 |
self,
|
163 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
|
|
164 |
cond: float["b n d"], # masked cond audio # noqa: F722
|
165 |
-
text: int["b nt"], # text # noqa: F722
|
166 |
time: float["b"] | float[""], # time step # noqa: F821 F722
|
167 |
drop_audio_cond, # cfg for cond audio
|
168 |
-
drop_text, # cfg for text
|
169 |
drop_prompt=False,
|
170 |
style_prompt=None, # [b d t]
|
171 |
-
style_prompt_lens=None,
|
172 |
-
mask: bool["b n"] | None = None, # noqa: F722
|
173 |
-
grad_ckpt=False,
|
174 |
start_time=None,
|
175 |
):
|
176 |
batch, seq_len = x.shape[0], x.shape[1]
|
177 |
if time.ndim == 0:
|
178 |
time = time.repeat(batch)
|
179 |
|
180 |
-
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
181 |
t = self.time_embed(time)
|
182 |
-
|
183 |
-
c = t + s_t
|
184 |
-
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
185 |
|
186 |
-
# import pdb; pdb.set_trace()
|
187 |
if drop_prompt:
|
188 |
style_prompt = torch.zeros_like(style_prompt)
|
189 |
-
# if self.training:
|
190 |
-
# packed_style_prompt = torch.nn.utils.rnn.pack_padded_sequence(style_prompt.transpose(1, 2), style_prompt_lens.cpu(), batch_first=True, enforce_sorted=False)
|
191 |
-
# else:
|
192 |
-
# packed_style_prompt = style_prompt.transpose(1, 2)
|
193 |
-
#print(packed_style_prompt.shape)
|
194 |
-
# _, style_emb = self.prompt_rnn.forward(packed_style_prompt)
|
195 |
-
# _, (h_n, c_n) = self.prompt_rnn.forward(packed_style_prompt)
|
196 |
-
# style_emb = h_n.squeeze(0) # 1, B, dim -> B, dim
|
197 |
|
198 |
-
|
199 |
|
200 |
-
x = self.input_embed(x, cond, text_embed,
|
201 |
|
202 |
if self.long_skip_connection is not None:
|
203 |
residual = x
|
204 |
|
205 |
pos_ids = torch.arange(x.shape[1], device=x.device)
|
206 |
pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
|
|
|
|
|
207 |
for i, block in enumerate(self.transformer_blocks):
|
208 |
-
|
209 |
-
x, *_ = block(x, position_ids=pos_ids)
|
210 |
-
else:
|
211 |
-
x, *_ = checkpoint(block, x, position_ids=pos_ids, use_reentrant=False)
|
212 |
if i < self.depth // 2:
|
213 |
-
x = x +
|
214 |
|
215 |
if self.long_skip_connection is not None:
|
216 |
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
|
|
15 |
import torch.nn.functional as F
|
16 |
|
17 |
from x_transformers.x_transformers import RotaryEmbedding
|
18 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
|
19 |
from transformers.models.llama import LlamaConfig
|
20 |
from torch.utils.checkpoint import checkpoint
|
21 |
|
|
|
28 |
precompute_freqs_cis,
|
29 |
get_pos_embed_indices,
|
30 |
)
|
31 |
+
from liger_kernel.transformers import apply_liger_kernel_to_llama
|
32 |
+
apply_liger_kernel_to_llama()
|
33 |
|
34 |
# Text embedding
|
35 |
|
|
|
135 |
#)
|
136 |
llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
|
137 |
llama_config._attn_implementation = 'sdpa'
|
138 |
+
#llama_config._attn_implementation = ''
|
139 |
self.transformer_blocks = nn.ModuleList(
|
140 |
[LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
|
141 |
)
|
142 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=llama_config)
|
143 |
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
144 |
|
145 |
self.text_fusion_linears = nn.ModuleList(
|
|
|
160 |
# if use_style_prompt:
|
161 |
# self.prompt_rnn = nn.LSTM(64, cond_dim, 1, batch_first=True)
|
162 |
|
163 |
+
def forward_timestep_invariant(self, text, seq_len, drop_text, start_time):
|
164 |
+
s_t = self.start_time_embed(start_time)
|
165 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
166 |
+
text_residuals = []
|
167 |
+
for layer in self.text_fusion_linears:
|
168 |
+
text_residual = layer(text_embed)
|
169 |
+
text_residuals.append(text_residual)
|
170 |
+
return s_t, text_embed, text_residuals
|
171 |
+
|
172 |
|
173 |
def forward(
|
174 |
self,
|
175 |
x: float["b n d"], # nosied input audio # noqa: F722
|
176 |
+
text_embed: int["b nt"], # text # noqa: F722
|
177 |
+
text_residuals,
|
178 |
cond: float["b n d"], # masked cond audio # noqa: F722
|
|
|
179 |
time: float["b"] | float[""], # time step # noqa: F821 F722
|
180 |
drop_audio_cond, # cfg for cond audio
|
|
|
181 |
drop_prompt=False,
|
182 |
style_prompt=None, # [b d t]
|
|
|
|
|
|
|
183 |
start_time=None,
|
184 |
):
|
185 |
batch, seq_len = x.shape[0], x.shape[1]
|
186 |
if time.ndim == 0:
|
187 |
time = time.repeat(batch)
|
188 |
|
|
|
189 |
t = self.time_embed(time)
|
190 |
+
c = t + start_time
|
|
|
|
|
191 |
|
|
|
192 |
if drop_prompt:
|
193 |
style_prompt = torch.zeros_like(style_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
+
style_embed = style_prompt # [b, 512]
|
196 |
|
197 |
+
x = self.input_embed(x, cond, text_embed, style_embed, c, drop_audio_cond=drop_audio_cond)
|
198 |
|
199 |
if self.long_skip_connection is not None:
|
200 |
residual = x
|
201 |
|
202 |
pos_ids = torch.arange(x.shape[1], device=x.device)
|
203 |
pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
|
204 |
+
rotary_embed = self.rotary_emb(x, pos_ids)
|
205 |
+
|
206 |
for i, block in enumerate(self.transformer_blocks):
|
207 |
+
x, *_ = block(x, position_embeddings=rotary_embed)
|
|
|
|
|
|
|
208 |
if i < self.depth // 2:
|
209 |
+
x = x + text_residuals[i]
|
210 |
|
211 |
if self.long_skip_connection is not None:
|
212 |
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|