mrfakename's picture
Super-squash branch 'main' using huggingface_hub
0102e16 verified
raw
history blame
13.1 kB
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, List
class LoRALayer:
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
class Embedding(nn.Embedding, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
merge_weights: bool = True,
**kwargs
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoRALayer.__init__(
self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=0,
merge_weights=merge_weights,
)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
nn.Embedding.reset_parameters(self)
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
self.merged = False
def eval(self):
nn.Linear.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
result = nn.Embedding.forward(self, x)
if self.r > 0:
after_A = F.embedding(
x,
self.lora_A.T,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
result += (after_A @ self.lora_B.T) * self.scaling
return result
else:
return nn.Embedding.forward(self, x)
class Linear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(
self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights,
)
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result += (
self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T
) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
class MergedLinear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
enable_lora: List[bool] = [False],
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(
self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights,
)
assert (
out_features % len(enable_lora) == 0
), "The length of enable_lora must divide out_features"
self.enable_lora = enable_lora
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0 and any(enable_lora):
self.lora_A = nn.Parameter(
self.weight.new_zeros((r * sum(enable_lora), in_features))
)
self.lora_B = nn.Parameter(
self.weight.new_zeros(
(out_features // len(enable_lora) * sum(enable_lora), r)
)
) # weights for Conv1D with groups=sum(enable_lora)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros(
(out_features,), dtype=torch.bool
).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def zero_pad(self, x):
result = x.new_zeros((*x.shape[:-1], self.out_features))
result = result.view(-1, self.out_features)
result[:, self.lora_ind] = x.reshape(
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
)
return result.view((*x.shape[:-1], self.out_features))
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora),
).squeeze(0)
self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora),
).squeeze(0)
self.weight.data += self.zero_pad(T(delta_w * self.scaling))
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.merged:
return F.linear(x, T(self.weight), bias=self.bias)
else:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
after_A = F.linear(self.lora_dropout(x), self.lora_A)
after_B = F.conv1d(
after_A.transpose(-2, -1),
self.lora_B.unsqueeze(-1),
groups=sum(self.enable_lora),
).transpose(-2, -1)
result += self.zero_pad(after_B) * self.scaling
return result
class Conv2d(nn.Conv2d, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
merge_weights: bool = True,
**kwargs
):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
LoRALayer.__init__(
self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights,
)
assert type(kernel_size) is int
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(
self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
)
self.lora_B = nn.Parameter(
self.weight.new_zeros((out_channels * kernel_size, r * kernel_size))
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
nn.Conv2d.reset_parameters(self)
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
nn.Conv2d.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
self.weight.data -= (self.lora_B @ self.lora_A).view(
self.weight.shape
) * self.scaling
self.merged = False
def eval(self):
nn.Conv2d.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
self.weight.data += (self.lora_B @ self.lora_A).view(
self.weight.shape
) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
return F.conv2d(
x,
self.weight
+ (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
return nn.Conv2d.forward(self, x)