draw_me_a_sheep_heb / image_generator.py
Amir Zait
added dalle
be37091
raw
history blame
1.5 kB
import random
import numpy as np
from PIL import Image
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
# Model references
# dalle-mega
DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or πŸ€— Hub or local folder or google bucket
DALLE_COMMIT_ID = None
# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line
# DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
def get_image(text):
tokenized_prompt = processor([text])
gen_top_k = None
gen_top_p = None
temperature = 0.85
cond_scale = 3.0
encoded_images = model.generate(
tokenized_prompt,
random.randint(0, 1e7),
model.params,
gen_top_k,
gen_top_p,
temperature,
cond_scale,
)
encoded_images = encoded_images.sequences[..., 1:]
decoded_images = model.decode(encoded_images, vqgan.params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
img = decoded_images[0]
return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))