xcczach's picture
Upload 73 files
35c1cfd verified
raw
history blame
1.58 kB
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from fairseq.data import Dictionary
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
logger = logging.get_logger(__name__)
class VallexConfig(PretrainedConfig):
model_type = "vallex"
def __init__(self,
n_layer=24,
n_head=16,
n_dim=1024,
prefix_mode=1,
num_quantizers=8,
sample_rate=24000,
ar_at_dict="",
ar_st_dict="",
nar_at_dict="",
nar_st_dict="",
nar_scale_factor=1.0,
prepend_bos=True,
norm_first=True,
eps=0.0,
only_ar=False,
only_nar=False,
**kwargs
):
self.n_layer = n_layer
self.n_head = n_head
self.n_dim = n_dim
self.prefix_mode = prefix_mode
self.num_quantizers = num_quantizers
self.sample_rate = sample_rate
self.nar_scale_factor = nar_scale_factor
self.prepend_bos = prepend_bos
self.norm_first = norm_first
self.ar_at_dict = ar_at_dict
self.ar_st_dict = ar_st_dict
self.nar_at_dict = nar_at_dict
self.nar_st_dict = nar_st_dict
self.eps = eps
self.only_ar = only_ar
self.only_nar = only_nar
super().__init__(
**kwargs
)
AutoConfig.register("vallex", VallexConfig)