Spaces:
Running
Running
jason-on-salt-a40
commited on
Commit
·
579d79b
1
Parent(s):
78774ba
better hf integration
Browse files- app.py +1 -1
- models/voicecraft.py +37 -10
app.py
CHANGED
@@ -93,7 +93,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
|
|
93 |
transcribe_model = WhisperxModel(whisper_model_name, align_model)
|
94 |
|
95 |
voicecraft_name = f"{voicecraft_model_name}.pth"
|
96 |
-
model = voicecraft.
|
97 |
phn2num = model.args.phn2num
|
98 |
config = model.args
|
99 |
model.to(device)
|
|
|
93 |
transcribe_model = WhisperxModel(whisper_model_name, align_model)
|
94 |
|
95 |
voicecraft_name = f"{voicecraft_model_name}.pth"
|
96 |
+
model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
|
97 |
phn2num = model.args.phn2num
|
98 |
config = model.args
|
99 |
model.to(device)
|
models/voicecraft.py
CHANGED
@@ -3,6 +3,7 @@ import random
|
|
3 |
import numpy as np
|
4 |
import logging
|
5 |
import argparse, copy
|
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
@@ -17,8 +18,11 @@ from .modules.transformer import (
|
|
17 |
TransformerEncoderLayer,
|
18 |
)
|
19 |
from .codebooks_patterns import DelayedPatternProvider
|
20 |
-
|
21 |
from argparse import Namespace
|
|
|
|
|
|
|
22 |
def top_k_top_p_filtering(
|
23 |
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
24 |
):
|
@@ -83,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
|
83 |
|
84 |
|
85 |
|
86 |
-
class VoiceCraft(
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
self.args = copy.copy(args)
|
90 |
self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
|
91 |
if not getattr(self.args, "special_first", False):
|
@@ -97,7 +123,7 @@ class VoiceCraft(nn.Module):
|
|
97 |
if self.args.eos > 0:
|
98 |
assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
|
99 |
self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
|
100 |
-
if
|
101 |
self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
|
102 |
|
103 |
self.n_text_tokens = self.args.text_vocab_size + 1
|
@@ -410,6 +436,10 @@ class VoiceCraft(nn.Module):
|
|
410 |
.expand(-1, self.args.nhead, -1, -1)
|
411 |
.reshape(bsz * self.args.nhead, 1, src_len)
|
412 |
)
|
|
|
|
|
|
|
|
|
413 |
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
414 |
|
415 |
new_attn_mask = torch.zeros_like(xy_attn_mask)
|
@@ -455,8 +485,10 @@ class VoiceCraft(nn.Module):
|
|
455 |
before padding.
|
456 |
"""
|
457 |
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
|
|
|
|
|
458 |
x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
|
459 |
-
y = y[:, :y_lens.max()]
|
460 |
assert x.ndim == 2, x.shape
|
461 |
assert x_lens.ndim == 1, x_lens.shape
|
462 |
assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
|
@@ -1405,8 +1437,3 @@ class VoiceCraft(nn.Module):
|
|
1405 |
flatten_gen = flatten_gen - int(self.args.n_special)
|
1406 |
|
1407 |
return res, flatten_gen[0].unsqueeze(0)
|
1408 |
-
|
1409 |
-
class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", tags=["Text-to-Speech", "VoiceCraft"]):
|
1410 |
-
def __init__(self, config: dict):
|
1411 |
-
args = Namespace(**config)
|
1412 |
-
super().__init__(args)
|
|
|
3 |
import numpy as np
|
4 |
import logging
|
5 |
import argparse, copy
|
6 |
+
from typing import Dict, Optional
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
|
|
18 |
TransformerEncoderLayer,
|
19 |
)
|
20 |
from .codebooks_patterns import DelayedPatternProvider
|
21 |
+
|
22 |
from argparse import Namespace
|
23 |
+
from huggingface_hub import PyTorchModelHubMixin
|
24 |
+
|
25 |
+
|
26 |
def top_k_top_p_filtering(
|
27 |
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
28 |
):
|
|
|
87 |
|
88 |
|
89 |
|
90 |
+
class VoiceCraft(
|
91 |
+
nn.Module,
|
92 |
+
PyTorchModelHubMixin,
|
93 |
+
library_name="voicecraft",
|
94 |
+
repo_url="https://github.com/jasonppy/VoiceCraft",
|
95 |
+
tags=["text-to-speech"],
|
96 |
+
):
|
97 |
+
def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft":
|
98 |
+
# If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
|
99 |
+
# Won't affect instance initialization
|
100 |
+
if args is not None:
|
101 |
+
if config is not None:
|
102 |
+
raise ValueError("Cannot provide both `args` and `config`.")
|
103 |
+
config = vars(args)
|
104 |
+
return super().__new__(cls, args=args, config=config, **kwargs)
|
105 |
+
|
106 |
+
def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
|
107 |
super().__init__()
|
108 |
+
|
109 |
+
# If loaded from HF Hub => convert config.json to Namespace args before initializing
|
110 |
+
if args is None:
|
111 |
+
if config is None:
|
112 |
+
raise ValueError("Either `args` or `config` must be provided.")
|
113 |
+
args = Namespace(**config)
|
114 |
+
|
115 |
self.args = copy.copy(args)
|
116 |
self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
|
117 |
if not getattr(self.args, "special_first", False):
|
|
|
123 |
if self.args.eos > 0:
|
124 |
assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
|
125 |
self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
|
126 |
+
if isinstance(self.args.audio_vocab_size, str):
|
127 |
self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
|
128 |
|
129 |
self.n_text_tokens = self.args.text_vocab_size + 1
|
|
|
436 |
.expand(-1, self.args.nhead, -1, -1)
|
437 |
.reshape(bsz * self.args.nhead, 1, src_len)
|
438 |
)
|
439 |
+
# Check shapes and resize+broadcast as necessary
|
440 |
+
if xy_attn_mask.shape != _xy_padding_mask.shape:
|
441 |
+
assert xy_attn_mask.ndim + 1 == _xy_padding_mask.ndim, f"xy_attn_mask.shape: {xy_attn_mask.shape}, _xy_padding_mask: {_xy_padding_mask.shape}"
|
442 |
+
xy_attn_mask = xy_attn_mask.unsqueeze(0).repeat(_xy_padding_mask.shape[0], 1, 1) # Example approach
|
443 |
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
444 |
|
445 |
new_attn_mask = torch.zeros_like(xy_attn_mask)
|
|
|
485 |
before padding.
|
486 |
"""
|
487 |
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
|
488 |
+
if len(x) == 0:
|
489 |
+
return None
|
490 |
x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
|
491 |
+
y = y[:, :, :y_lens.max()]
|
492 |
assert x.ndim == 2, x.shape
|
493 |
assert x_lens.ndim == 1, x_lens.shape
|
494 |
assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
|
|
|
1437 |
flatten_gen = flatten_gen - int(self.args.n_special)
|
1438 |
|
1439 |
return res, flatten_gen[0].unsqueeze(0)
|
|
|
|
|
|
|
|
|
|