File size: 994 Bytes
5fc7eb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch 
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(input_dim, rank) * std_dev)  # Low-rank matrix A
        self.B = nn.Parameter(torch.zeros(rank, output_dim))  # Low-rank matrix B
        self.alpha = alpha  # Scaling factor
    def forward(self, x):
        # Apply low-rank adaptation: x + alpha * (x @ A @ B)
        return self.alpha * (x @ self.A @ self.B)
    

class LinearWithLoRA(nn.Module):
    def __init__(self, linear_layer, rank, alpha):
        super().__init__()
        self.linear = linear_layer  # Original linear layer
        self.lora = LoRALayer(linear_layer.in_features, linear_layer.out_features, rank, alpha)  # LoRA layer

    def forward(self, x):
        # Combine original linear layer output with LoRA adaptation
        return self.linear(x) + self.lora(x)