import argparse import functools import os import random from tqdm import tqdm import sys sys.path.append('../') import yaml import time import numpy as np import torch from data.data import get_audiotext_dataloader @torch.no_grad() def validation_losses(model, data_config, clap_config, tokenizer, batch_size, autocast, cast_dtype, device_id, verbose=True): model.eval() @torch.no_grad() def get_val_loss(validloader): loss_sum = 0.0 for idx, batch in tqdm(enumerate(validloader)): audio_clips = batch["audio_clips"].to(device_id, dtype=cast_dtype, non_blocking=True) audio_embed_mask = batch["audio_embed_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) labels = input_ids.clone() labels[labels == tokenizer.pad_token_id] = -100 labels[:, :1] = -100 labels[labels == tokenizer.encode("