Spaces:
Running
on
Zero
Running
on
Zero
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import importlib | |
import os | |
import re | |
import warnings | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Union | |
import torch | |
from huggingface_hub import model_info | |
from huggingface_hub.utils import validate_hf_hub_args | |
from packaging import version | |
from .. import __version__ | |
from ..utils import ( | |
FLAX_WEIGHTS_NAME, | |
ONNX_EXTERNAL_WEIGHTS_NAME, | |
ONNX_WEIGHTS_NAME, | |
SAFETENSORS_WEIGHTS_NAME, | |
WEIGHTS_NAME, | |
get_class_from_dynamic_module, | |
is_accelerate_available, | |
is_peft_available, | |
is_transformers_available, | |
logging, | |
) | |
from ..utils.torch_utils import is_compiled_module | |
if is_transformers_available(): | |
import transformers | |
from transformers import PreTrainedModel | |
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME | |
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME | |
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME | |
if is_accelerate_available(): | |
import accelerate | |
from accelerate import dispatch_model | |
from accelerate.hooks import remove_hook_from_module | |
from accelerate.utils import compute_module_sizes, get_max_memory | |
INDEX_FILE = "diffusion_pytorch_model.bin" | |
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" | |
DUMMY_MODULES_FOLDER = "diffusers.utils" | |
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" | |
CONNECTED_PIPES_KEYS = ["prior"] | |
logger = logging.get_logger(__name__) | |
LOADABLE_CLASSES = { | |
"diffusers": { | |
"ModelMixin": ["save_pretrained", "from_pretrained"], | |
"SchedulerMixin": ["save_pretrained", "from_pretrained"], | |
"DiffusionPipeline": ["save_pretrained", "from_pretrained"], | |
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], | |
}, | |
"transformers": { | |
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], | |
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], | |
"PreTrainedModel": ["save_pretrained", "from_pretrained"], | |
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], | |
"ProcessorMixin": ["save_pretrained", "from_pretrained"], | |
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"], | |
}, | |
"onnxruntime.training": { | |
"ORTModule": ["save_pretrained", "from_pretrained"], | |
}, | |
} | |
ALL_IMPORTABLE_CLASSES = {} | |
for library in LOADABLE_CLASSES: | |
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) | |
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: | |
""" | |
Checking for safetensors compatibility: | |
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch | |
files to know which safetensors files are needed. | |
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. | |
Converting default pytorch serialized filenames to safetensors serialized filenames: | |
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors" | |
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" | |
extension is replaced with ".safetensors" | |
""" | |
pt_filenames = [] | |
sf_filenames = set() | |
passed_components = passed_components or [] | |
for filename in filenames: | |
_, extension = os.path.splitext(filename) | |
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: | |
continue | |
if extension == ".bin": | |
pt_filenames.append(os.path.normpath(filename)) | |
elif extension == ".safetensors": | |
sf_filenames.add(os.path.normpath(filename)) | |
for filename in pt_filenames: | |
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam' | |
path, filename = os.path.split(filename) | |
filename, extension = os.path.splitext(filename) | |
if filename.startswith("pytorch_model"): | |
filename = filename.replace("pytorch_model", "model") | |
else: | |
filename = filename | |
expected_sf_filename = os.path.normpath(os.path.join(path, filename)) | |
expected_sf_filename = f"{expected_sf_filename}.safetensors" | |
if expected_sf_filename not in sf_filenames: | |
logger.warning(f"{expected_sf_filename} not found") | |
return False | |
return True | |
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: | |
weight_names = [ | |
WEIGHTS_NAME, | |
SAFETENSORS_WEIGHTS_NAME, | |
FLAX_WEIGHTS_NAME, | |
ONNX_WEIGHTS_NAME, | |
ONNX_EXTERNAL_WEIGHTS_NAME, | |
] | |
if is_transformers_available(): | |
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] | |
# model_pytorch, diffusion_model_pytorch, ... | |
weight_prefixes = [w.split(".")[0] for w in weight_names] | |
# .bin, .safetensors, ... | |
weight_suffixs = [w.split(".")[-1] for w in weight_names] | |
# -00001-of-00002 | |
transformers_index_format = r"\d{5}-of-\d{5}" | |
if variant is not None: | |
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors` | |
variant_file_re = re.compile( | |
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" | |
) | |
# `text_encoder/pytorch_model.bin.index.fp16.json` | |
variant_index_re = re.compile( | |
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" | |
) | |
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` | |
non_variant_file_re = re.compile( | |
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" | |
) | |
# `text_encoder/pytorch_model.bin.index.json` | |
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") | |
if variant is not None: | |
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} | |
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} | |
variant_filenames = variant_weights | variant_indexes | |
else: | |
variant_filenames = set() | |
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} | |
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} | |
non_variant_filenames = non_variant_weights | non_variant_indexes | |
# all variant filenames will be used by default | |
usable_filenames = set(variant_filenames) | |
def convert_to_variant(filename): | |
if "index" in filename: | |
variant_filename = filename.replace("index", f"index.{variant}") | |
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: | |
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" | |
else: | |
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" | |
return variant_filename | |
for f in non_variant_filenames: | |
variant_filename = convert_to_variant(f) | |
if variant_filename not in usable_filenames: | |
usable_filenames.add(f) | |
return usable_filenames, variant_filenames | |
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames): | |
info = model_info( | |
pretrained_model_name_or_path, | |
token=token, | |
revision=None, | |
) | |
filenames = {sibling.rfilename for sibling in info.siblings} | |
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) | |
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] | |
if set(model_filenames).issubset(set(comp_model_filenames)): | |
warnings.warn( | |
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", | |
FutureWarning, | |
) | |
else: | |
warnings.warn( | |
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", | |
FutureWarning, | |
) | |
def _unwrap_model(model): | |
"""Unwraps a model.""" | |
if is_compiled_module(model): | |
model = model._orig_mod | |
if is_peft_available(): | |
from peft import PeftModel | |
if isinstance(model, PeftModel): | |
model = model.base_model.model | |
return model | |
def maybe_raise_or_warn( | |
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module | |
): | |
"""Simple helper method to raise or warn in case incorrect module has been passed""" | |
if not is_pipeline_module: | |
library = importlib.import_module(library_name) | |
class_obj = getattr(library, class_name) | |
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} | |
expected_class_obj = None | |
for class_name, class_candidate in class_candidates.items(): | |
if class_candidate is not None and issubclass(class_obj, class_candidate): | |
expected_class_obj = class_candidate | |
# Dynamo wraps the original model in a private class. | |
# I didn't find a public API to get the original class. | |
sub_model = passed_class_obj[name] | |
unwrapped_sub_model = _unwrap_model(sub_model) | |
model_cls = unwrapped_sub_model.__class__ | |
if not issubclass(model_cls, expected_class_obj): | |
raise ValueError( | |
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" | |
) | |
else: | |
logger.warning( | |
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" | |
" has the correct type" | |
) | |
def get_class_obj_and_candidates( | |
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None | |
): | |
"""Simple helper method to retrieve class object of module as well as potential parent class objects""" | |
component_folder = os.path.join(cache_dir, component_name) | |
if is_pipeline_module: | |
pipeline_module = getattr(pipelines, library_name) | |
class_obj = getattr(pipeline_module, class_name) | |
class_candidates = {c: class_obj for c in importable_classes.keys()} | |
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): | |
# load custom component | |
class_obj = get_class_from_dynamic_module( | |
component_folder, module_file=library_name + ".py", class_name=class_name | |
) | |
class_candidates = {c: class_obj for c in importable_classes.keys()} | |
else: | |
# else we just import it from the library. | |
library = importlib.import_module(library_name) | |
class_obj = getattr(library, class_name) | |
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} | |
return class_obj, class_candidates | |
def _get_custom_pipeline_class( | |
custom_pipeline, | |
repo_id=None, | |
hub_revision=None, | |
class_name=None, | |
cache_dir=None, | |
revision=None, | |
): | |
if custom_pipeline.endswith(".py"): | |
path = Path(custom_pipeline) | |
# decompose into folder & file | |
file_name = path.name | |
custom_pipeline = path.parent.absolute() | |
elif repo_id is not None: | |
file_name = f"{custom_pipeline}.py" | |
custom_pipeline = repo_id | |
else: | |
file_name = CUSTOM_PIPELINE_FILE_NAME | |
if repo_id is not None and hub_revision is not None: | |
# if we load the pipeline code from the Hub | |
# make sure to overwrite the `revision` | |
revision = hub_revision | |
return get_class_from_dynamic_module( | |
custom_pipeline, | |
module_file=file_name, | |
class_name=class_name, | |
cache_dir=cache_dir, | |
revision=revision, | |
) | |
def _get_pipeline_class( | |
class_obj, | |
config=None, | |
load_connected_pipeline=False, | |
custom_pipeline=None, | |
repo_id=None, | |
hub_revision=None, | |
class_name=None, | |
cache_dir=None, | |
revision=None, | |
): | |
if custom_pipeline is not None: | |
return _get_custom_pipeline_class( | |
custom_pipeline, | |
repo_id=repo_id, | |
hub_revision=hub_revision, | |
class_name=class_name, | |
cache_dir=cache_dir, | |
revision=revision, | |
) | |
if class_obj.__name__ != "DiffusionPipeline": | |
return class_obj | |
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) | |
class_name = class_name or config["_class_name"] | |
if not class_name: | |
raise ValueError( | |
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." | |
) | |
class_name = class_name[4:] if class_name.startswith("Flax") else class_name | |
pipeline_cls = getattr(diffusers_module, class_name) | |
if load_connected_pipeline: | |
from .auto_pipeline import _get_connected_pipeline | |
connected_pipeline_cls = _get_connected_pipeline(pipeline_cls) | |
if connected_pipeline_cls is not None: | |
logger.info( | |
f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`" | |
) | |
else: | |
logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.") | |
pipeline_cls = connected_pipeline_cls or pipeline_cls | |
return pipeline_cls | |
def _load_empty_model( | |
library_name: str, | |
class_name: str, | |
importable_classes: List[Any], | |
pipelines: Any, | |
is_pipeline_module: bool, | |
name: str, | |
torch_dtype: Union[str, torch.dtype], | |
cached_folder: Union[str, os.PathLike], | |
**kwargs, | |
): | |
# retrieve class objects. | |
class_obj, _ = get_class_obj_and_candidates( | |
library_name, | |
class_name, | |
importable_classes, | |
pipelines, | |
is_pipeline_module, | |
component_name=name, | |
cache_dir=cached_folder, | |
) | |
if is_transformers_available(): | |
transformers_version = version.parse(version.parse(transformers.__version__).base_version) | |
else: | |
transformers_version = "N/A" | |
# Determine library. | |
is_transformers_model = ( | |
is_transformers_available() | |
and issubclass(class_obj, PreTrainedModel) | |
and transformers_version >= version.parse("4.20.0") | |
) | |
diffusers_module = importlib.import_module(__name__.split(".")[0]) | |
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) | |
model = None | |
config_path = cached_folder | |
user_agent = { | |
"diffusers": __version__, | |
"file_type": "model", | |
"framework": "pytorch", | |
} | |
if is_diffusers_model: | |
# Load config and then the model on meta. | |
config, unused_kwargs, commit_hash = class_obj.load_config( | |
os.path.join(config_path, name), | |
cache_dir=cached_folder, | |
return_unused_kwargs=True, | |
return_commit_hash=True, | |
force_download=kwargs.pop("force_download", False), | |
resume_download=kwargs.pop("resume_download", None), | |
proxies=kwargs.pop("proxies", None), | |
local_files_only=kwargs.pop("local_files_only", False), | |
token=kwargs.pop("token", None), | |
revision=kwargs.pop("revision", None), | |
subfolder=kwargs.pop("subfolder", None), | |
user_agent=user_agent, | |
) | |
with accelerate.init_empty_weights(): | |
model = class_obj.from_config(config, **unused_kwargs) | |
elif is_transformers_model: | |
config_class = getattr(class_obj, "config_class", None) | |
if config_class is None: | |
raise ValueError("`config_class` cannot be None. Please double-check the model.") | |
config = config_class.from_pretrained( | |
cached_folder, | |
subfolder=name, | |
force_download=kwargs.pop("force_download", False), | |
resume_download=kwargs.pop("resume_download", None), | |
proxies=kwargs.pop("proxies", None), | |
local_files_only=kwargs.pop("local_files_only", False), | |
token=kwargs.pop("token", None), | |
revision=kwargs.pop("revision", None), | |
user_agent=user_agent, | |
) | |
with accelerate.init_empty_weights(): | |
model = class_obj(config) | |
if model is not None: | |
model = model.to(dtype=torch_dtype) | |
return model | |
def _assign_components_to_devices( | |
module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced" | |
): | |
device_ids = list(device_memory.keys()) | |
device_cycle = device_ids + device_ids[::-1] | |
device_memory = device_memory.copy() | |
device_id_component_mapping = {} | |
current_device_index = 0 | |
for component in module_sizes: | |
device_id = device_cycle[current_device_index % len(device_cycle)] | |
component_memory = module_sizes[component] | |
curr_device_memory = device_memory[device_id] | |
# If the GPU doesn't fit the current component offload to the CPU. | |
if component_memory > curr_device_memory: | |
device_id_component_mapping["cpu"] = [component] | |
else: | |
if device_id not in device_id_component_mapping: | |
device_id_component_mapping[device_id] = [component] | |
else: | |
device_id_component_mapping[device_id].append(component) | |
# Update the device memory. | |
device_memory[device_id] -= component_memory | |
current_device_index += 1 | |
return device_id_component_mapping | |
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs): | |
# To avoid circular import problem. | |
from diffusers import pipelines | |
torch_dtype = kwargs.get("torch_dtype", torch.float32) | |
# Load each module in the pipeline on a meta device so that we can derive the device map. | |
init_empty_modules = {} | |
for name, (library_name, class_name) in init_dict.items(): | |
if class_name.startswith("Flax"): | |
raise ValueError("Flax pipelines are not supported with `device_map`.") | |
# Define all importable classes | |
is_pipeline_module = hasattr(pipelines, library_name) | |
importable_classes = ALL_IMPORTABLE_CLASSES | |
loaded_sub_model = None | |
# Use passed sub model or load class_name from library_name | |
if name in passed_class_obj: | |
# if the model is in a pipeline module, then we load it from the pipeline | |
# check that passed_class_obj has correct parent class | |
maybe_raise_or_warn( | |
library_name, | |
library, | |
class_name, | |
importable_classes, | |
passed_class_obj, | |
name, | |
is_pipeline_module, | |
) | |
with accelerate.init_empty_weights(): | |
loaded_sub_model = passed_class_obj[name] | |
else: | |
loaded_sub_model = _load_empty_model( | |
library_name=library_name, | |
class_name=class_name, | |
importable_classes=importable_classes, | |
pipelines=pipelines, | |
is_pipeline_module=is_pipeline_module, | |
pipeline_class=pipeline_class, | |
name=name, | |
torch_dtype=torch_dtype, | |
cached_folder=kwargs.get("cached_folder", None), | |
force_download=kwargs.get("force_download", None), | |
resume_download=kwargs.get("resume_download", None), | |
proxies=kwargs.get("proxies", None), | |
local_files_only=kwargs.get("local_files_only", None), | |
token=kwargs.get("token", None), | |
revision=kwargs.get("revision", None), | |
) | |
if loaded_sub_model is not None: | |
init_empty_modules[name] = loaded_sub_model | |
# determine device map | |
# Obtain a sorted dictionary for mapping the model-level components | |
# to their sizes. | |
module_sizes = { | |
module_name: compute_module_sizes(module, dtype=torch_dtype)[""] | |
for module_name, module in init_empty_modules.items() | |
if isinstance(module, torch.nn.Module) | |
} | |
module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True)) | |
# Obtain maximum memory available per device (GPUs only). | |
max_memory = get_max_memory(max_memory) | |
max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True)) | |
max_memory = {k: v for k, v in max_memory.items() if k != "cpu"} | |
# Obtain a dictionary mapping the model-level components to the available | |
# devices based on the maximum memory and the model sizes. | |
final_device_map = None | |
if len(max_memory) > 0: | |
device_id_component_mapping = _assign_components_to_devices( | |
module_sizes, max_memory, device_mapping_strategy=device_map | |
) | |
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}` | |
final_device_map = {} | |
for device_id, components in device_id_component_mapping.items(): | |
for component in components: | |
final_device_map[component] = device_id | |
return final_device_map | |
def load_sub_model( | |
library_name: str, | |
class_name: str, | |
importable_classes: List[Any], | |
pipelines: Any, | |
is_pipeline_module: bool, | |
pipeline_class: Any, | |
torch_dtype: torch.dtype, | |
provider: Any, | |
sess_options: Any, | |
device_map: Optional[Union[Dict[str, torch.device], str]], | |
max_memory: Optional[Dict[Union[int, str], Union[int, str]]], | |
offload_folder: Optional[Union[str, os.PathLike]], | |
offload_state_dict: bool, | |
model_variants: Dict[str, str], | |
name: str, | |
from_flax: bool, | |
variant: str, | |
low_cpu_mem_usage: bool, | |
cached_folder: Union[str, os.PathLike], | |
): | |
"""Helper method to load the module `name` from `library_name` and `class_name`""" | |
# retrieve class candidates | |
class_obj, class_candidates = get_class_obj_and_candidates( | |
library_name, | |
class_name, | |
importable_classes, | |
pipelines, | |
is_pipeline_module, | |
component_name=name, | |
cache_dir=cached_folder, | |
) | |
load_method_name = None | |
# retrieve load method name | |
for class_name, class_candidate in class_candidates.items(): | |
if class_candidate is not None and issubclass(class_obj, class_candidate): | |
load_method_name = importable_classes[class_name][1] | |
# if load method name is None, then we have a dummy module -> raise Error | |
if load_method_name is None: | |
none_module = class_obj.__module__ | |
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( | |
TRANSFORMERS_DUMMY_MODULES_FOLDER | |
) | |
if is_dummy_path and "dummy" in none_module: | |
# call class_obj for nice error message of missing requirements | |
class_obj() | |
raise ValueError( | |
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" | |
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." | |
) | |
load_method = getattr(class_obj, load_method_name) | |
# add kwargs to loading method | |
diffusers_module = importlib.import_module(__name__.split(".")[0]) | |
loading_kwargs = {} | |
if issubclass(class_obj, torch.nn.Module): | |
loading_kwargs["torch_dtype"] = torch_dtype | |
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): | |
loading_kwargs["provider"] = provider | |
loading_kwargs["sess_options"] = sess_options | |
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) | |
if is_transformers_available(): | |
transformers_version = version.parse(version.parse(transformers.__version__).base_version) | |
else: | |
transformers_version = "N/A" | |
is_transformers_model = ( | |
is_transformers_available() | |
and issubclass(class_obj, PreTrainedModel) | |
and transformers_version >= version.parse("4.20.0") | |
) | |
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. | |
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. | |
# This makes sure that the weights won't be initialized which significantly speeds up loading. | |
if is_diffusers_model or is_transformers_model: | |
loading_kwargs["device_map"] = device_map | |
loading_kwargs["max_memory"] = max_memory | |
loading_kwargs["offload_folder"] = offload_folder | |
loading_kwargs["offload_state_dict"] = offload_state_dict | |
loading_kwargs["variant"] = model_variants.pop(name, None) | |
if from_flax: | |
loading_kwargs["from_flax"] = True | |
# the following can be deleted once the minimum required `transformers` version | |
# is higher than 4.27 | |
if ( | |
is_transformers_model | |
and loading_kwargs["variant"] is not None | |
and transformers_version < version.parse("4.27.0") | |
): | |
raise ImportError( | |
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" | |
) | |
elif is_transformers_model and loading_kwargs["variant"] is None: | |
loading_kwargs.pop("variant") | |
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` | |
if not (from_flax and is_transformers_model): | |
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage | |
else: | |
loading_kwargs["low_cpu_mem_usage"] = False | |
# check if the module is in a subdirectory | |
if os.path.isdir(os.path.join(cached_folder, name)): | |
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) | |
else: | |
# else load from the root directory | |
loaded_sub_model = load_method(cached_folder, **loading_kwargs) | |
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict): | |
# remove hooks | |
remove_hook_from_module(loaded_sub_model, recurse=True) | |
needs_offloading_to_cpu = device_map[""] == "cpu" | |
if needs_offloading_to_cpu: | |
dispatch_model( | |
loaded_sub_model, | |
state_dict=loaded_sub_model.state_dict(), | |
device_map=device_map, | |
force_hooks=True, | |
main_device=0, | |
) | |
else: | |
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True) | |
return loaded_sub_model | |
def _fetch_class_library_tuple(module): | |
# import it here to avoid circular import | |
diffusers_module = importlib.import_module(__name__.split(".")[0]) | |
pipelines = getattr(diffusers_module, "pipelines") | |
# register the config from the original module, not the dynamo compiled one | |
not_compiled_module = _unwrap_model(module) | |
library = not_compiled_module.__module__.split(".")[0] | |
# check if the module is a pipeline module | |
module_path_items = not_compiled_module.__module__.split(".") | |
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None | |
path = not_compiled_module.__module__.split(".") | |
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) | |
# if library is not in LOADABLE_CLASSES, then it is a custom module. | |
# Or if it's a pipeline module, then the module is inside the pipeline | |
# folder so we set the library to module name. | |
if is_pipeline_module: | |
library = pipeline_dir | |
elif library not in LOADABLE_CLASSES: | |
library = not_compiled_module.__module__ | |
# retrieve class_name | |
class_name = not_compiled_module.__class__.__name__ | |
return (library, class_name) | |