LINC-BIT's picture
Upload 1912 files
b84549f verified
import torch
from torch import nn
from abc import ABC, abstractmethod
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, set_module
from utils.common.log import logger
from .base import FMLoRA_Util, LoRA
class ToQKV_WrappedWithLoRA(nn.Module):
def __init__(self, fc: nn.Linear, ab_r: int):
super(ToQKV_WrappedWithLoRA, self).__init__()
self.fc = fc
self.ab = self.create_ab_as_linear(fc.weight.data, ab_r)
def create_ab_as_linear(self, fc_weight: torch.Tensor, ab_r: int):
res = nn.Sequential(
LoRA(fc_weight.size(1), fc_weight.size(0) // ab_r, bias=False),
LoRA(fc_weight.size(0) // ab_r, fc_weight.size(0), bias=False)
).to(fc_weight.device)
nn.init.kaiming_uniform_(res[0].weight, a=5 ** 0.5)
nn.init.zeros_(res[1].weight)
return res
def forward(self, x):
x1 = self.fc(x)
x2 = self.ab(x)
return x1 + x2
class FMLoRA_CLIP_Util(FMLoRA_Util):
@torch.no_grad()
def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples: dict):
fm.eval()
for k, v in samples.items():
if isinstance(v, torch.Tensor):
samples[k] = v.to(get_model_device(fm))
print(k)
o1 = fm(**samples)
for name, module in fm.named_modules():
if name.endswith(('k_proj', 'q_proj', 'v_proj')):
set_module(fm, name, ToQKV_WrappedWithLoRA(module, ab_r))
o2 = fm(**samples)
output_diff = ((o1.logits_per_image - o2.logits_per_image) ** 2).sum() + ((o1.logits_per_text - o2.logits_per_text) ** 2).sum()
assert output_diff < 1e-5
return fm
@torch.no_grad()
def absorb_lora_and_recover_net_structure(self, fm: nn.Module, samples: dict):
fm.eval()
# print('absorb lora before')
for k, v in samples.items():
if isinstance(v, torch.Tensor):
samples[k] = v.to(get_model_device(fm))
print(k)
o1 = fm(**samples)
for name, module in fm.named_modules():
if not isinstance(module, ToQKV_WrappedWithLoRA):
continue
fc = module.fc
ab = module.ab
fc.weight.add_(ab[1].weight @ ab[0].weight)
set_module(fm, name, fc)
# print('absorb lora after')
o2 = fm(**samples)
output_diff = ((o1.logits_per_image - o2.logits_per_image) ** 2).sum() + ((o1.logits_per_text - o2.logits_per_text) ** 2).sum()
assert output_diff < 1e-6, output_diff
return fm