multimodalart HF staff commited on
Commit
4bdeba0
·
verified ·
1 Parent(s): a992e53

Activate live previews

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -3,32 +3,39 @@ import numpy as np
3
  import random
4
  import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
7
  from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
 
8
 
9
  dtype = torch.bfloat16
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
 
 
13
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
 
17
- @spaces.GPU(duration=190)
18
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
 
 
19
  if randomize_seed:
20
  seed = random.randint(0, MAX_SEED)
21
  generator = torch.Generator().manual_seed(seed)
22
- image = pipe(
23
- prompt = prompt,
24
- width = width,
25
- height = height,
26
- num_inference_steps = num_inference_steps,
27
- generator = generator,
28
- guidance_scale=guidance_scale
29
- ).images[0]
30
- return image, seed
31
-
 
 
32
  examples = [
33
  "a tiny astronaut hatching from an egg on the moon",
34
  "a cat holding a sign that says hello world",
 
3
  import random
4
  import spaces
5
  import torch
6
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny
7
  from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
 
10
  dtype = torch.bfloat16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
14
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
15
+ torch.cuda.empty_cache()
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 2048
19
 
20
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
21
+
22
+ @spaces.GPU(duration=75)
23
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
24
  if randomize_seed:
25
  seed = random.randint(0, MAX_SEED)
26
  generator = torch.Generator().manual_seed(seed)
27
+
28
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
29
+ prompt=prompt,
30
+ guidance_scale=guidance_scale,
31
+ num_inference_steps=num_inference_steps,
32
+ width=width,
33
+ height=height,
34
+ generator=generator,
35
+ output_type="pil",
36
+ ):
37
+ yield img, seed
38
+
39
  examples = [
40
  "a tiny astronaut hatching from an egg on the moon",
41
  "a cat holding a sign that says hello world",