Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import copy | |
from typing import TYPE_CHECKING, Dict, List, Union | |
from ..utils import logging | |
if TYPE_CHECKING: | |
# import here to avoid circular imports | |
from ..models import UNet2DConditionModel | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
def _translate_into_actual_layer_name(name): | |
"""Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')""" | |
if name == "mid": | |
return "mid_block.attentions.0" | |
updown, block, attn = name.split(".") | |
updown = updown.replace("down", "down_blocks").replace("up", "up_blocks") | |
block = block.replace("block_", "") | |
attn = "attentions." + attn | |
return ".".join((updown, block, attn)) | |
def _maybe_expand_lora_scales( | |
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0 | |
): | |
blocks_with_transformer = { | |
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")], | |
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")], | |
} | |
transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1} | |
expanded_weight_scales = [ | |
_maybe_expand_lora_scales_for_one_adapter( | |
weight_for_adapter, | |
blocks_with_transformer, | |
transformer_per_block, | |
unet.state_dict(), | |
default_scale=default_scale, | |
) | |
for weight_for_adapter in weight_scales | |
] | |
return expanded_weight_scales | |
def _maybe_expand_lora_scales_for_one_adapter( | |
scales: Union[float, Dict], | |
blocks_with_transformer: Dict[str, int], | |
transformer_per_block: Dict[str, int], | |
state_dict: None, | |
default_scale: float = 1.0, | |
): | |
""" | |
Expands the inputs into a more granular dictionary. See the example below for more details. | |
Parameters: | |
scales (`Union[float, Dict]`): | |
Scales dict to expand. | |
blocks_with_transformer (`Dict[str, int]`): | |
Dict with keys 'up' and 'down', showing which blocks have transformer layers | |
transformer_per_block (`Dict[str, int]`): | |
Dict with keys 'up' and 'down', showing how many transformer layers each block has | |
E.g. turns | |
```python | |
scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}} | |
blocks_with_transformer = {"down": [1, 2], "up": [0, 1]} | |
transformer_per_block = {"down": 2, "up": 3} | |
``` | |
into | |
```python | |
{ | |
"down.block_1.0": 2, | |
"down.block_1.1": 2, | |
"down.block_2.0": 2, | |
"down.block_2.1": 2, | |
"mid": 3, | |
"up.block_0.0": 4, | |
"up.block_0.1": 4, | |
"up.block_0.2": 4, | |
"up.block_1.0": 5, | |
"up.block_1.1": 6, | |
"up.block_1.2": 7, | |
} | |
``` | |
""" | |
if sorted(blocks_with_transformer.keys()) != ["down", "up"]: | |
raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`") | |
if sorted(transformer_per_block.keys()) != ["down", "up"]: | |
raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`") | |
if not isinstance(scales, dict): | |
# don't expand if scales is a single number | |
return scales | |
scales = copy.deepcopy(scales) | |
if "mid" not in scales: | |
scales["mid"] = default_scale | |
elif isinstance(scales["mid"], list): | |
if len(scales["mid"]) == 1: | |
scales["mid"] = scales["mid"][0] | |
else: | |
raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.") | |
for updown in ["up", "down"]: | |
if updown not in scales: | |
scales[updown] = default_scale | |
# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}} | |
if not isinstance(scales[updown], dict): | |
scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]} | |
# eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}} | |
for i in blocks_with_transformer[updown]: | |
block = f"block_{i}" | |
# set not assigned blocks to default scale | |
if block not in scales[updown]: | |
scales[updown][block] = default_scale | |
if not isinstance(scales[updown][block], list): | |
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])] | |
elif len(scales[updown][block]) == 1: | |
# a list specifying scale to each masked IP input | |
scales[updown][block] = scales[updown][block] * transformer_per_block[updown] | |
elif len(scales[updown][block]) != transformer_per_block[updown]: | |
raise ValueError( | |
f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}." | |
) | |
# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1} | |
for i in blocks_with_transformer[updown]: | |
block = f"block_{i}" | |
for tf_idx, value in enumerate(scales[updown][block]): | |
scales[f"{updown}.{block}.{tf_idx}"] = value | |
del scales[updown] | |
for layer in scales.keys(): | |
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()): | |
raise ValueError( | |
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions." | |
) | |
return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()} | |