AlekseyCalvin commited on
Commit
ba7b5c8
1 Parent(s): 988167d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -12,19 +12,29 @@ import random
12
  import time
13
  from huggingface_hub import hf_hub_download
14
  from diffusers import FluxTransformer2DModel, FluxPipeline
 
15
  import safetensors.torch
16
  from safetensors.torch import load_file
17
  import gc
 
 
 
18
 
19
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
20
  os.environ["TRANSFORMERS_CACHE"] = cache_path
21
  os.environ["HF_HUB_CACHE"] = cache_path
22
  os.environ["HF_HOME"] = cache_path
23
 
24
-
25
  torch.backends.cuda.matmul.allow_tf32 = True
26
 
27
- pipe = FluxPipeline.from_pretrained("John6666/fastflux-unchained-t5f16-fp8-flux", torch_dtype=torch.bfloat16)
 
 
 
 
 
 
 
28
  pipe.to(device="cuda", dtype=torch.bfloat16)
29
 
30
  # Load LoRAs from JSON file
 
12
  import time
13
  from huggingface_hub import hf_hub_download
14
  from diffusers import FluxTransformer2DModel, FluxPipeline
15
+ from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
16
  import safetensors.torch
17
  from safetensors.torch import load_file
18
  import gc
19
+ from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
20
+ from tea_model import TeaDecoder
21
+ from text_encoder import t5_config, T5EncoderModel, PretrainedTextEncoder
22
 
23
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
24
  os.environ["TRANSFORMERS_CACHE"] = cache_path
25
  os.environ["HF_HUB_CACHE"] = cache_path
26
  os.environ["HF_HOME"] = cache_path
27
 
 
28
  torch.backends.cuda.matmul.allow_tf32 = True
29
 
30
+ class Flux2DModel(QuantizedDiffusersModel):
31
+ base_class = FluxTransformer2DModel
32
+
33
+ if __name__ == '__main__':
34
+ t5 = PretrainedTextEncoder(t5_config, T5EncoderModel(t5_config)).to(dtype=torch.float16)
35
+ t5.load_model('text_encoder_2.safetensors')
36
+
37
+ pipe = FluxPipeline.from_pretrained("John6666/fastflux-unchained-t5f16-fp8-flux", torch_dtype=torch.bfloat16, text_encoder_2=t5)
38
  pipe.to(device="cuda", dtype=torch.bfloat16)
39
 
40
  # Load LoRAs from JSON file