File size: 779 Bytes
c209d46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
# this is code adapted from https://github.com/JamesQFreeman/LoRA-ViT
import torch.nn as nn
class LoRA_qkv(nn.Module):
""" LoRA qkv module for Vision Transformer. """
def __init__(
self,
qkv: nn.Module,
linear_a_q: nn.Module,
linear_b_q: nn.Module,
linear_a_v: nn.Module,
linear_b_v: nn.Module,
):
super().__init__()
self.qkv = qkv
self.dim = qkv.in_features
self.q_lora = nn.Sequential(linear_a_q, linear_b_q)
self.v_lora = nn.Sequential(linear_a_v, linear_b_v)
def forward(self, x):
qkv = self.qkv(x)
new_q = self.q_lora(x)
new_v = self.v_lora(x)
qkv[:, :, : self.dim] += new_q
qkv[:, :, -self.dim :] += new_v
return qkv
|