jbilcke-hf HF staff commited on
Commit
be29b01
·
verified ·
1 Parent(s): 1ad507b

Upload 31 files

Browse files
hyvideo/config.py CHANGED
@@ -346,6 +346,12 @@ def add_inference_args(parser: argparse.ArgumentParser):
346
  help="Embeded classifier free guidance scale.",
347
  )
348
 
 
 
 
 
 
 
349
  group.add_argument(
350
  "--reproduce",
351
  action="store_true",
 
346
  help="Embeded classifier free guidance scale.",
347
  )
348
 
349
+ group.add_argument(
350
+ "--use-fp8",
351
+ action="store_true",
352
+ help="Enable use fp8 for inference acceleration."
353
+ )
354
+
355
  group.add_argument(
356
  "--reproduce",
357
  action="store_true",
hyvideo/inference.py CHANGED
@@ -15,6 +15,7 @@ from hyvideo.modules import load_model
15
  from hyvideo.text_encoder import TextEncoder
16
  from hyvideo.utils.data_utils import align_to
17
  from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
 
18
  from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
19
  from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
20
 
@@ -196,6 +197,8 @@ class Inference(object):
196
  out_channels=out_channels,
197
  factor_kwargs=factor_kwargs,
198
  )
 
 
199
  model = model.to(device)
200
  model = Inference.load_state_dict(args, model, pretrained_model_path)
201
  model.eval()
@@ -402,6 +405,8 @@ class HunyuanVideoSampler(Inference):
402
  )
403
 
404
  self.default_negative_prompt = NEGATIVE_PROMPT
 
 
405
 
406
  def load_diffusion_pipeline(
407
  self,
@@ -521,12 +526,6 @@ class HunyuanVideoSampler(Inference):
521
  num_images_per_prompt (int): The number of images per prompt. Default is 1.
522
  infer_steps (int): The number of inference steps. Default is 100.
523
  """
524
- if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
525
- assert seed is not None, \
526
- "You have to set a seed in the distributed environment, please rerun with --seed <your-seed>."
527
-
528
- parallelize_transformer(self.pipeline)
529
-
530
  out_dict = dict()
531
 
532
  # ========================================================================
 
15
  from hyvideo.text_encoder import TextEncoder
16
  from hyvideo.utils.data_utils import align_to
17
  from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
18
+ from hyvideo.modules.fp8_optimization import convert_fp8_linear
19
  from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
20
  from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
21
 
 
197
  out_channels=out_channels,
198
  factor_kwargs=factor_kwargs,
199
  )
200
+ if args.use_fp8:
201
+ convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision])
202
  model = model.to(device)
203
  model = Inference.load_state_dict(args, model, pretrained_model_path)
204
  model.eval()
 
405
  )
406
 
407
  self.default_negative_prompt = NEGATIVE_PROMPT
408
+ if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
409
+ parallelize_transformer(self.pipeline)
410
 
411
  def load_diffusion_pipeline(
412
  self,
 
526
  num_images_per_prompt (int): The number of images per prompt. Default is 1.
527
  infer_steps (int): The number of inference steps. Default is 100.
528
  """
 
 
 
 
 
 
529
  out_dict = dict()
530
 
531
  # ========================================================================
hyvideo/modules/fp8_optimization.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
8
+ _bits = torch.tensor(bits)
9
+ _mantissa_bit = torch.tensor(mantissa_bit)
10
+ _sign_bits = torch.tensor(sign_bits)
11
+ M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
12
+ E = _bits - _sign_bits - M
13
+ bias = 2 ** (E - 1) - 1
14
+ mantissa = 1
15
+ for i in range(mantissa_bit - 1):
16
+ mantissa += 1 / (2 ** (i+1))
17
+ maxval = mantissa * 2 ** (2**E - 1 - bias)
18
+ return maxval
19
+
20
+ def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
21
+ """
22
+ Default is E4M3.
23
+ """
24
+ bits = torch.tensor(bits)
25
+ mantissa_bit = torch.tensor(mantissa_bit)
26
+ sign_bits = torch.tensor(sign_bits)
27
+ M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
28
+ E = bits - sign_bits - M
29
+ bias = 2 ** (E - 1) - 1
30
+ mantissa = 1
31
+ for i in range(mantissa_bit - 1):
32
+ mantissa += 1 / (2 ** (i+1))
33
+ maxval = mantissa * 2 ** (2**E - 1 - bias)
34
+ minval = - maxval
35
+ minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
36
+ input_clamp = torch.min(torch.max(x, minval), maxval)
37
+ log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
38
+ log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
39
+ # dequant
40
+ qdq_out = torch.round(input_clamp / log_scales) * log_scales
41
+ return qdq_out, log_scales
42
+
43
+ def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
44
+ for i in range(len(x.shape) - 1):
45
+ scale = scale.unsqueeze(-1)
46
+ new_x = x / scale
47
+ quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
48
+ return quant_dequant_x, scale, log_scales
49
+
50
+ def fp8_activation_dequant(qdq_out, scale, dtype):
51
+ qdq_out = qdq_out.type(dtype)
52
+ quant_dequant_x = qdq_out * scale.to(dtype)
53
+ return quant_dequant_x
54
+
55
+ def fp8_linear_forward(cls, original_dtype, input):
56
+ weight_dtype = cls.weight.dtype
57
+ #####
58
+ if cls.weight.dtype != torch.float8_e4m3fn:
59
+ maxval = get_fp_maxval()
60
+ scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
61
+ linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
62
+ linear_weight = linear_weight.to(torch.float8_e4m3fn)
63
+ weight_dtype = linear_weight.dtype
64
+ else:
65
+ scale = cls.fp8_scale.to(cls.weight.device)
66
+ linear_weight = cls.weight
67
+ #####
68
+
69
+ if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
70
+ if True or len(input.shape) == 3:
71
+ cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
72
+ if cls.bias != None:
73
+ output = F.linear(input, cls_dequant, cls.bias)
74
+ else:
75
+ output = F.linear(input, cls_dequant)
76
+ return output
77
+ else:
78
+ return cls.original_forward(input.to(original_dtype))
79
+ else:
80
+ return cls.original_forward(input)
81
+
82
+ def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
83
+ setattr(module, "fp8_matmul_enabled", True)
84
+
85
+ # loading fp8 mapping file
86
+ fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
87
+ if os.path.exists(fp8_map_path):
88
+ fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
89
+ else:
90
+ raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
91
+
92
+ fp8_layers = []
93
+ for key, layer in module.named_modules():
94
+ if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
95
+ fp8_layers.append(key)
96
+ original_forward = layer.forward
97
+ layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
98
+ setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
99
+ setattr(layer, "original_forward", original_forward)
100
+ setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
101
+
102
+