OpenSound commited on
Commit
1047b0c
·
verified ·
1 Parent(s): a09029c

Delete src/test.py

Browse files
Files changed (1) hide show
  1. src/test.py +0 -97
src/test.py DELETED
@@ -1,97 +0,0 @@
1
- import random
2
- import argparse
3
- import os
4
- import time
5
- import soundfile as sf
6
- import numpy as np
7
- import pandas as pd
8
- from tqdm import tqdm
9
-
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
-
14
- from diffusers import DDIMScheduler
15
- from models.conditioners import MaskDiT
16
- from modules.autoencoder_wrapper import Autoencoder
17
- from transformers import T5Tokenizer, T5EncoderModel
18
- from inference import inference
19
- from utils import scale_shift, get_lr_scheduler, compute_snr, load_yaml_with_includes
20
-
21
-
22
- parser = argparse.ArgumentParser()
23
- # config settings
24
- parser.add_argument('--config-name', type=str, default='configs/udit_ada.yml')
25
- parser.add_argument('--ckpt-path', type=str, default='../ckpts/')
26
- parser.add_argument('--ckpt-id', type=str, default='120')
27
- parser.add_argument('--save_path', type=str, default='../output/')
28
- parser.add_argument('--test-df', type=str, default='audiocaps_test.csv')
29
- # parser.add_argument('--test-split', type=str, default='test')
30
-
31
- parser.add_argument('--device', type=str, default='cuda')
32
- parser.add_argument('--guidance-scale', type=float, default=3)
33
- parser.add_argument('--guidance-rescale', type=float, default=0)
34
- parser.add_argument('--ddim-steps', type=int, default=50)
35
- parser.add_argument('--eta', type=float, default=1)
36
- parser.add_argument('--random-seed', type=int, default=None)
37
-
38
- args = parser.parse_args()
39
- params = load_yaml_with_includes(args.config_name)
40
-
41
- # args.ckpt_path = f"{args.ckpt_path}/{params['model_name']}/{args.ckpt_id}.pt"
42
- args.save_path = f"{args.save_path}/{params['model_name']}/{args.ckpt_id}_{args.ddim_steps}_{args.guidance_scale}_{args.guidance_rescale}/"
43
- args.ckpt_path = f"{args.ckpt_path}/{args.ckpt_id}.pt"
44
-
45
- if __name__ == '__main__':
46
- # Codec Model
47
- autoencoder = Autoencoder(ckpt_path=params['autoencoder']['path'],
48
- model_type=params['autoencoder']['name'],
49
- quantization_first=params['autoencoder']['q_first'])
50
- autoencoder.to(args.device)
51
- autoencoder.eval()
52
-
53
- # text encoder
54
- tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
55
- text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model'],
56
- device_map='cpu').to(args.device)
57
- text_encoder.eval()
58
-
59
- # main U-Net
60
- unet = MaskDiT(**params['model']).to(args.device)
61
- unet.eval()
62
- unet.load_state_dict(torch.load(args.ckpt_path)['model'])
63
-
64
- total_params = sum([param.nelement() for param in unet.parameters()])
65
- print("Number of parameter: %.2fM" % (total_params / 1e6))
66
-
67
- noise_scheduler = DDIMScheduler(**params['diff'])
68
- # these steps reset dtype of noise_scheduler params
69
- latents = torch.randn((1, 128, 128), device=args.device)
70
- noise = torch.randn_like(latents)
71
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=args.device)
72
- _ = noise_scheduler.add_noise(latents, noise, timesteps)
73
-
74
- df = pd.read_csv(args.test_df)
75
- # Wdf = df[df['split'] == args.test_split]
76
- df = df[df['audio_length'] != 0]
77
- # df = df.sample(10)
78
- os.makedirs(args.save_path, exist_ok=True)
79
- audio_frames = params['data']['train_frames']
80
-
81
- for i in tqdm(range(len(df))):
82
- row = df.iloc[i]
83
- text = row['caption']
84
- audio_id = row['audiocap_id']
85
-
86
- pred = inference(autoencoder, unet, None, None,
87
- tokenizer, text_encoder,
88
- params, noise_scheduler,
89
- text, None,
90
- audio_frames,
91
- args.guidance_scale, args.guidance_rescale,
92
- args.ddim_steps, args.eta, args.random_seed,
93
- args.device)
94
- pred = pred.cpu().numpy().squeeze(0).squeeze(0)
95
-
96
- sf.write(f"{args.save_path}/{audio_id}.wav",
97
- pred, samplerate=params['data']['sr'])