Spaces:
Runtime error
Runtime error
import diffusers | |
import torch | |
import os | |
import time | |
import streamlit as st | |
from stqdm import stqdm | |
from diffusers import DiffusionPipeline, UNet2DConditionModel | |
from PIL import Image | |
MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0' | |
LoRa_DIR = 'weights' | |
DATASET_REPO = 'VESSL/Bored_Ape_NFT_text' | |
SAMPLE_IMAGE = 'weights/Sample.png' | |
def load_pipeline_w_lora() : | |
# Load pretrained unet from huggingface | |
unet = UNet2DConditionModel.from_pretrained( | |
MODEL_REPO, | |
subfolder="unet", | |
revision=None | |
) | |
# Load LoRa attn layer weights to unet attn layers | |
unet.load_attn_procs(LoRa_DIR) | |
# Load pipeline | |
pipeline = DiffusionPipeline.from_pretrained( | |
MODEL_REPO, | |
unet=unet, | |
revision=None, | |
torch_dtype=torch.float32, | |
) | |
pipeline.set_progress_bar_config(disable=True) | |
return pipeline | |
def elapsed_time(fn, *args): | |
start = time.time() | |
output = fn(*args) | |
end = time.time() | |
elapsed = f'{end - start:.2f}' | |
return elapsed, output | |
def main(): | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
st.title("BAYC Text to IMAGE generator") | |
st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}") | |
sample = Image.open(SAMPLE_IMAGE) | |
st.image(sample, caption="An ape with solid gold fur and beanie") | |
elapsed, pipeline = elapsed_time(load_pipeline_w_lora) | |
st.write(f"Model is loaded in {elapsed} seconds!") | |
prompt = st.text_input( | |
label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)") | |
num_images = st.slider("Number of images to generate", 1, 10, 1) | |
seed = st.slider("Seed for images", 1, 10000, 1) | |
if prompt and num_images and seed: | |
st.write(f"Generating {num_images}BAYC image with prompt {prompt}...") | |
generator = torch.Generator(device=device).manual_seed(seed) | |
images = [] | |
for img_idx in stqdm(range(num_images)): | |
generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0] | |
images.append(generated_image) | |
st.write("Done!") | |
st.image(images, width=150, caption=f"Generated Images with {prompt}") | |
if __name__ == '__main__': | |
main() | |