mfrashad's picture
Init code
8f87579
raw
history blame contribute delete
No virus
1.66 kB
import argparse
import torch
from torchvision import utils
from model import Generator
from tqdm import tqdm
def generate(args, g_ema, device, mean_latent):
with torch.no_grad():
g_ema.eval()
for i in tqdm(range(args.pics)):
sample_z = torch.randn(args.sample, args.latent, device=device)
sample, _ = g_ema([sample_z], truncation=args.truncation, truncation_latent=mean_latent)
utils.save_image(
sample,
f'sample/{str(i).zfill(6)}.png',
nrow=1,
normalize=True,
range=(-1, 1),
)
if __name__ == '__main__':
device = 'cuda'
parser = argparse.ArgumentParser()
parser.add_argument('--size', type=int, default=1024)
parser.add_argument('--sample', type=int, default=1)
parser.add_argument('--pics', type=int, default=20)
parser.add_argument('--truncation', type=float, default=1)
parser.add_argument('--truncation_mean', type=int, default=4096)
parser.add_argument('--ckpt', type=str, default="stylegan2-ffhq-config-f.pt")
parser.add_argument('--channel_multiplier', type=int, default=2)
args = parser.parse_args()
args.latent = 512
args.n_mlp = 8
g_ema = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
checkpoint = torch.load(args.ckpt)
g_ema.load_state_dict(checkpoint['g_ema'])
if args.truncation < 1:
with torch.no_grad():
mean_latent = g_ema.mean_latent(args.truncation_mean)
else:
mean_latent = None
generate(args, g_ema, device, mean_latent)