Prime Cai commited on
Commit
c3187f1
·
1 Parent(s): c148614

dynamic gpu

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -13,12 +13,15 @@ pipe = None
13
 
14
  CHECKPOINT = "primecai/dsd_model"
15
 
 
 
 
16
  def init_pipeline():
17
  global pipe
18
  transformer = FluxTransformer2DConditionalModel.from_pretrained(
19
  CHECKPOINT,
20
  subfolder="transformer",
21
- torch_dtype=torch.bfloat16,
22
  low_cpu_mem_usage=False,
23
  ignore_mismatched_sizes=True,
24
  use_auth_token=os.getenv("HF_TOKEN"),
@@ -26,7 +29,7 @@ def init_pipeline():
26
  pipe = FluxConditionalPipeline.from_pretrained(
27
  "black-forest-labs/FLUX.1-dev",
28
  transformer=transformer,
29
- torch_dtype=torch.bfloat16,
30
  use_auth_token=os.getenv("HF_TOKEN"),
31
  )
32
  pipe.load_lora_weights(
@@ -34,7 +37,7 @@ def init_pipeline():
34
  weight_name="pytorch_lora_weights.safetensors",
35
  use_auth_token=os.getenv("HF_TOKEN"),
36
  )
37
- pipe.to("cuda")
38
 
39
 
40
  def process_image_and_text(image, text, gemini_prompt, guidance, i_guidance, t_guidance):
 
13
 
14
  CHECKPOINT = "primecai/dsd_model"
15
 
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
18
+
19
  def init_pipeline():
20
  global pipe
21
  transformer = FluxTransformer2DConditionalModel.from_pretrained(
22
  CHECKPOINT,
23
  subfolder="transformer",
24
+ torch_dtype=dtype,
25
  low_cpu_mem_usage=False,
26
  ignore_mismatched_sizes=True,
27
  use_auth_token=os.getenv("HF_TOKEN"),
 
29
  pipe = FluxConditionalPipeline.from_pretrained(
30
  "black-forest-labs/FLUX.1-dev",
31
  transformer=transformer,
32
+ torch_dtype=dtype,
33
  use_auth_token=os.getenv("HF_TOKEN"),
34
  )
35
  pipe.load_lora_weights(
 
37
  weight_name="pytorch_lora_weights.safetensors",
38
  use_auth_token=os.getenv("HF_TOKEN"),
39
  )
40
+ pipe.to(device, dtype=dtype)
41
 
42
 
43
  def process_image_and_text(image, text, gemini_prompt, guidance, i_guidance, t_guidance):