Spaces:
Runtime error
Runtime error
import random | |
import numpy as np | |
from PIL import Image | |
from dalle_mini import DalleBart, DalleBartProcessor | |
from vqgan_jax.modeling_flax_vqgan import VQModel | |
# Model references | |
# dalle-mega | |
DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or π€ Hub or local folder or google bucket | |
DALLE_COMMIT_ID = None | |
# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line | |
# DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0" | |
# VQGAN model | |
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384" | |
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9" | |
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) | |
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID) | |
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) | |
def get_image(text): | |
tokenized_prompt = processor([text]) | |
gen_top_k = None | |
gen_top_p = None | |
temperature = 0.85 | |
cond_scale = 3.0 | |
encoded_images = model.generate( | |
tokenized_prompt, | |
random.randint(0, 1e7), | |
model.params, | |
gen_top_k, | |
gen_top_p, | |
temperature, | |
cond_scale, | |
) | |
encoded_images = encoded_images.sequences[..., 1:] | |
decoded_images = model.decode(encoded_images, vqgan.params) | |
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) | |
img = decoded_images[0] | |
return Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) | |