markury commited on
Commit
bb1ca36
1 Parent(s): a4e4096

add lora weight

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -17,8 +17,8 @@ MAX_IMAGE_SIZE = 2048
17
  # Initialize the pipeline globally
18
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
19
 
20
- @spaces.GPU(duration=190)
21
- def infer(prompt, lora_model, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
22
  global pipe
23
 
24
  # Load LoRA if specified
@@ -39,7 +39,8 @@ def infer(prompt, lora_model, seed=42, randomize_seed=False, width=1024, height=
39
  height=height,
40
  num_inference_steps=num_inference_steps,
41
  generator=generator,
42
- guidance_scale=guidance_scale
 
43
  ).images[0]
44
 
45
  # Unload LoRA weights after generation
@@ -127,6 +128,13 @@ with gr.Blocks(css=css) as demo:
127
  step=1,
128
  value=28,
129
  )
 
 
 
 
 
 
 
130
 
131
  gr.Examples(
132
  examples=examples,
@@ -139,8 +147,8 @@ with gr.Blocks(css=css) as demo:
139
  gr.on(
140
  triggers=[run_button.click, prompt.submit],
141
  fn=infer,
142
- inputs=[prompt, lora_model, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
  outputs=[result, seed, output_message]
144
  )
145
 
146
- demo.launch()
 
17
  # Initialize the pipeline globally
18
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
19
 
20
+ @spaces.GPU(duration=300)
21
+ def infer(prompt, lora_model, lora_weight=1.0, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
22
  global pipe
23
 
24
  # Load LoRA if specified
 
39
  height=height,
40
  num_inference_steps=num_inference_steps,
41
  generator=generator,
42
+ guidance_scale=guidance_scale,
43
+ cross_attention_kwargs={"scale": lora_weight}
44
  ).images[0]
45
 
46
  # Unload LoRA weights after generation
 
128
  step=1,
129
  value=28,
130
  )
131
+ lora_weight = gr.Slider(
132
+ label="LoRA Weight",
133
+ minimum=0,
134
+ maximum=2,
135
+ step=0.01,
136
+ value=1.0,
137
+ )
138
 
139
  gr.Examples(
140
  examples=examples,
 
147
  gr.on(
148
  triggers=[run_button.click, prompt.submit],
149
  fn=infer,
150
+ inputs=[prompt, lora_model, lora_weight, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
151
  outputs=[result, seed, output_message]
152
  )
153
 
154
+ demo.launch()