import torch from torch import nn from abc import ABC, abstractmethod from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, set_module from utils.common.log import logger from .base import FMLoRA_Util, LoRA class ToQKV_WrappedWithLoRA(nn.Module): def __init__(self, fc: nn.Linear, ab_r: int): super(ToQKV_WrappedWithLoRA, self).__init__() self.fc = fc self.ab = self.create_ab_as_linear(fc.weight.data, ab_r) def create_ab_as_linear(self, fc_weight: torch.Tensor, ab_r: int): res = nn.Sequential( LoRA(fc_weight.size(1), fc_weight.size(0) // ab_r, bias=False), LoRA(fc_weight.size(0) // ab_r, fc_weight.size(0), bias=False) ).to(fc_weight.device) nn.init.kaiming_uniform_(res[0].weight, a=5 ** 0.5) nn.init.zeros_(res[1].weight) return res def forward(self, x): x1 = self.fc(x) x2 = self.ab(x) return x1 + x2 class FMLoRA_CLIP_Util(FMLoRA_Util): @torch.no_grad() def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples: dict): fm.eval() for k, v in samples.items(): if isinstance(v, torch.Tensor): samples[k] = v.to(get_model_device(fm)) print(k) o1 = fm(**samples) for name, module in fm.named_modules(): if name.endswith(('k_proj', 'q_proj', 'v_proj')): set_module(fm, name, ToQKV_WrappedWithLoRA(module, ab_r)) o2 = fm(**samples) output_diff = ((o1.logits_per_image - o2.logits_per_image) ** 2).sum() + ((o1.logits_per_text - o2.logits_per_text) ** 2).sum() assert output_diff < 1e-5 return fm @torch.no_grad() def absorb_lora_and_recover_net_structure(self, fm: nn.Module, samples: dict): fm.eval() # print('absorb lora before') for k, v in samples.items(): if isinstance(v, torch.Tensor): samples[k] = v.to(get_model_device(fm)) print(k) o1 = fm(**samples) for name, module in fm.named_modules(): if not isinstance(module, ToQKV_WrappedWithLoRA): continue fc = module.fc ab = module.ab fc.weight.add_(ab[1].weight @ ab[0].weight) set_module(fm, name, fc) # print('absorb lora after') o2 = fm(**samples) output_diff = ((o1.logits_per_image - o2.logits_per_image) ** 2).sum() + ((o1.logits_per_text - o2.logits_per_text) ** 2).sum() assert output_diff < 1e-6, output_diff return fm