# Code adapted from the following sources: # https://huggingface.co./huggan/fastgan-few-shot-fauvism-still-life # https://huggingface.co./spaces/huggan/FastGan/ import torch from PIL import Image from models import Generator def load_img_generator(model_name_or_path): generator = Generator(in_channels=256, out_channels=3) generator = generator.from_pretrained(model_name_or_path, in_channels=256, out_channels=3) _ = generator.eval() return generator def _denormalize(input: torch.Tensor) -> torch.Tensor: return (input * 127.5) + 127.5 def generate_img(device, gan_model): img_generator = load_img_generator("huggan/fastgan-few-shot-"+gan_model) noise = torch.zeros(1, 256, 1, 1, device=device).normal_(0.0, 1.0) with torch.no_grad(): gan_images, _ = img_generator(noise) gan_image = _denormalize(gan_images.detach()).cpu().squeeze() gan_image = gan_image.permute(1, 2, 0).to("cpu", torch.uint8).numpy() gan_image = Image.fromarray(gan_image) return gan_image