File size: 9,948 Bytes
5f37b27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
import torch
import math
import torch.nn.functional as F
from torch import nn
def convert_linear_to_moe(
name: str,
config: dict,
layer_idx: int,
in_features: int,
out_features: int,
bias: bool = True,
show_debug: bool = False,
):
"""Converts nn.Linear to MoeLayer
Args:
name (str): Layer Name
config (dict): Composer config
layer_idx (int): Transformer block id.
in_features (int): Input features of Default nn.Linear layer.
out_features (int): Output features of Default nn.Linear layer.
bias (bool, optional): Defaults to True.
"""
if (layer_idx in config.router_layers_index) and (name in config.router_layers):
if hasattr(config, "adapter_configs"):
return LoRAMoeLayer(
config=config,
in_features=in_features,
out_features=out_features,
bias=bias,
name=name,
layer_idx=layer_idx,
show_debug=show_debug
)
else:
return MoeLayer(
in_features=in_features,
out_features=out_features,
bias=bias,
num_experts=config.num_experts,
num_experts_per_tok=config.num_experts_per_tok,
)
return nn.Linear(in_features, out_features, bias=bias)
class MoeLayer(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
num_experts: int,
num_experts_per_tok: int = 2,
):
"""Mixture of Expert Layer
Args:
in_features (int): Input Features
out_features (int): Output Features
bias (bool): bias
num_experts (int): Total numbers of experts that Router Layer would handle
num_experts_per_tok (int, optional): Number of Active Experts per token(step). Defaults to 2.
"""
super().__init__()
self.gate = nn.Linear(in_features, num_experts, bias=False)
self.experts = nn.ModuleList(
[nn.Linear(in_features, out_features, bias) for _ in range(num_experts)]
)
self.num_experts_per_tok = num_experts_per_tok
self.in_features = in_features
self.out_features = out_features
def forward(self, inputs: torch.Tensor):
gate_logits = self.gate(inputs)
weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok)
weights = F.softmax(weights, dim=2, dtype=torch.float).to(inputs.dtype)
results = torch.zeros(
(inputs.shape[0], inputs.shape[1], self.out_features),
device=inputs.device,
dtype=inputs.dtype,
)
for ix, expert in enumerate(self.experts):
batch_idx, tok_idx, expert_idx = torch.where(selected_experts == ix)
results[batch_idx, tok_idx] += expert(inputs[batch_idx, tok_idx]) * weights[
batch_idx, tok_idx, expert_idx
].unsqueeze(-1)
return results
class LoRAMoeLayer(torch.nn.Module):
def __init__(self, config, in_features, out_features, bias, name = "", layer_idx = -1, show_debug=False) -> None:
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.num_experts = config.num_experts
self.in_features = in_features
self.out_features = out_features
self._name = name
self._layer_idx = layer_idx
self.r = {}
self.lora_alpha = {}
self.scaling = {}
self.use_dora = {}
self.lora_dropout = nn.ModuleDict({})
self.lora_A = nn.ModuleDict({})
self.lora_B = nn.ModuleDict({})
self.base_layer = nn.Linear(self.in_features, self.out_features, bias=bias)
## BTXと対応させるため仮想のexpertを1つ作る
self.num_experts = config.num_experts+1
self.gate = torch.nn.Linear(
in_features, self.num_experts, bias=False
) # device="mps:0")# TODO FIXME
# self.gate = torch.nn.Linear(
# config.hidden_size, config.num_experts, bias=False
# ) # device="mps:0")# TODO FIXME
self.active_adapters = []
for ix, adapter_config in enumerate(self.config.adapter_configs):
self.update_layer(
adapter_name=str(ix),
r=adapter_config["r"],
lora_alpha=adapter_config["lora_alpha"],
lora_dropout=adapter_config["lora_dropout"],
init_lora_weights=adapter_config["init_lora_weights"],
use_rslora=adapter_config["use_rslora"],
use_dora=adapter_config["use_dora"],
)
def forward(self, x, *args, **kwargs):
"""
This method is designed to be a drop-in-replacement for the peft LoRA layers' .forward method.
To use it, a bound method must be created (bound to an instance of the LoRALayer class).
"""
previous_dtype = x.dtype
gate_logits = self.gate(x) # b,s,N
weights, selected_experts = torch.topk(
gate_logits, self.num_experts_per_tok
) # b,s,n
#if self._layer_idx == 0 or self._layer_idx == 16 or self._layer_idx == 31:
# print(f"{self._name}_{self._layer_idx}: {selected_experts}")
# print("-"*10)
weights = F.softmax(weights, dim=2, dtype=torch.float).to(
previous_dtype
) # b,s,n
result = self.base_layer(x, *args, **kwargs)
"""TODO MAYBE
- tensorize this loop add learnable weights here
- These are in my mind ( sigle embedding, each lora layer with a gate, lora gating loss similar to iclr )
"""
for ix, active_adapter in enumerate(self.active_adapters):
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype) # type: ignore
batch_idx, tok_idx, expert_idx = torch.where(selected_experts == ix)
x_adapter = x[
batch_idx, tok_idx
] # slicing uses the same tensor, whereas indexing will result in a copy. check the tensor address using tensor.storage().data_ptr()
x_adapter = (
lora_B(lora_A(dropout(x_adapter))) * scaling
) # * self.config.global_scaling_weight
# maybe we require a small linear layer that we train here, not sure.
result[batch_idx, tok_idx] += x_adapter * weights[
batch_idx, tok_idx, expert_idx
].unsqueeze(-1)
# apply nn.functional.silu ?? can pretrained lora be tweaked for this variation.
result = result.to(previous_dtype)
return result
def update_layer(
self,
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
use_dora: bool = False,
):
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)
# check weight and qweight (for GPTQ)
for weight_name in ("weight", "qweight"):
weight = getattr(self.base_layer, weight_name, None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
break
if use_dora:
raise NotImplementedError
self.use_dora[adapter_name] = False
self.active_adapters.append(adapter_name)
def reset_lora_parameters(self, adapter_name, init_lora_weights):
if init_lora_weights is False:
return
if adapter_name in self.lora_A.keys():
if init_lora_weights is True:
# initialize A the same way as the default for nn.Linear and B to zero
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
nn.init.kaiming_uniform_(
self.lora_A[adapter_name].weight, a=math.sqrt(5)
)
elif init_lora_weights.lower() == "gaussian":
nn.init.normal_(
self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name]
)
else:
raise ValueError(f"Unknown initialization {init_lora_weights=}")
nn.init.zeros_(self.lora_B[adapter_name].weight)
if hasattr(self, "lora_embedding_A"):
if adapter_name in self.lora_embedding_A.keys():
# initialize a the same way as the default for nn.linear and b to zero
nn.init.zeros_(self.lora_embedding_A[adapter_name])
nn.init.normal_(self.lora_embedding_B[adapter_name])
|