m4r4k0s23 commited on
Commit
ff17886
·
verified ·
1 Parent(s): 287e1e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -50
app.py CHANGED
@@ -2,48 +2,94 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
 
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- # default parameter
11
- # model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
12
- model_repo_id = "CompVis/stable-diffusion-v1-4"
13
- model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4' ]
 
 
 
 
 
14
 
15
  if torch.cuda.is_available():
16
  torch_dtype = torch.float16
17
  else:
18
  torch_dtype = torch.float32
19
 
20
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
21
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 1024
25
 
26
-
27
- # @spaces.GPU #[uncomment to use ZeroGPU]
28
  def infer(
 
29
  prompt,
30
  negative_prompt,
 
31
  randomize_seed,
32
  width,
33
  height,
34
- model_repo_id=model_repo_id,
35
- seed=42,
36
- guidance_scale=7,
37
- num_inference_steps=20,
38
  progress=gr.Progress(track_tqdm=True),
39
  ):
 
 
 
40
  if randomize_seed:
41
  seed = random.randint(0, MAX_SEED)
42
 
43
- generator = torch.Generator().manual_seed(seed)
44
 
45
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
46
- pipe = pipe.to(device)
 
 
 
 
 
47
 
48
  image = pipe(
49
  prompt=prompt,
@@ -57,12 +103,7 @@ def infer(
57
 
58
  return image, seed
59
 
60
-
61
  examples = [
62
- "A young lady in a Russian embroidered kaftan is sitting on a beautiful carved veranda, holding a cup to her mouth and drinking tea from the cup. With her other hand, the girl holds a saucer. The cup and saucer are painted with gzhel. Next to the girl on the table stands a samovar, and steam can be seen above it.",
63
- "Puss in Boots wearing a sombrero crosses the Grand Canyon on a tightrope with a guitar.",
64
- "A cat is playing a song called ""About the Cat"" on an accordion by the sea at sunset. The sun is quickly setting behind the horizon, and the light is fading.",
65
- "A cat walks through the grass on the streets of an abandoned city. The camera view is always focused on the cat's face.",
66
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
67
  "An astronaut riding a green horse",
68
  "A delicious ceviche cheesecake slice",
@@ -77,7 +118,15 @@ css = """
77
 
78
  with gr.Blocks(css=css) as demo:
79
  with gr.Column(elem_id="col-container"):
80
- gr.Markdown(" # Text-to-Image SemaSci Template")
 
 
 
 
 
 
 
 
81
 
82
  with gr.Row():
83
  prompt = gr.Text(
@@ -93,27 +142,10 @@ with gr.Blocks(css=css) as demo:
93
  result = gr.Image(label="Result", show_label=False)
94
 
95
  with gr.Accordion("Advanced Settings", open=False):
96
- # model_repo_id = gr.Text(
97
- # label="Model Id",
98
- # max_lines=1,
99
- # placeholder="Choose model",
100
- # visible=True,
101
- # value=model_repo_id,
102
- # )
103
- model_repo_id = gr.Dropdown(
104
- label="Model Id",
105
- choices=model_dropdown,
106
- info="Choose model",
107
- visible=True,
108
- allow_custom_value=True,
109
- value=model_repo_id,
110
- )
111
-
112
  negative_prompt = gr.Text(
113
  label="Negative prompt",
114
  max_lines=1,
115
  placeholder="Enter a negative prompt",
116
- visible=True,
117
  )
118
 
119
  seed = gr.Slider(
@@ -121,10 +153,10 @@ with gr.Blocks(css=css) as demo:
121
  minimum=0,
122
  maximum=MAX_SEED,
123
  step=1,
124
- value=42,
125
  )
126
 
127
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
128
 
129
  with gr.Row():
130
  width = gr.Slider(
@@ -132,7 +164,7 @@ with gr.Blocks(css=css) as demo:
132
  minimum=256,
133
  maximum=MAX_IMAGE_SIZE,
134
  step=32,
135
- value=1024, # Replace with defaults that work for your model
136
  )
137
 
138
  height = gr.Slider(
@@ -140,40 +172,51 @@ with gr.Blocks(css=css) as demo:
140
  minimum=256,
141
  maximum=MAX_IMAGE_SIZE,
142
  step=32,
143
- value=1024, # Replace with defaults that work for your model
144
  )
145
 
146
  with gr.Row():
147
  guidance_scale = gr.Slider(
148
  label="Guidance scale",
149
  minimum=0.0,
150
- maximum=10.0,
151
- step=0.1,
152
- value=7.0, # Replace with defaults that work for your model
153
  )
154
 
155
  num_inference_steps = gr.Slider(
156
  label="Number of inference steps",
157
  minimum=1,
158
- maximum=50,
159
  step=1,
160
- value=20, # Replace with defaults that work for your model
161
  )
162
 
 
 
 
 
 
 
 
 
 
 
163
  gr.Examples(examples=examples, inputs=[prompt])
164
  gr.on(
165
  triggers=[run_button.click, prompt.submit],
166
  fn=infer,
167
  inputs=[
 
168
  prompt,
169
  negative_prompt,
 
170
  randomize_seed,
171
  width,
172
  height,
173
- model_repo_id,
174
- seed,
175
  guidance_scale,
176
  num_inference_steps,
 
177
  ],
178
  outputs=[result, seed],
179
  )
 
2
  import numpy as np
3
  import random
4
 
 
5
  from diffusers import DiffusionPipeline
6
+ from peft import PeftModel, PeftConfig
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Model list including your LoRA model
12
+ MODEL_LIST = [
13
+ "CompVis/stable-diffusion-v1-4",
14
+ "stabilityai/sdxl-turbo",
15
+ "runwayml/stable-diffusion-v1-5",
16
+ "stabilityai/stable-diffusion-2-1",
17
+ "m4r4k0s23/hw5_lora_raccoon",
18
+ ]
19
 
20
  if torch.cuda.is_available():
21
  torch_dtype = torch.float16
22
  else:
23
  torch_dtype = torch.float32
24
 
25
+ # Cache to avoid re-initializing pipelines repeatedly
26
+ model_cache = {}
27
+
28
+ def load_pipeline(model_id: str):
29
+ """
30
+ Loads or retrieves a cached DiffusionPipeline.
31
+
32
+ If the chosen model is your LoRA adapter, then load the base model
33
+ (CompVis/stable-diffusion-v1-4) and apply the LoRA weights.
34
+ """
35
+ if model_id in model_cache:
36
+ return model_cache[model_id]
37
+
38
+ if model_id == "m4r4k0s23/hw5_lora_raccoon":
39
+ # Use the specified base model for your LoRA adapter.
40
+ base_model = "CompVis/stable-diffusion-v1-4"
41
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
42
+ # Load the LoRA weights
43
+ pipe.unet = PeftModel.from_pretrained(
44
+ pipe.unet,
45
+ model_id,
46
+ subfolder="unet",
47
+ torch_dtype=torch_dtype
48
+ )
49
+ pipe.text_encoder = PeftModel.from_pretrained(
50
+ pipe.text_encoder,
51
+ model_id,
52
+ subfolder="text_encoder",
53
+ torch_dtype=torch_dtype
54
+ )
55
+ else:
56
+ pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
57
+
58
+ pipe.to(device)
59
+ model_cache[model_id] = pipe
60
+ return pipe
61
 
62
  MAX_SEED = np.iinfo(np.int32).max
63
  MAX_IMAGE_SIZE = 1024
64
 
 
 
65
  def infer(
66
+ model_id,
67
  prompt,
68
  negative_prompt,
69
+ seed,
70
  randomize_seed,
71
  width,
72
  height,
73
+ guidance_scale,
74
+ num_inference_steps,
75
+ lora_scale, # New parameter for adjusting LoRA scale
 
76
  progress=gr.Progress(track_tqdm=True),
77
  ):
78
+ # Load the pipeline for the chosen model
79
+ pipe = load_pipeline(model_id)
80
+
81
  if randomize_seed:
82
  seed = random.randint(0, MAX_SEED)
83
 
84
+ generator = torch.Generator(device=device).manual_seed(seed)
85
 
86
+ # If using the LoRA model, update the LoRA scale if supported.
87
+ if model_id == "m4r4k0s23/hw5_lora_raccoon":
88
+ # This assumes your pipeline's unet has a method to update the LoRA scale.
89
+ if hasattr(pipe.unet, "set_lora_scale"):
90
+ pipe.unet.set_lora_scale(lora_scale)
91
+ else:
92
+ print("Warning: LoRA scale adjustment method not found on UNet.")
93
 
94
  image = pipe(
95
  prompt=prompt,
 
103
 
104
  return image, seed
105
 
 
106
  examples = [
 
 
 
 
107
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
108
  "An astronaut riding a green horse",
109
  "A delicious ceviche cheesecake slice",
 
118
 
119
  with gr.Blocks(css=css) as demo:
120
  with gr.Column(elem_id="col-container"):
121
+ gr.Markdown(" # Text-to-Image Gradio Template")
122
+
123
+ with gr.Row():
124
+ # Dropdown to select the model from Hugging Face
125
+ model_id = gr.Dropdown(
126
+ label="Model",
127
+ choices=MODEL_LIST,
128
+ value=MODEL_LIST[0], # Default model
129
+ )
130
 
131
  with gr.Row():
132
  prompt = gr.Text(
 
142
  result = gr.Image(label="Result", show_label=False)
143
 
144
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  negative_prompt = gr.Text(
146
  label="Negative prompt",
147
  max_lines=1,
148
  placeholder="Enter a negative prompt",
 
149
  )
150
 
151
  seed = gr.Slider(
 
153
  minimum=0,
154
  maximum=MAX_SEED,
155
  step=1,
156
+ value=42, # Default seed
157
  )
158
 
159
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
160
 
161
  with gr.Row():
162
  width = gr.Slider(
 
164
  minimum=256,
165
  maximum=MAX_IMAGE_SIZE,
166
  step=32,
167
+ value=1024,
168
  )
169
 
170
  height = gr.Slider(
 
172
  minimum=256,
173
  maximum=MAX_IMAGE_SIZE,
174
  step=32,
175
+ value=1024,
176
  )
177
 
178
  with gr.Row():
179
  guidance_scale = gr.Slider(
180
  label="Guidance scale",
181
  minimum=0.0,
182
+ maximum=20.0,
183
+ step=0.5,
184
+ value=7.0,
185
  )
186
 
187
  num_inference_steps = gr.Slider(
188
  label="Number of inference steps",
189
  minimum=1,
190
+ maximum=100,
191
  step=1,
192
+ value=20,
193
  )
194
 
195
+ # New slider for LoRA scale.
196
+ lora_scale = gr.Slider(
197
+ label="LoRA Scale",
198
+ minimum=0.0,
199
+ maximum=2.0,
200
+ step=0.1,
201
+ value=1.0,
202
+ info="Adjust the influence of the LoRA weights",
203
+ )
204
+
205
  gr.Examples(examples=examples, inputs=[prompt])
206
  gr.on(
207
  triggers=[run_button.click, prompt.submit],
208
  fn=infer,
209
  inputs=[
210
+ model_id,
211
  prompt,
212
  negative_prompt,
213
+ seed,
214
  randomize_seed,
215
  width,
216
  height,
 
 
217
  guidance_scale,
218
  num_inference_steps,
219
+ lora_scale, # Pass the new slider value
220
  ],
221
  outputs=[result, seed],
222
  )