ideityfy / lora.py
Yegiiii's picture
Upload 5 files
c209d46 verified
raw
history blame contribute delete
779 Bytes
# this is code adapted from https://github.com/JamesQFreeman/LoRA-ViT
import torch.nn as nn
class LoRA_qkv(nn.Module):
""" LoRA qkv module for Vision Transformer. """
def __init__(
self,
qkv: nn.Module,
linear_a_q: nn.Module,
linear_b_q: nn.Module,
linear_a_v: nn.Module,
linear_b_v: nn.Module,
):
super().__init__()
self.qkv = qkv
self.dim = qkv.in_features
self.q_lora = nn.Sequential(linear_a_q, linear_b_q)
self.v_lora = nn.Sequential(linear_a_v, linear_b_v)
def forward(self, x):
qkv = self.qkv(x)
new_q = self.q_lora(x)
new_v = self.v_lora(x)
qkv[:, :, : self.dim] += new_q
qkv[:, :, -self.dim :] += new_v
return qkv