treasuraid's picture
Update app.py
d1ca92a
raw
history blame
2.3 kB
import diffusers
import torch
import os
import time
import streamlit as st
from stqdm import stqdm
from diffusers import DiffusionPipeline, UNet2DConditionModel
from PIL import Image
MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0'
LoRa_DIR = 'weights'
DATASET_REPO = 'VESSL/Bored_Ape_NFT_text'
SAMPLE_IMAGE = 'weights/Sample.png'
def load_pipeline_w_lora() :
# Load pretrained unet from huggingface
unet = UNet2DConditionModel.from_pretrained(
MODEL_REPO,
subfolder="unet",
revision=None
)
# Load LoRa attn layer weights to unet attn layers
unet.load_attn_procs(LoRa_DIR)
# Load pipeline
pipeline = DiffusionPipeline.from_pretrained(
MODEL_REPO,
unet=unet,
revision=None,
torch_dtype=torch.float32,
)
pipeline.set_progress_bar_config(disable=True)
return pipeline
def elapsed_time(fn, *args):
start = time.time()
output = fn(*args)
end = time.time()
elapsed = f'{end - start:.2f}'
return elapsed, output
def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
st.title("BAYC Text to IMAGE generator")
st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}")
sample = Image.open(SAMPLE_IMAGE)
st.image(sample, caption="An ape with solid gold fur and beanie")
elapsed, pipeline = elapsed_time(load_pipeline_w_lora)
st.write(f"Model is loaded in {elapsed} seconds!")
prompt = st.text_input(
label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)")
num_images = st.slider("Number of images to generate", 1, 10, 1)
seed = st.slider("Seed for images", 1, 10000, 1)
if prompt and num_images and seed:
st.write(f"Generating {num_images}BAYC image with prompt {prompt}...")
generator = torch.Generator(device=device).manual_seed(seed)
images = []
for img_idx in stqdm(range(num_images)):
generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0]
images.append(generated_image)
st.write("Done!")
st.image(images, width=150, caption=f"Generated Images with {prompt}")
if __name__ == '__main__':
main()