import numpy as np import os from PIL import Image import random from dalle_mini import DalleBart, DalleBartProcessor from vqgan_jax.modeling_flax_vqgan import VQModel # Model references # dalle-mini, mega too large # DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or 🤗 Hub or local folder or google bucket DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0" DALLE_COMMIT_ID = None # VQGAN model VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384" VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9" model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID) processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) def get_image(text): tokenized_prompt = processor([text]) gen_top_k = None gen_top_p = None temperature = 0.85 cond_scale = 3.0 encoded_images = model.generate( tokenized_prompt, random.randint(0, 1e7), model.params, gen_top_k, gen_top_p, temperature, cond_scale, ) encoded_images = encoded_images.sequences[..., 1:] decoded_images = model.decode(encoded_images, vqgan.params) decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) img = decoded_images[0] return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))