Spaces:
Running
on
Zero
Running
on
Zero
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import List, Union | |
from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available | |
class PeftAdapterMixin: | |
""" | |
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For | |
more details about adapters and injecting them in a transformer-based model, check out the PEFT | |
[documentation](https://huggingface.co./docs/peft/index). | |
Install the latest version of PEFT, and use this mixin to: | |
- Attach new adapters in the model. | |
- Attach multiple adapters and iteratively activate/deactivate them. | |
- Activate/deactivate all adapters from the model. | |
- Get a list of the active adapters. | |
""" | |
_hf_peft_config_loaded = False | |
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: | |
r""" | |
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned | |
to the adapter to follow the convention of the PEFT library. | |
If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT | |
[documentation](https://huggingface.co./docs/peft). | |
Args: | |
adapter_config (`[~peft.PeftConfig]`): | |
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt | |
methods. | |
adapter_name (`str`, *optional*, defaults to `"default"`): | |
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. | |
""" | |
check_peft_version(min_version=MIN_PEFT_VERSION) | |
if not is_peft_available(): | |
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") | |
from peft import PeftConfig, inject_adapter_in_model | |
if not self._hf_peft_config_loaded: | |
self._hf_peft_config_loaded = True | |
elif adapter_name in self.peft_config: | |
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") | |
if not isinstance(adapter_config, PeftConfig): | |
raise ValueError( | |
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." | |
) | |
# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is | |
# handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here. | |
adapter_config.base_model_name_or_path = None | |
inject_adapter_in_model(adapter_config, self, adapter_name) | |
self.set_adapter(adapter_name) | |
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: | |
""" | |
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. | |
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | |
[documentation](https://huggingface.co./docs/peft). | |
Args: | |
adapter_name (Union[str, List[str]])): | |
The list of adapters to set or the adapter name in the case of a single adapter. | |
""" | |
check_peft_version(min_version=MIN_PEFT_VERSION) | |
if not self._hf_peft_config_loaded: | |
raise ValueError("No adapter loaded. Please load an adapter first.") | |
if isinstance(adapter_name, str): | |
adapter_name = [adapter_name] | |
missing = set(adapter_name) - set(self.peft_config) | |
if len(missing) > 0: | |
raise ValueError( | |
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." | |
f" current loaded adapters are: {list(self.peft_config.keys())}" | |
) | |
from peft.tuners.tuners_utils import BaseTunerLayer | |
_adapters_has_been_set = False | |
for _, module in self.named_modules(): | |
if isinstance(module, BaseTunerLayer): | |
if hasattr(module, "set_adapter"): | |
module.set_adapter(adapter_name) | |
# Previous versions of PEFT does not support multi-adapter inference | |
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: | |
raise ValueError( | |
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." | |
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" | |
) | |
else: | |
module.active_adapter = adapter_name | |
_adapters_has_been_set = True | |
if not _adapters_has_been_set: | |
raise ValueError( | |
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." | |
) | |
def disable_adapters(self) -> None: | |
r""" | |
Disable all adapters attached to the model and fallback to inference with the base model only. | |
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | |
[documentation](https://huggingface.co./docs/peft). | |
""" | |
check_peft_version(min_version=MIN_PEFT_VERSION) | |
if not self._hf_peft_config_loaded: | |
raise ValueError("No adapter loaded. Please load an adapter first.") | |
from peft.tuners.tuners_utils import BaseTunerLayer | |
for _, module in self.named_modules(): | |
if isinstance(module, BaseTunerLayer): | |
if hasattr(module, "enable_adapters"): | |
module.enable_adapters(enabled=False) | |
else: | |
# support for older PEFT versions | |
module.disable_adapters = True | |
def enable_adapters(self) -> None: | |
""" | |
Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of | |
adapters to enable. | |
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | |
[documentation](https://huggingface.co./docs/peft). | |
""" | |
check_peft_version(min_version=MIN_PEFT_VERSION) | |
if not self._hf_peft_config_loaded: | |
raise ValueError("No adapter loaded. Please load an adapter first.") | |
from peft.tuners.tuners_utils import BaseTunerLayer | |
for _, module in self.named_modules(): | |
if isinstance(module, BaseTunerLayer): | |
if hasattr(module, "enable_adapters"): | |
module.enable_adapters(enabled=True) | |
else: | |
# support for older PEFT versions | |
module.disable_adapters = False | |
def active_adapters(self) -> List[str]: | |
""" | |
Gets the current list of active adapters of the model. | |
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | |
[documentation](https://huggingface.co./docs/peft). | |
""" | |
check_peft_version(min_version=MIN_PEFT_VERSION) | |
if not is_peft_available(): | |
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") | |
if not self._hf_peft_config_loaded: | |
raise ValueError("No adapter loaded. Please load an adapter first.") | |
from peft.tuners.tuners_utils import BaseTunerLayer | |
for _, module in self.named_modules(): | |
if isinstance(module, BaseTunerLayer): | |
return module.active_adapter | |