Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
import functools | |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
from torch.distributed.fsdp.wrap import ( | |
transformer_auto_wrap_policy, | |
size_based_auto_wrap_policy, | |
) | |
def get_size_policy(min_params=1e8): | |
num_wrap_policy = functools.partial( | |
size_based_auto_wrap_policy, min_num_params=min_params | |
) | |
return num_wrap_policy | |
def get_llama_wrapper(): | |
"""we register our main layer class and use the fsdp transformer wrapping policy | |
ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers | |
""" | |
# ==== use new transformer wrapper | |
llama_auto_wrap_policy = functools.partial( | |
transformer_auto_wrap_policy, | |
transformer_layer_cls={ | |
LlamaDecoderLayer, | |
}, | |
) | |
return llama_auto_wrap_policy | |