File size: 779 Bytes
c209d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# this is code adapted from https://github.com/JamesQFreeman/LoRA-ViT

import torch.nn as nn

class LoRA_qkv(nn.Module):
    """ LoRA qkv module for Vision Transformer. """
    def __init__(
        self,
        qkv: nn.Module,
        linear_a_q: nn.Module,
        linear_b_q: nn.Module,
        linear_a_v: nn.Module,
        linear_b_v: nn.Module,
    ):
        super().__init__()
        self.qkv = qkv
        self.dim = qkv.in_features
        self.q_lora = nn.Sequential(linear_a_q, linear_b_q)
        self.v_lora = nn.Sequential(linear_a_v, linear_b_v)

    def forward(self, x):
        qkv = self.qkv(x) 
        new_q = self.q_lora(x)
        new_v = self.v_lora(x)
        qkv[:, :, : self.dim] += new_q
        qkv[:, :, -self.dim :] += new_v
        return qkv