Upload 31 files
Browse files- hyvideo/config.py +6 -0
- hyvideo/inference.py +5 -6
- hyvideo/modules/fp8_optimization.py +102 -0
hyvideo/config.py
CHANGED
@@ -346,6 +346,12 @@ def add_inference_args(parser: argparse.ArgumentParser):
|
|
346 |
help="Embeded classifier free guidance scale.",
|
347 |
)
|
348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
group.add_argument(
|
350 |
"--reproduce",
|
351 |
action="store_true",
|
|
|
346 |
help="Embeded classifier free guidance scale.",
|
347 |
)
|
348 |
|
349 |
+
group.add_argument(
|
350 |
+
"--use-fp8",
|
351 |
+
action="store_true",
|
352 |
+
help="Enable use fp8 for inference acceleration."
|
353 |
+
)
|
354 |
+
|
355 |
group.add_argument(
|
356 |
"--reproduce",
|
357 |
action="store_true",
|
hyvideo/inference.py
CHANGED
@@ -15,6 +15,7 @@ from hyvideo.modules import load_model
|
|
15 |
from hyvideo.text_encoder import TextEncoder
|
16 |
from hyvideo.utils.data_utils import align_to
|
17 |
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
|
|
|
18 |
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
|
19 |
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
|
20 |
|
@@ -196,6 +197,8 @@ class Inference(object):
|
|
196 |
out_channels=out_channels,
|
197 |
factor_kwargs=factor_kwargs,
|
198 |
)
|
|
|
|
|
199 |
model = model.to(device)
|
200 |
model = Inference.load_state_dict(args, model, pretrained_model_path)
|
201 |
model.eval()
|
@@ -402,6 +405,8 @@ class HunyuanVideoSampler(Inference):
|
|
402 |
)
|
403 |
|
404 |
self.default_negative_prompt = NEGATIVE_PROMPT
|
|
|
|
|
405 |
|
406 |
def load_diffusion_pipeline(
|
407 |
self,
|
@@ -521,12 +526,6 @@ class HunyuanVideoSampler(Inference):
|
|
521 |
num_images_per_prompt (int): The number of images per prompt. Default is 1.
|
522 |
infer_steps (int): The number of inference steps. Default is 100.
|
523 |
"""
|
524 |
-
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
|
525 |
-
assert seed is not None, \
|
526 |
-
"You have to set a seed in the distributed environment, please rerun with --seed <your-seed>."
|
527 |
-
|
528 |
-
parallelize_transformer(self.pipeline)
|
529 |
-
|
530 |
out_dict = dict()
|
531 |
|
532 |
# ========================================================================
|
|
|
15 |
from hyvideo.text_encoder import TextEncoder
|
16 |
from hyvideo.utils.data_utils import align_to
|
17 |
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
|
18 |
+
from hyvideo.modules.fp8_optimization import convert_fp8_linear
|
19 |
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
|
20 |
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
|
21 |
|
|
|
197 |
out_channels=out_channels,
|
198 |
factor_kwargs=factor_kwargs,
|
199 |
)
|
200 |
+
if args.use_fp8:
|
201 |
+
convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision])
|
202 |
model = model.to(device)
|
203 |
model = Inference.load_state_dict(args, model, pretrained_model_path)
|
204 |
model.eval()
|
|
|
405 |
)
|
406 |
|
407 |
self.default_negative_prompt = NEGATIVE_PROMPT
|
408 |
+
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
|
409 |
+
parallelize_transformer(self.pipeline)
|
410 |
|
411 |
def load_diffusion_pipeline(
|
412 |
self,
|
|
|
526 |
num_images_per_prompt (int): The number of images per prompt. Default is 1.
|
527 |
infer_steps (int): The number of inference steps. Default is 100.
|
528 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
out_dict = dict()
|
530 |
|
531 |
# ========================================================================
|
hyvideo/modules/fp8_optimization.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
|
8 |
+
_bits = torch.tensor(bits)
|
9 |
+
_mantissa_bit = torch.tensor(mantissa_bit)
|
10 |
+
_sign_bits = torch.tensor(sign_bits)
|
11 |
+
M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
|
12 |
+
E = _bits - _sign_bits - M
|
13 |
+
bias = 2 ** (E - 1) - 1
|
14 |
+
mantissa = 1
|
15 |
+
for i in range(mantissa_bit - 1):
|
16 |
+
mantissa += 1 / (2 ** (i+1))
|
17 |
+
maxval = mantissa * 2 ** (2**E - 1 - bias)
|
18 |
+
return maxval
|
19 |
+
|
20 |
+
def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
|
21 |
+
"""
|
22 |
+
Default is E4M3.
|
23 |
+
"""
|
24 |
+
bits = torch.tensor(bits)
|
25 |
+
mantissa_bit = torch.tensor(mantissa_bit)
|
26 |
+
sign_bits = torch.tensor(sign_bits)
|
27 |
+
M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
|
28 |
+
E = bits - sign_bits - M
|
29 |
+
bias = 2 ** (E - 1) - 1
|
30 |
+
mantissa = 1
|
31 |
+
for i in range(mantissa_bit - 1):
|
32 |
+
mantissa += 1 / (2 ** (i+1))
|
33 |
+
maxval = mantissa * 2 ** (2**E - 1 - bias)
|
34 |
+
minval = - maxval
|
35 |
+
minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
|
36 |
+
input_clamp = torch.min(torch.max(x, minval), maxval)
|
37 |
+
log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
|
38 |
+
log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
|
39 |
+
# dequant
|
40 |
+
qdq_out = torch.round(input_clamp / log_scales) * log_scales
|
41 |
+
return qdq_out, log_scales
|
42 |
+
|
43 |
+
def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
|
44 |
+
for i in range(len(x.shape) - 1):
|
45 |
+
scale = scale.unsqueeze(-1)
|
46 |
+
new_x = x / scale
|
47 |
+
quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
|
48 |
+
return quant_dequant_x, scale, log_scales
|
49 |
+
|
50 |
+
def fp8_activation_dequant(qdq_out, scale, dtype):
|
51 |
+
qdq_out = qdq_out.type(dtype)
|
52 |
+
quant_dequant_x = qdq_out * scale.to(dtype)
|
53 |
+
return quant_dequant_x
|
54 |
+
|
55 |
+
def fp8_linear_forward(cls, original_dtype, input):
|
56 |
+
weight_dtype = cls.weight.dtype
|
57 |
+
#####
|
58 |
+
if cls.weight.dtype != torch.float8_e4m3fn:
|
59 |
+
maxval = get_fp_maxval()
|
60 |
+
scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
|
61 |
+
linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
|
62 |
+
linear_weight = linear_weight.to(torch.float8_e4m3fn)
|
63 |
+
weight_dtype = linear_weight.dtype
|
64 |
+
else:
|
65 |
+
scale = cls.fp8_scale.to(cls.weight.device)
|
66 |
+
linear_weight = cls.weight
|
67 |
+
#####
|
68 |
+
|
69 |
+
if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
|
70 |
+
if True or len(input.shape) == 3:
|
71 |
+
cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
|
72 |
+
if cls.bias != None:
|
73 |
+
output = F.linear(input, cls_dequant, cls.bias)
|
74 |
+
else:
|
75 |
+
output = F.linear(input, cls_dequant)
|
76 |
+
return output
|
77 |
+
else:
|
78 |
+
return cls.original_forward(input.to(original_dtype))
|
79 |
+
else:
|
80 |
+
return cls.original_forward(input)
|
81 |
+
|
82 |
+
def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
|
83 |
+
setattr(module, "fp8_matmul_enabled", True)
|
84 |
+
|
85 |
+
# loading fp8 mapping file
|
86 |
+
fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
|
87 |
+
if os.path.exists(fp8_map_path):
|
88 |
+
fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
|
89 |
+
else:
|
90 |
+
raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
|
91 |
+
|
92 |
+
fp8_layers = []
|
93 |
+
for key, layer in module.named_modules():
|
94 |
+
if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
|
95 |
+
fp8_layers.append(key)
|
96 |
+
original_forward = layer.forward
|
97 |
+
layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
|
98 |
+
setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
|
99 |
+
setattr(layer, "original_forward", original_forward)
|
100 |
+
setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
|
101 |
+
|
102 |
+
|