Surn commited on
Commit
8e7b045
·
1 Parent(s): eb4b77d

Fall Back from update

Browse files
Files changed (1) hide show
  1. app.py +188 -214
app.py CHANGED
@@ -8,7 +8,6 @@ from typing import Optional, Union, List, Tuple
8
 
9
  from PIL import Image, ImageFilter
10
  import cv2
11
-
12
  import utils.constants as constants
13
 
14
  from haishoku.haishoku import Haishoku
@@ -92,7 +91,6 @@ from utils.version_info import (
92
  #release_torch_resources,
93
  #get_torch_info
94
  )
95
- from src.condition import Condition
96
  import spaces
97
 
98
  input_image_palette = []
@@ -201,24 +199,11 @@ condition_dict = {
201
  "fill": 9,
202
  }
203
 
204
- @spaces.GPU(duration=120, progress=gr.Progress(track_tqdm=True))
205
- def generate_image(pipe, conditions, generate_params, progress=gr.Progress(track_tqdm=True)):
206
- gr.Info("Generating AI image...",duration=5)
207
- result = pipe(**generate_params)
208
- image = result.images[0]
209
- # Clean up
210
- del result
211
- del conditions
212
- # Delete the pipeline and clear cache
213
- del pipe
214
- torch.cuda.empty_cache()
215
- torch.cuda.ipc_collect()
216
- print(torch.cuda.memory_summary(device=None, abbreviated=False))
217
- return image
218
-
219
 
220
- @spaces.GPU(duration=90)
221
- @torch.no_grad()
222
  def generate_image_lowmem(
223
  text,
224
  neg_prompt=None,
@@ -244,205 +229,195 @@ def generate_image_lowmem(
244
  f"Available options: {list(PIPELINE_CLASSES.keys())}")
245
 
246
  #initialize_cuda()
247
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
248
 
249
  print(f"device:{device}\nmodel_name:{model_name}\nlora_weights:{lora_weights}\n")
250
  #print(f"\n {get_torch_info()}\n")
251
  # Disable gradient calculations
252
- #with torch.no_grad():
253
- gr.Info("Initialize the pipeline inside the context manager",duration=5)
254
- # Initialize the pipeline inside the context manager
255
- pipe = pipeline_class.from_pretrained(
256
- model_name,
257
- torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
258
- ).to(device)
259
- # Optionally, don't use CPU offload if not necessary
260
 
261
- # alternative version that may be more efficient
262
- # pipe.enable_sequential_cpu_offload()
263
- if pipeline_name == "FluxPipeline":
264
- pipe.enable_model_cpu_offload()
265
- pipe.vae.enable_slicing()
266
- #pipe.vae.enable_tiling()
267
- else:
268
- pipe.enable_model_cpu_offload()
269
-
270
- # Access the tokenizer from the pipeline
271
- tokenizer = pipe.tokenizer
272
-
273
- # Check if add_prefix_space is set and convert to slow tokenizer if necessary
274
- if getattr(tokenizer, 'add_prefix_space', False):
275
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, device_map = 'cpu')
276
- # Update the pipeline's tokenizer
277
- pipe.tokenizer = tokenizer
278
- pipe.to(device)
279
-
280
- flash_attention_enabled = torch.backends.cuda.flash_sdp_enabled()
281
- if flash_attention_enabled == False:
282
- #Enable xFormers memory-efficient attention (optional)
283
- #pipe.enable_xformers_memory_efficient_attention()
284
- print("\nEnabled xFormers memory-efficient attention.\n")
285
- else:
286
- pipe.attn_implementation="flash_attention_2"
287
- print("\nEnabled flash_attention_2.\n")
288
-
289
- condition_type = "subject"
290
- # Load LoRA weights
291
- # note: does not yet handle multiple LoRA weights with different names, needs .set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
292
- if lora_weights:
293
- for lora_weight in lora_weights:
294
- lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
295
- lora_weight_set = False
296
- if lora_configs:
297
- for config in lora_configs:
298
- # Load LoRA weights with optional weight_name and adapter_name
299
- if 'weight_name' in config:
300
- weight_name = config.get("weight_name")
301
- adapter_name = config.get("adapter_name")
302
- lora_collection = config.get("lora_collection")
303
- if weight_name and adapter_name and lora_collection and lora_weight_set == False:
304
- pipe.load_lora_weights(
305
- lora_collection,
306
- weight_name=weight_name,
307
- adapter_name=adapter_name,
308
- token=constants.HF_API_TOKEN
309
- )
310
- lora_weight_set = True
311
- print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n")
312
- elif weight_name and adapter_name==None and lora_collection and lora_weight_set == False:
313
- pipe.load_lora_weights(
314
- lora_collection,
315
- weight_name=weight_name,
316
- token=constants.HF_API_TOKEN
317
- )
318
- lora_weight_set = True
319
- print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n")
320
- elif weight_name and adapter_name and lora_weight_set == False:
321
- pipe.load_lora_weights(
322
- lora_weight,
323
- weight_name=weight_name,
324
- adapter_name=adapter_name,
325
- token=constants.HF_API_TOKEN
326
- )
327
- lora_weight_set = True
328
- print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
329
- elif weight_name and adapter_name==None and lora_weight_set == False:
330
- pipe.load_lora_weights(
331
- lora_weight,
332
- weight_name=weight_name,
333
- token=constants.HF_API_TOKEN
334
- )
335
- lora_weight_set = True
336
- print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
337
- elif lora_weight_set == False:
338
- pipe.load_lora_weights(
339
- lora_weight,
340
- token=constants.HF_API_TOKEN
341
- )
342
- lora_weight_set = True
343
- print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
344
- # Apply 'pipe' configurations if present
345
- if 'pipe' in config:
346
- pipe_config = config['pipe']
347
- for method_name, params in pipe_config.items():
348
- method = getattr(pipe, method_name, None)
349
- if method:
350
- print(f"Applying pipe method: {method_name} with params: {params}")
351
- method(**params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  else:
353
- print(f"Method {method_name} not found in pipe.")
354
- if 'condition_type' in config:
355
- condition_type = config['condition_type']
356
- if condition_type == "coloring":
357
- #pipe.enable_coloring()
358
- print("\nEnabled coloring.\n")
359
- elif condition_type == "deblurring":
360
- #pipe.enable_deblurring()
361
- print("\nEnabled deblurring.\n")
362
- elif condition_type == "fill":
363
- #pipe.enable_fill()
364
- print("\nEnabled fill.\n")
365
- elif condition_type == "depth":
366
- #pipe.enable_depth()
367
- print("\nEnabled depth.\n")
368
- elif condition_type == "canny":
369
- #pipe.enable_canny()
370
- print("\nEnabled canny.\n")
371
- elif condition_type == "subject":
372
- #pipe.enable_subject()
373
- print("\nEnabled subject.\n")
374
- else:
375
- print(f"Condition type {condition_type} not implemented.")
376
- else:
377
- pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
378
- gr.Info("lora_weights are loaded",duration=5)
379
- # Set the random seed for reproducibility
380
- generator = torch.Generator(device=device).manual_seed(seed)
381
- conditions = []
382
- if conditioned_image is not None:
383
- conditioned_image = crop_and_resize_image(conditioned_image, image_width, image_height)
384
- condition = Condition(condition_type, conditioned_image)
385
- conditions.append(condition)
386
- print(f"\nAdded conditioned image.\n {conditioned_image.size}")
387
- # Prepare the parameters for image generation
388
- additional_parameters ={
389
- "strength": strength,
390
- "image": conditioned_image,
391
- }
392
- else:
393
- print("\nNo conditioned image provided.")
394
- if neg_prompt!=None:
395
- true_cfg_scale=1.1
396
- additional_parameters ={
397
- "negative_prompt": neg_prompt,
398
- "true_cfg_scale": true_cfg_scale,
399
- }
400
- # handle long prompts by splitting them
401
- if approximate_token_count(text) > 76:
402
- prompt, prompt2 = split_prompt_precisely(text)
403
- prompt_parameters = {
404
- "prompt" : prompt,
405
- "prompt_2": prompt2
406
- }
407
- else:
408
- prompt_parameters = {
409
- "prompt" :text
410
- }
411
- additional_parameters.update(prompt_parameters)
412
- # Combine all parameters
413
- generate_params = {
414
- "height": image_height,
415
- "width": image_width,
416
- "guidance_scale": guidance_scale,
417
- "num_inference_steps": num_inference_steps,
418
- "generator": generator,
419
  }
420
- if additional_parameters:
421
- generate_params.update(additional_parameters)
422
- generate_params = {k: v for k, v in generate_params.items() if v is not None}
423
- print(f"generate_params: {generate_params}")
424
- import pickle
425
-
426
- try:
427
- pickle.dumps(pipe)
428
- print("pipe is picklable.\n")
429
- except pickle.PicklingError:
430
- print("pipe is not picklable\n.")
431
-
432
- try:
433
- pickle.dumps(conditions)
434
- print("conditions is picklable.\n")
435
- except pickle.PicklingError:
436
- print("conditions is not picklable.\n")
437
-
438
- try:
439
- pickle.dumps(generator)
440
- print("generator is picklable.\n")
441
- except pickle.PicklingError:
442
- print("generator is not picklable.\n")
443
-
444
- return pipe, conditions, generate_params
445
-
446
 
447
  def generate_ai_image_local (
448
  map_option,
@@ -501,8 +476,8 @@ def generate_ai_image_local (
501
  print(f"Additional Parameters: {additional_parameters}")
502
  print(f"Conditioned Image: {conditioned_image}")
503
  print(f"Conditioned Image Strength: {strength}")
504
- print(f"pipeline: {pipeline_name}\n")
505
- pipe, conditions, generate_params = generate_image_lowmem(
506
  text=prompt,
507
  model_name=model,
508
  neg_prompt=negative_prompt,
@@ -517,7 +492,6 @@ def generate_ai_image_local (
517
  strength=strength,
518
  additional_parameters=additional_parameters
519
  )
520
- image = generate_image(pipe, conditions, **generate_params)
521
  with NamedTemporaryFile(delete=False, suffix=".png") as tmp:
522
  image.save(tmp.name, format="PNG")
523
  constants.temp_files.append(tmp.name)
 
8
 
9
  from PIL import Image, ImageFilter
10
  import cv2
 
11
  import utils.constants as constants
12
 
13
  from haishoku.haishoku import Haishoku
 
91
  #release_torch_resources,
92
  #get_torch_info
93
  )
 
94
  import spaces
95
 
96
  input_image_palette = []
 
199
  "fill": 9,
200
  }
201
 
202
+ # @spaces.GPU(duration=140, progress=gr.Progress(track_tqdm=True))
203
+ # def generate_image(pipe, generate_params, progress=gr.Progress(track_tqdm=True)):
204
+ # return pipe(**generate_params)
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ @spaces.GPU(duration=200, progress=gr.Progress(track_tqdm=True))
 
207
  def generate_image_lowmem(
208
  text,
209
  neg_prompt=None,
 
229
  f"Available options: {list(PIPELINE_CLASSES.keys())}")
230
 
231
  #initialize_cuda()
232
+ device = "cuda" if torch.cuda.is_available() else "cpu"
233
+ from src.condition import Condition
234
 
235
  print(f"device:{device}\nmodel_name:{model_name}\nlora_weights:{lora_weights}\n")
236
  #print(f"\n {get_torch_info()}\n")
237
  # Disable gradient calculations
238
+ with torch.no_grad():
239
+ # Initialize the pipeline inside the context manager
240
+ pipe = pipeline_class.from_pretrained(
241
+ model_name,
242
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
243
+ ).to(device)
244
+ # Optionally, don't use CPU offload if not necessary
 
245
 
246
+ # alternative version that may be more efficient
247
+ # pipe.enable_sequential_cpu_offload()
248
+ if pipeline_name == "FluxPipeline":
249
+ pipe.enable_model_cpu_offload()
250
+ pipe.vae.enable_slicing()
251
+ #pipe.vae.enable_tiling()
252
+ else:
253
+ pipe.enable_model_cpu_offload()
254
+
255
+ # Access the tokenizer from the pipeline
256
+ tokenizer = pipe.tokenizer
257
+
258
+ # Check if add_prefix_space is set and convert to slow tokenizer if necessary
259
+ if getattr(tokenizer, 'add_prefix_space', False):
260
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, device_map = 'cpu')
261
+ # Update the pipeline's tokenizer
262
+ pipe.tokenizer = tokenizer
263
+ pipe.to(device)
264
+
265
+ flash_attention_enabled = torch.backends.cuda.flash_sdp_enabled()
266
+ if flash_attention_enabled == False:
267
+ #Enable xFormers memory-efficient attention (optional)
268
+ #pipe.enable_xformers_memory_efficient_attention()
269
+ print("\nEnabled xFormers memory-efficient attention.\n")
270
+ else:
271
+ pipe.attn_implementation="flash_attention_2"
272
+ print("\nEnabled flash_attention_2.\n")
273
+
274
+ condition_type = "subject"
275
+ # Load LoRA weights
276
+ # note: does not yet handle multiple LoRA weights with different names, needs .set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
277
+ if lora_weights:
278
+ for lora_weight in lora_weights:
279
+ lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
280
+ lora_weight_set = False
281
+ if lora_configs:
282
+ for config in lora_configs:
283
+ # Load LoRA weights with optional weight_name and adapter_name
284
+ if 'weight_name' in config:
285
+ weight_name = config.get("weight_name")
286
+ adapter_name = config.get("adapter_name")
287
+ lora_collection = config.get("lora_collection")
288
+ if weight_name and adapter_name and lora_collection and lora_weight_set == False:
289
+ pipe.load_lora_weights(
290
+ lora_collection,
291
+ weight_name=weight_name,
292
+ adapter_name=adapter_name,
293
+ token=constants.HF_API_TOKEN
294
+ )
295
+ lora_weight_set = True
296
+ print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n")
297
+ elif weight_name and adapter_name==None and lora_collection and lora_weight_set == False:
298
+ pipe.load_lora_weights(
299
+ lora_collection,
300
+ weight_name=weight_name,
301
+ token=constants.HF_API_TOKEN
302
+ )
303
+ lora_weight_set = True
304
+ print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n")
305
+ elif weight_name and adapter_name and lora_weight_set == False:
306
+ pipe.load_lora_weights(
307
+ lora_weight,
308
+ weight_name=weight_name,
309
+ adapter_name=adapter_name,
310
+ token=constants.HF_API_TOKEN
311
+ )
312
+ lora_weight_set = True
313
+ print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
314
+ elif weight_name and adapter_name==None and lora_weight_set == False:
315
+ pipe.load_lora_weights(
316
+ lora_weight,
317
+ weight_name=weight_name,
318
+ token=constants.HF_API_TOKEN
319
+ )
320
+ lora_weight_set = True
321
+ print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
322
+ elif lora_weight_set == False:
323
+ pipe.load_lora_weights(
324
+ lora_weight,
325
+ token=constants.HF_API_TOKEN
326
+ )
327
+ lora_weight_set = True
328
+ print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
329
+ # Apply 'pipe' configurations if present
330
+ if 'pipe' in config:
331
+ pipe_config = config['pipe']
332
+ for method_name, params in pipe_config.items():
333
+ method = getattr(pipe, method_name, None)
334
+ if method:
335
+ print(f"Applying pipe method: {method_name} with params: {params}")
336
+ method(**params)
337
+ else:
338
+ print(f"Method {method_name} not found in pipe.")
339
+ if 'condition_type' in config:
340
+ condition_type = config['condition_type']
341
+ if condition_type == "coloring":
342
+ #pipe.enable_coloring()
343
+ print("\nEnabled coloring.\n")
344
+ elif condition_type == "deblurring":
345
+ #pipe.enable_deblurring()
346
+ print("\nEnabled deblurring.\n")
347
+ elif condition_type == "fill":
348
+ #pipe.enable_fill()
349
+ print("\nEnabled fill.\n")
350
+ elif condition_type == "depth":
351
+ #pipe.enable_depth()
352
+ print("\nEnabled depth.\n")
353
+ elif condition_type == "canny":
354
+ #pipe.enable_canny()
355
+ print("\nEnabled canny.\n")
356
+ elif condition_type == "subject":
357
+ #pipe.enable_subject()
358
+ print("\nEnabled subject.\n")
359
  else:
360
+ print(f"Condition type {condition_type} not implemented.")
361
+ else:
362
+ pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
363
+ # Set the random seed for reproducibility
364
+ generator = torch.Generator(device=device).manual_seed(seed)
365
+ conditions = []
366
+ if conditioned_image is not None:
367
+ conditioned_image = crop_and_resize_image(conditioned_image, image_width, image_height)
368
+ condition = Condition(condition_type, conditioned_image)
369
+ conditions.append(condition)
370
+ print(f"\nAdded conditioned image.\n {conditioned_image.size}")
371
+ # Prepare the parameters for image generation
372
+ additional_parameters ={
373
+ "strength": strength,
374
+ "image": conditioned_image,
375
+ }
376
+ else:
377
+ print("\nNo conditioned image provided.")
378
+ if neg_prompt!=None:
379
+ true_cfg_scale=1.1
380
+ additional_parameters ={
381
+ "negative_prompt": neg_prompt,
382
+ "true_cfg_scale": true_cfg_scale,
383
+ }
384
+ # handle long prompts by splitting them
385
+ if approximate_token_count(text) > 76:
386
+ prompt, prompt2 = split_prompt_precisely(text)
387
+ prompt_parameters = {
388
+ "prompt" : prompt,
389
+ "prompt_2": prompt2
390
+ }
391
+ else:
392
+ prompt_parameters = {
393
+ "prompt" :text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  }
395
+ additional_parameters.update(prompt_parameters)
396
+ # Combine all parameters
397
+ generate_params = {
398
+ "height": image_height,
399
+ "width": image_width,
400
+ "guidance_scale": guidance_scale,
401
+ "num_inference_steps": num_inference_steps,
402
+ "generator": generator, }
403
+ if additional_parameters:
404
+ generate_params.update(additional_parameters)
405
+ generate_params = {k: v for k, v in generate_params.items() if v is not None}
406
+ print(f"generate_params: {generate_params}")
407
+ # Generate the image
408
+ result = pipe(**generate_params) #generate_image(pipe,generate_params)
409
+ image = result.images[0]
410
+ # Clean up
411
+ del result
412
+ del conditions
413
+ del generator
414
+ # Delete the pipeline and clear cache
415
+ del pipe
416
+ torch.cuda.empty_cache()
417
+ torch.cuda.ipc_collect()
418
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
419
+
420
+ return image
421
 
422
  def generate_ai_image_local (
423
  map_option,
 
476
  print(f"Additional Parameters: {additional_parameters}")
477
  print(f"Conditioned Image: {conditioned_image}")
478
  print(f"Conditioned Image Strength: {strength}")
479
+ print(f"pipeline: {pipeline_name}")
480
+ image = generate_image_lowmem(
481
  text=prompt,
482
  model_name=model,
483
  neg_prompt=negative_prompt,
 
492
  strength=strength,
493
  additional_parameters=additional_parameters
494
  )
 
495
  with NamedTemporaryFile(delete=False, suffix=".png") as tmp:
496
  image.save(tmp.name, format="PNG")
497
  constants.temp_files.append(tmp.name)