|
import logging |
|
import sys |
|
import time |
|
from dataclasses import field |
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Training the library models for Wav2Vec. |
|
""" |
|
|
|
|
|
import numpy as np |
|
from datasets import DatasetDict, load_dataset |
|
from tqdm import tqdm |
|
|
|
import flax |
|
import jax |
|
import jax.numpy as jnp |
|
import librosa |
|
import optax |
|
from flax import jax_utils, traverse_util |
|
from flax.training import train_state |
|
from flax.training.common_utils import get_metrics, onehot, shard |
|
from transformers import ( |
|
FlaxWav2Vec2ForPreTraining, |
|
HfArgumentParser, |
|
TrainingArguments, |
|
Wav2Vec2Config, |
|
Wav2Vec2FeatureExtractor, |
|
is_tensorboard_available, |
|
) |
|
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices, _sample_negative_indices |
|
|
|
from normalizer import normalizer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@flax.struct.dataclass |
|
class ModelArguments: |
|
""" |
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. |
|
""" |
|
|
|
model_name_or_path: str = field( |
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} |
|
) |
|
cache_dir: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, |
|
) |
|
freeze_feature_extractor: Optional[bool] = field( |
|
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} |
|
) |
|
gradient_checkpointing: Optional[bool] = field( |
|
default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."} |
|
) |
|
verbose_logging: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "Whether to log verbose messages or not."}, |
|
) |
|
max_gumbel_temperature: Optional[float] = field( |
|
default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."} |
|
) |
|
min_gumbel_temperature: Optional[float] = field( |
|
default=0.1, metadata={"help": "Minimum temperature for gumbel softmax."} |
|
) |
|
gumbel_temperature_decay: Optional[float] = field( |
|
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."} |
|
) |
|
dtype: Optional[str] = field( |
|
default="float32", |
|
metadata={ |
|
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." |
|
}, |
|
) |
|
|
|
|
|
@flax.struct.dataclass |
|
class DataTrainingArguments: |
|
""" |
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
|
|
Using `HfArgumentParser` we can turn this class |
|
into argparse arguments to be able to specify them on |
|
the command line. |
|
""" |
|
|
|
dataset_name: str = field( |
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
|
) |
|
dataset_config_name: Optional[str] = field( |
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
|
) |
|
train_split_name: Optional[str] = field( |
|
default="train", |
|
metadata={ |
|
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" |
|
}, |
|
) |
|
validation_split_name: Optional[str] = field( |
|
default="validation", |
|
metadata={ |
|
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'" |
|
}, |
|
) |
|
train_file: Optional[str] = field( |
|
default=None, metadata={"help": "The input training data file (a csv or JSON file)."} |
|
) |
|
validation_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, |
|
) |
|
speech_file_column: Optional[str] = field( |
|
default="file", |
|
metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"}, |
|
) |
|
overwrite_cache: bool = field( |
|
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} |
|
) |
|
validation_split_percentage: Optional[int] = field( |
|
default=5, |
|
metadata={ |
|
"help": "The percentage of the train set used as validation set in case there's no validation split" |
|
}, |
|
) |
|
preprocessing_num_workers: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "The number of processes to use for the preprocessing."}, |
|
) |
|
max_duration_in_seconds: Optional[float] = field( |
|
default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"} |
|
) |
|
pad_to_multiple_of: Optional[int] = field( |
|
default=1024, |
|
metadata={ |
|
"help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU" |
|
}, |
|
) |
|
|
|
|
|
@flax.struct.dataclass |
|
class FlaxDataCollatorForWav2Vec2Pretraining: |
|
""" |
|
Data collator that will dynamically pad the inputs received and prepare masked indices |
|
for self-supervised pretraining. |
|
|
|
Args: |
|
model (:class:`~transformers.FlaxWav2Vec2ForPreTraining`): |
|
The Wav2Vec2 model used for pretraining. The data collator needs to have access |
|
to config and ``_get_feat_extract_output_lengths`` function for correct padding. |
|
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`): |
|
The processor used for proccessing the data. |
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): |
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index) |
|
among: |
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single |
|
sequence if provided). |
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the |
|
maximum acceptable input length for the model if that argument is not provided. |
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of |
|
different lengths). |
|
max_length (:obj:`int`, `optional`): |
|
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). |
|
pad_to_multiple_of (:obj:`int`, `optional`): |
|
If set will pad the sequence to a multiple of the provided value. |
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= |
|
7.5 (Volta). |
|
""" |
|
|
|
model: FlaxWav2Vec2ForPreTraining |
|
feature_extractor: Wav2Vec2FeatureExtractor |
|
padding: Union[bool, str] = "longest" |
|
pad_to_multiple_of: Optional[int] = None |
|
max_length: Optional[int] = None |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: |
|
|
|
batch = self.feature_extractor.pad( |
|
features, |
|
max_length=self.max_length, |
|
padding=self.padding, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors="np", |
|
) |
|
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1]) |
|
|
|
|
|
batch["mask_time_indices"] = _compute_mask_indices( |
|
(batch["input_values"].shape[0], mask_indices_seq_length), |
|
self.model.config.mask_time_prob, |
|
self.model.config.mask_time_length, |
|
min_masks=2, |
|
) |
|
|
|
|
|
batch["sampled_negative_indices"] = _sample_negative_indices( |
|
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)), |
|
self.model.config.num_negatives, |
|
) |
|
|
|
return batch |
|
|
|
|
|
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments): |
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
) |
|
logging_level = logging.WARNING |
|
if model_args.verbose_logging: |
|
logging_level = logging.DEBUG |
|
logger.setLevel(logging_level) |
|
|
|
|
|
def write_train_metric(summary_writer, train_metrics, train_time, step): |
|
summary_writer.scalar("train_time", train_time, step) |
|
|
|
train_metrics = get_metrics(train_metrics) |
|
for key, vals in train_metrics.items(): |
|
tag = f"train_{key}" |
|
for i, val in enumerate(vals): |
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1) |
|
|
|
|
|
def write_eval_metric(summary_writer, eval_metrics, step): |
|
for metric_name, value in eval_metrics.items(): |
|
summary_writer.scalar(f"eval_{metric_name}", value, step) |
|
|
|
|
|
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: |
|
num_samples = len(samples_idx) |
|
samples_to_remove = num_samples % batch_size |
|
|
|
if samples_to_remove != 0: |
|
samples_idx = samples_idx[:-samples_to_remove] |
|
sections_split = num_samples // batch_size |
|
batch_idx = np.split(samples_idx, sections_split) |
|
return batch_idx |
|
|
|
|
|
def compute_contrastive_loss( |
|
quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives |
|
): |
|
batch_size, sequence_length, hidden_size = quantized_features.shape |
|
|
|
|
|
quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)] |
|
quantized_negatives = quantized_negatives.reshape( |
|
batch_size, sequence_length, num_negatives, hidden_size |
|
).transpose(2, 0, 1, 3) |
|
|
|
target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0) |
|
loss_logits = optax.cosine_similarity(transformer_features, target_features) |
|
loss_logits = loss_logits / logits_temp |
|
|
|
neg_is_pos = (quantized_features == quantized_negatives).all(-1) |
|
neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0) |
|
|
|
|
|
loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits) |
|
|
|
predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0]) |
|
targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten() |
|
|
|
target_mask = jnp.where(targets >= 0, 1.0, 0.0) |
|
contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask |
|
|
|
contrastive_loss = contrastive_loss.sum() |
|
|
|
return contrastive_loss |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) |
|
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
configure_logger(model_args, training_args) |
|
|
|
|
|
if data_args.dataset_name: |
|
|
|
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) |
|
|
|
if "validation" not in datasets.keys(): |
|
|
|
datasets = DatasetDict() |
|
datasets["validation"] = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]", |
|
cache_dir=model_args.cache_dir, |
|
) |
|
datasets["train"] = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]", |
|
cache_dir=model_args.cache_dir, |
|
) |
|
else: |
|
|
|
datasets = DatasetDict() |
|
datasets["validation"] = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
split="validation", |
|
cache_dir=model_args.cache_dir, |
|
) |
|
datasets["train"] = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
split=f"{data_args.train_split_name}", |
|
cache_dir=model_args.cache_dir, |
|
) |
|
else: |
|
data_files = {} |
|
if data_args.train_file is not None: |
|
data_files["train"] = data_args.train_file |
|
if data_args.validation_file is not None: |
|
data_files["validation"] = data_args.validation_file |
|
extension = data_args.train_file.split(".")[-1] |
|
datasets = load_dataset(extension, data_files=data_files, delimiter="\t") |
|
|
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=model_args.cache_dir, |
|
do_normalize=True |
|
) |
|
|
|
def prepare_dataset(batch): |
|
|
|
batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate) |
|
return batch |
|
|
|
|
|
vectorized_datasets = datasets.map( |
|
prepare_dataset, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=datasets["train"].column_names |
|
) |
|
|
|
|
|
vectorized_datasets = vectorized_datasets.filter( |
|
lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) |
|
) |
|
|
|
def normalize(batch): |
|
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate) |
|
|
|
|
|
vectorized_datasets = vectorized_datasets.map( |
|
normalize, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
remove_columns=vectorized_datasets["train"].column_names, |
|
) |
|
|
|
|
|
|
|
config = Wav2Vec2Config.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=model_args.cache_dir, |
|
gradient_checkpointing=model_args.gradient_checkpointing, |
|
) |
|
|
|
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": |
|
raise ValueError( |
|
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'" |
|
) |
|
|
|
model = FlaxWav2Vec2ForPreTraining( |
|
config, |
|
seed=training_args.seed, |
|
dtype=getattr(jnp, model_args.dtype) |
|
) |
|
|
|
data_collator = FlaxDataCollatorForWav2Vec2Pretraining( |
|
model=model, |
|
feature_extractor=feature_extractor, |
|
pad_to_multiple_of=data_args.pad_to_multiple_of |
|
) |
|
|
|
|
|
has_tensorboard = is_tensorboard_available() |
|
if has_tensorboard and jax.process_index() == 0: |
|
try: |
|
from flax.metrics.tensorboard import SummaryWriter |
|
|
|
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) |
|
except ImportError as ie: |
|
has_tensorboard = False |
|
logger.warning( |
|
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" |
|
) |
|
else: |
|
logger.warning( |
|
"Unable to display metrics through TensorBoard because the package is not installed: " |
|
"Please run pip install tensorboard to enable." |
|
) |
|
|
|
|
|
rng = jax.random.PRNGKey(training_args.seed) |
|
dropout_rngs = jax.random.split(rng, jax.local_device_count()) |
|
gumbel_rngs = jax.random.split(rng, jax.local_device_count()) |
|
|
|
num_epochs = int(training_args.num_train_epochs) |
|
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() |
|
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() |
|
|
|
num_train_steps = len(vectorized_datasets["train"]) // train_batch_size * num_epochs |
|
|
|
|
|
warmup_fn = optax.linear_schedule( |
|
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps |
|
) |
|
decay_fn = optax.linear_schedule( |
|
init_value=training_args.learning_rate, |
|
end_value=0, |
|
transition_steps=num_train_steps - training_args.warmup_steps, |
|
) |
|
linear_decay_lr_schedule_fn = optax.join_schedules( |
|
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def decay_mask_fn(params): |
|
flat_params = traverse_util.flatten_dict(params) |
|
flat_mask = { |
|
path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) |
|
for path in flat_params |
|
} |
|
return traverse_util.unflatten_dict(flat_mask) |
|
|
|
|
|
adamw = optax.adamw( |
|
learning_rate=linear_decay_lr_schedule_fn, |
|
b1=training_args.adam_beta1, |
|
b2=training_args.adam_beta2, |
|
eps=training_args.adam_epsilon, |
|
weight_decay=training_args.weight_decay, |
|
mask=decay_mask_fn, |
|
) |
|
|
|
|
|
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) |
|
num_negatives = model.config.num_negatives |
|
contrastive_logits_temperature = model.config.contrastive_logits_temperature |
|
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups |
|
diversity_loss_weight = model.config.diversity_loss_weight |
|
|
|
|
|
def train_step(state, batch, dropout_rng, gumbel_rng): |
|
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) |
|
gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng) |
|
|
|
def loss_fn(params): |
|
negative_indices = batch.pop("sampled_negative_indices") |
|
|
|
gumbel_temperature = jnp.clip( |
|
model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step, |
|
a_min=model_args.min_gumbel_temperature, |
|
) |
|
|
|
outputs = state.apply_fn( |
|
**batch, |
|
gumbel_temperature=gumbel_temperature, |
|
params=params, |
|
dropout_rng=dropout_rng, |
|
gumbel_rng=gumbel_rng, |
|
train=True, |
|
) |
|
|
|
contrastive_loss = compute_contrastive_loss( |
|
outputs.projected_quantized_states, |
|
outputs.projected_states, |
|
negative_indices, |
|
batch["mask_time_indices"], |
|
contrastive_logits_temperature, |
|
num_negatives, |
|
) |
|
|
|
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors |
|
loss = contrastive_loss + diversity_loss_weight * diversity_loss |
|
|
|
return loss |
|
|
|
grad_fn = jax.value_and_grad(loss_fn) |
|
loss, grad = grad_fn(state.params) |
|
grad = jax.lax.pmean(grad, "batch") |
|
new_state = state.apply_gradients(grads=grad) |
|
|
|
metrics = jax.lax.pmean( |
|
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" |
|
) |
|
|
|
return new_state, metrics, new_dropout_rng, new_gumbel_rng |
|
|
|
|
|
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) |
|
|
|
|
|
def eval_step(params, batch): |
|
negative_indices = batch.pop("sampled_negative_indices") |
|
|
|
outputs = model(**batch, params=params, train=False) |
|
|
|
contrastive_loss = compute_contrastive_loss( |
|
outputs.projected_quantized_states, |
|
outputs.projected_states, |
|
negative_indices, |
|
batch["mask_time_indices"], |
|
contrastive_logits_temperature, |
|
num_negatives, |
|
) |
|
|
|
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors |
|
loss = contrastive_loss + diversity_loss_weight * diversity_loss |
|
|
|
|
|
metrics = {"loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity} |
|
metrics = jax.lax.pmean(metrics, axis_name="batch") |
|
|
|
return metrics |
|
|
|
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) |
|
|
|
|
|
state = jax_utils.replicate(state) |
|
|
|
train_time = 0 |
|
train_metrics = [] |
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) |
|
for epoch in epochs: |
|
|
|
train_start = time.time() |
|
|
|
|
|
rng, input_rng = jax.random.split(rng) |
|
|
|
|
|
num_train_samples = len(vectorized_datasets["train"]) |
|
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) |
|
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) |
|
|
|
|
|
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): |
|
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx] |
|
model_inputs = data_collator(samples) |
|
model_inputs = shard(model_inputs.data) |
|
|
|
|
|
state, train_metric, dropout_rngs, gumbel_rngs = p_train_step( |
|
state, model_inputs, dropout_rngs, gumbel_rngs |
|
) |
|
train_metrics.append(train_metric) |
|
|
|
cur_step = epoch * (num_train_samples // train_batch_size) + step |
|
|
|
if cur_step % training_args.logging_steps == 0 and cur_step > 0: |
|
|
|
train_metric = jax_utils.unreplicate(train_metric) |
|
train_time += time.time() - train_start |
|
if has_tensorboard and jax.process_index() == 0: |
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step) |
|
|
|
epochs.write( |
|
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" |
|
) |
|
|
|
train_metrics = [] |
|
|
|
|
|
num_eval_samples = len(vectorized_datasets["validation"]) |
|
eval_samples_idx = jnp.arange(num_eval_samples) |
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) |
|
|
|
eval_metrics = [] |
|
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): |
|
samples = [vectorized_datasets["validation"][int(idx)] for idx in batch_idx] |
|
model_inputs = data_collator(samples) |
|
|
|
|
|
model_inputs = shard(model_inputs.data) |
|
metrics = p_eval_step(state.params, model_inputs) |
|
eval_metrics.append(metrics) |
|
|
|
|
|
eval_metrics = get_metrics(eval_metrics) |
|
eval_metrics = jax.tree_map(jnp.mean, eval_metrics) |
|
|
|
|
|
epochs.write( |
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})" |
|
) |
|
|
|
|
|
if has_tensorboard and jax.process_index() == 0: |
|
cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size) |
|
write_eval_metric(summary_writer, eval_metrics, cur_step) |
|
|
|
|
|
if jax.process_index() == 0: |
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) |
|
model.save_pretrained( |
|
training_args.output_dir, |
|
params=params, |
|
push_to_hub=training_args.push_to_hub |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|