File size: 2,484 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from torch import nn
from abc import ABC, abstractmethod

from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, set_module
from utils.common.log import logger
from .base import FMLoRA_Util, LoRA


class ToQKV_WrappedWithLoRA(nn.Module):
    def __init__(self, qkv: nn.Linear, ab_r: int):
        super(ToQKV_WrappedWithLoRA, self).__init__()
        
        self.qkv = qkv
        self.abs = nn.ModuleList([self.create_ab_as_linear(w, ab_r) for w in qkv.weight.data.chunk(3, dim=0)])
        
    def create_ab_as_linear(self, fc_weight: torch.Tensor, ab_r: int):
        res = nn.Sequential(
            LoRA(fc_weight.size(1), fc_weight.size(0) // ab_r, bias=False),
            LoRA(fc_weight.size(0) // ab_r, fc_weight.size(0), bias=False)
        ).to(fc_weight.device)
        nn.init.kaiming_uniform_(res[0].weight, a=5 ** 0.5)
        nn.init.zeros_(res[1].weight)
        return res
        
    def forward(self, x):
        x1 = self.qkv(x)
        x2 = torch.cat([ab(x) for ab in self.abs], dim=-1)
        return x1 + x2
    

class FMLoRA_ViT_Util(FMLoRA_Util):
    
    @torch.no_grad()
    def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples: torch.Tensor):
        fm.eval()
        o1 = fm(samples)
        
        for name, module in fm.named_modules():
            if not name.endswith('.qkv'):
                continue    
            set_module(fm, name, ToQKV_WrappedWithLoRA(module, ab_r))
            
        o2 = fm(samples)
        
        output_diff = ((o1 - o2) ** 2).sum()
        assert output_diff < 1e-5
        
        return fm
    
    @torch.no_grad()
    def absorb_lora_and_recover_net_structure(self, fm: nn.Module, samples: torch.Tensor):
        fm.eval()
        # print('absorb lora before')
        o1 = fm(samples)
        
        for name, module in fm.named_modules():
            if not isinstance(module, ToQKV_WrappedWithLoRA):
                continue
            
            qkv = module.qkv
            fm_abs = module.abs

            fm_abs_weight = torch.cat([_abs[1].weight @ _abs[0].weight for _abs in fm_abs], dim=0)
            qkv.weight.add_(fm_abs_weight)
            
            set_module(fm, name, qkv)
        
        # print('absorb lora after')
        o2 = fm(samples)
        
        output_diff = ((o1 - o2) ** 2).sum()
        # print(o1)
        # print(o2)
        assert output_diff < 1e-5, output_diff
        
        return fm