treasuraid commited on
Commit
e83adaa
·
1 Parent(s): efbaec2

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +83 -0
  2. weights/Sample.png +0 -0
  3. weights/pytorch_lora_weights.bin +3 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import diffusers
2
+ import torch
3
+ import os
4
+ import time
5
+
6
+ import streamlit as st
7
+
8
+ from diffusers import DiffusionPipeline, UNet2DConditionModel
9
+ from PIL import Image
10
+
11
+
12
+ MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0'
13
+ LoRa_DIR = '/weights'
14
+ DATASET_REPO = 'VESSL/Bored_Ape_NFT_text'
15
+ SAMPLE_IMAGE = '/weights/Sample.png'
16
+ def load_pipeline_w_lora() :
17
+
18
+ # Load pretrained unet from huggingface
19
+ unet = UNet2DConditionModel.from_pretrained(
20
+ MODEL_REPO,
21
+ subfolder="unet",
22
+ revision=None
23
+ )
24
+
25
+ # Load LoRa attn layer weights to unet attn layers
26
+ unet.load_attn_procs(LoRa_DIR)
27
+
28
+ # Load pipeline
29
+ pipeline = DiffusionPipeline.from_pretrained(
30
+ MODEL_REPO,
31
+ unet=unet,
32
+ revision=None,
33
+ torch_dtype=torch.float32,
34
+ )
35
+
36
+ return pipeline
37
+
38
+
39
+ def elapsed_time(fn, *args):
40
+ start = time.time()
41
+ output = fn(*args)
42
+ end = time.time()
43
+
44
+ elapsed = f'{end - start:.2f}'
45
+
46
+ return elapsed, output
47
+
48
+
49
+ def main():
50
+
51
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
52
+
53
+ st.title("BAYC Text to IMAGE generator")
54
+ st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}")
55
+
56
+ sample = Image.open(SAMPLE_IMAGE)
57
+ st.image(sample, caption="An ape with solid gold fur and beanie")
58
+
59
+ elapsed, pipeline = elapsed_time(load_pipeline_w_lora)
60
+ st.write(f"Model is loaded in {elapsed} seconds!")
61
+
62
+ prompt = st.text_input(
63
+ label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)")
64
+
65
+ num_images = st.slider("Number of images to generate", 1, 10, 1)
66
+
67
+ seed = st.slider("Seed for images", 1, 10000, 1)
68
+
69
+ if prompt and num_images and seed:
70
+ st.write(f"Generating {num_images}BAYC image with prompt {prompt}...")
71
+
72
+ generator = torch.Generator(device=device).manual_seed(seed)
73
+ images = []
74
+ for img_idx in range(num_images):
75
+ generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0]
76
+ images.append(generated_image)
77
+
78
+ st.write("Done!")
79
+
80
+ st.image(images, width=150, caption=f"Generated Images with {prompt}")
81
+
82
+ if __name__ == '__main__':
83
+ main()
weights/Sample.png ADDED
weights/pytorch_lora_weights.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4927a8337a888edff7d844cac082c141726ac60fe8ad065c9143ec0987a9297
3
+ size 1080571