LINC-BIT's picture
Upload 1912 files
b84549f verified
import torch
from torch import nn
from copy import deepcopy
from .base import FM_to_MD_Util
from utils.common.log import logger
from utils.dl.common.model import set_module, get_module, get_super_module
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size
from utils.common.log import logger
from typing import Optional, Tuple
from transformers.models.clip.modeling_clip import CLIPAttention
from transformers import CLIPVisionConfig
class CLIPAttentionPrunable(CLIPAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self):
config = CLIPVisionConfig.from_pretrained('openai/clip-vit-base-patch16')
super(CLIPAttentionPrunable, self).__init__(config)
# def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
# # print(tensor.size(), self.num_heads, self.head_dim, bsz) # torch.Size([1, 197, 192]) 8 64 1
# # head_dim should be modified
# # 'b n (h d) -> b h n d', h = self.num_heads
# if seq_len == -1:
# seq_len = tensor.size(1)
# # print(tensor.size(), bsz, seq_len, self.num_heads, -1)
# return tensor.view(bsz, seq_len, self.num_heads, -1).transpose(1, 2).contiguous()
# def forward(
# self,
# hidden_states: torch.Tensor,
# attention_mask: Optional[torch.Tensor] = None,
# causal_attention_mask: Optional[torch.Tensor] = None,
# output_attentions: Optional[bool] = False,
# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# """Input shape: Batch x Time x Channel"""
# bsz, tgt_len, embed_dim = hidden_states.size()
# # get query proj
# query_states = self.q_proj(hidden_states) * self.scale
# key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
# value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
# proj_shape = (-1, tgt_len, self.head_dim)
# # print(proj_shape, key_states.size())
# query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
# key_states = key_states.view(*proj_shape)
# value_states = value_states.view(*proj_shape)
# src_len = key_states.size(1)
# attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
# # if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
# # raise ValueError(
# # f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
# # f" {attn_weights.size()}"
# # )
# # apply the causal_attention_mask first
# if causal_attention_mask is not None:
# if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
# raise ValueError(
# f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
# f" {causal_attention_mask.size()}"
# )
# attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
# attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# if attention_mask is not None:
# if attention_mask.size() != (bsz, 1, tgt_len, src_len):
# raise ValueError(
# f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
# )
# attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
# attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# if output_attentions:
# # this operation is a bit akward, but it's required to
# # make sure that attn_weights keeps its gradient.
# # In order to do so, attn_weights have to reshaped
# # twice and have to be reused in the following
# attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
# attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
# else:
# attn_weights_reshaped = None
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# attn_output = torch.bmm(attn_probs, value_states)
# # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
# # raise ValueError(
# # f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
# # f" {attn_output.size()}"
# # )
# attn_output = attn_output.view(bsz, self.num_heads, tgt_len, -1)
# attn_output = attn_output.transpose(1, 2)
# attn_output = attn_output.reshape(bsz, tgt_len, -1)
# attn_output = self.out_proj(attn_output)
# return attn_output, attn_weights_reshaped
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _shape_dynamic_head_dim(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, -1).transpose(1, 2).contiguous()
def _shape_dynamic_num_head(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, -1, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
# logger.info(f'hidden state size: {hidden_states.size()}') # (64, 197, 768)
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape_dynamic_head_dim(self.k_proj(hidden_states), tgt_len, bsz)
value_states = self._shape_dynamic_head_dim(self.v_proj(hidden_states), tgt_len, bsz)
# (64, 197, 768), numhead: 12, head_dim: 64, seq_len: 197
# logger.info(f'key states: {self.k_proj(hidden_states).size()}, bsz: {bsz}, num_heads: {self.num_heads}, head_dim: {self.head_dim}, '
# f'seq_len: {self.k_proj(hidden_states).numel() / bsz / self.num_heads / self.head_dim}')
# (64, 197, 768), numhead: 12, head_dim: 64, seq_len: 197
# logger.info(f'value states: {self.v_proj(hidden_states).size()}, bsz: {bsz}, num_heads: {self.num_heads}, head_dim: {self.head_dim}, '
# f'seq_len: {self.v_proj(hidden_states).numel() / bsz / self.num_heads / self.head_dim}')
proj_shape = (bsz * self.num_heads, tgt_len, -1)
query_states = self._shape_dynamic_head_dim(query_states, tgt_len, bsz).view(*proj_shape)
# (64, 12, 197, 64), -1 means 197
# logger.info(f'query states: {self._shape(query_states, tgt_len, bsz).size()}, '
# f'-1 in proj_shape: {self._shape(query_states, tgt_len, bsz).numel() / bsz / self.num_heads / self.head_dim}')
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
# if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
# raise ValueError(
# f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
# f" {attn_output.size()}"
# )
# print(attn_output.size(), bsz, tgt_len, embed_dim)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, -1)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, -1)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
# reduce num_head
# def forward(
# self,
# hidden_states: torch.Tensor,
# attention_mask: Optional[torch.Tensor] = None,
# causal_attention_mask: Optional[torch.Tensor] = None,
# output_attentions: Optional[bool] = False,
# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# """Input shape: Batch x Time x Channel"""
# bsz, tgt_len, embed_dim = hidden_states.size()
# # logger.info(f'hidden state size: {hidden_states.size()}') # (64, 197, 768)
# # get query proj
# query_states = self.q_proj(hidden_states) * self.scale
# key_states = self._shape_dynamic_num_head(self.k_proj(hidden_states), tgt_len, bsz)
# value_states = self._shape_dynamic_num_head(self.v_proj(hidden_states), tgt_len, bsz)
# # (64, 197, 768), numhead: 12, head_dim: 64, seq_len: 197
# # logger.info(f'key states: {self.k_proj(hidden_states).size()}, bsz: {bsz}, num_heads: {self.num_heads}, head_dim: {self.head_dim}, '
# # f'seq_len: {self.k_proj(hidden_states).numel() / bsz / self.num_heads / self.head_dim}')
# # (64, 197, 768), numhead: 12, head_dim: 64, seq_len: 197
# # logger.info(f'value states: {self.v_proj(hidden_states).size()}, bsz: {bsz}, num_heads: {self.num_heads}, head_dim: {self.head_dim}, '
# # f'seq_len: {self.v_proj(hidden_states).numel() / bsz / self.num_heads / self.head_dim}')
# proj_shape = (-1, tgt_len, self.head_dim)
# query_states = self._shape_dynamic_head_dim(query_states, tgt_len, bsz).view(*proj_shape)
# # (64, 12, 197, 64), -1 means 197
# # logger.info(f'query states: {self._shape(query_states, tgt_len, bsz).size()}, '
# # f'-1 in proj_shape: {self._shape(query_states, tgt_len, bsz).numel() / bsz / self.num_heads / self.head_dim}')
# key_states = key_states.view(*proj_shape)
# value_states = value_states.view(*proj_shape)
# src_len = key_states.size(1)
# attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
# # if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
# # raise ValueError(
# # f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
# # f" {attn_weights.size()}"
# # )
# # apply the causal_attention_mask first
# if causal_attention_mask is not None:
# if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
# raise ValueError(
# f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
# f" {causal_attention_mask.size()}"
# )
# attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
# attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# if attention_mask is not None:
# if attention_mask.size() != (bsz, 1, tgt_len, src_len):
# raise ValueError(
# f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
# )
# attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
# attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# if output_attentions:
# # this operation is a bit akward, but it's required to
# # make sure that attn_weights keeps its gradient.
# # In order to do so, attn_weights have to reshaped
# # twice and have to be reused in the following
# attn_weights_reshaped = attn_weights.view(bsz, -1, tgt_len, src_len)
# attn_weights = attn_weights_reshaped.view(-1, tgt_len, src_len)
# else:
# attn_weights_reshaped = None
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# attn_output = torch.bmm(attn_probs, value_states)
# # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
# # raise ValueError(
# # f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
# # f" {attn_output.size()}"
# # )
# # print(attn_output.size(), bsz, tgt_len, embed_dim)
# attn_output = attn_output.view(bsz, -1, tgt_len, self.head_dim)
# attn_output = attn_output.transpose(1, 2)
# attn_output = attn_output.reshape(bsz, tgt_len, -1)
# attn_output = self.out_proj(attn_output)
# return attn_output, attn_weights_reshaped
@staticmethod
def init_from_exist_self_attn(attn: CLIPAttention):
# print(attn)
res = CLIPAttentionPrunable()
for attr in dir(attn):
# if str(attr) in ['transpose_for_scores'] or str(attr).startswith('_'):
# continue
# if isinstance(getattr(attn, attr), nn.Module):
# print(attr)
if isinstance(getattr(attn, attr), nn.Module):
try:
# print(attr, 'ok')
setattr(res, attr, getattr(attn, attr))
except Exception as e:
print(attr, str(e))
return res
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
class PrunableAttention(nn.Module):
"""
https://github.com/lucidrains/vit-pytorch
"""
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qkv_bias = False):
super().__init__()
self.inner_dim = inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.num_heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.qkv = nn.Linear(dim, inner_dim * 3, bias = qkv_bias)
# self.proj = nn.Sequential(
# nn.Linear(inner_dim, dim),
# nn.Dropout(dropout)
# ) if project_out else nn.Identity()
self.proj = nn.Linear(inner_dim, dim) if project_out else nn.Identity()
self.proj_dropout = nn.Dropout(dropout)
def forward(self, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,):
x = hidden_states
assert attention_mask is None
assert causal_attention_mask is None
assert not output_attentions
# qkv = self.qkv(x).chunk(3, dim = -1)
raw_qkv = self.qkv(x)
self.inner_dim = (raw_qkv.size(-1) - self.proj.in_features) // 2
qkv = raw_qkv[:, :, 0: self.inner_dim], raw_qkv[:, :, self.inner_dim: self.inner_dim * 2], raw_qkv[:, :, self.inner_dim * 2:]
# print('v', qkv[0].size(), qkv[0].sum((0, 1))[0: 10], qkv[0].sum((0, 1)).nonzero(as_tuple=True)[0].size())
# raw_v = qkv[2]
# print('after_fbs_q, after_fbs_k', qkv[0].sum((0, 1))[0: 10], qkv[0].sum((0, 1)).nonzero(as_tuple=True)[0].size(),
# qkv[1].sum((0, 1))[0: 10], qkv[1].sum((0, 1)).nonzero(as_tuple=True)[0].size(),)
# print('after_fbs_v', raw_v.size(), raw_v.sum((0, 1))[0: 10], raw_v.sum((0, 1)).nonzero(as_tuple=True)[0].size())
# print('q, before rearrage', qkv[0].size())
q, k, v = qkv
# print('raw qkv size', q.size(), k.size(), v.size())
# exit()
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.num_heads), qkv)
# print('raw qkv size', q.size(), k.size(), v.size())
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# print('q, k, dots, after rearrage', q.size(), k.transpose(-1, -2).size(), dots.size())
attn = self.attend(dots)
# attn = dots
attn = self.dropout(attn)
# print(attn)
# print('attn', attn.size(), attn.sum((0, 1))[0: 10], attn.sum((0, 1)).nonzero(as_tuple=True)[0].size())
# print('attn', attn.size(), attn.sum((0, 1))[0: 10], attn.sum((0, 1)).nonzero(as_tuple=True)[0].size())
# print('v2', v.size())
out = torch.matmul(attn, v)
# print('out1', out.size())
# NOTE: just for trial debug
# out = v
# print('out before rerange', out.size())
# print(v.size(), v)
# exit()
out = rearrange(out, 'b h n d -> b n (h d)')
# print('out', out.size(), out.sum((0, 1))[0: 10], out.sum((0, 1)).nonzero(as_tuple=True)[0].size())
# exit()
res = self.proj_dropout(self.proj(out))
# res = self.proj_dropout(
# F.linear(self.proj.weight.T, out.T, self.proj.bias)
# )
# print(self.proj, self.proj_dropout)
# print('res', res.size(), res.sum((0, 1))[0: 10], res.sum((0, 1)).nonzero(as_tuple=True)[0].size())
return res, None
class FM_to_MD_CLIP_Util(FM_to_MD_Util):
def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int) -> nn.Module:
fm_vit = deepcopy(fm)
# for block in fm_vit.model.text_model.encoder.layers:
# set_module(block, 'self_attn', CLIPAttentionPrunable.init_from_exist_self_attn(block.self_attn))
debug_input = torch.rand((1, 3, 32, 32)).cuda()
fm.eval()
o1 = fm.model.vision_model(debug_input).pooler_output
for block in fm_vit.model.vision_model.encoder.layers:
# set_module(block, 'self_attn', CLIPAttentionPrunable.init_from_exist_self_attn(block.self_attn))
attn: CLIPAttention = block.self_attn
# from dnns.vit import PrunableAttention
new_attn = PrunableAttention(
dim=768,
heads=12,
dim_head=64,
dropout=0,
qkv_bias=True
)
new_attn.qkv.weight.data.copy_(torch.cat([
attn.q_proj.weight,
attn.k_proj.weight,
attn.v_proj.weight
], dim=0))
new_attn.qkv.bias.data.copy_(torch.cat([
attn.q_proj.bias,
attn.k_proj.bias,
attn.v_proj.bias
], dim=0))
new_attn.proj.weight.data.copy_(attn.out_proj.weight)
new_attn.proj.bias.data.copy_(attn.out_proj.bias)
set_module(block, 'self_attn', new_attn)
o2 = fm.model.vision_model(debug_input).pooler_output
# NOTE: bug is here!!!
# although the diff is ZERO, but the logic of CLIPAttentionPrunable is incorrect!!!!
diff = ((o1 - o2) ** 2).sum()
print('diff before/after adding CLIPAttentionPrunable', diff)
assert diff < 1e-4
# print('\n\nDEBUG: WITHOUT ADDING CLIPAttentionPrunable\n\n')
# exit()
# return fm
def _f(n):
return int(n // reducing_width_ratio)
# def _rand_indexes(n):
# return torch.randperm(n)[0: int(n // reducing_width_ratio)]
def l1_max_indexes(p: torch.Tensor, dim=0):
assert dim in [0, 1]
assert p.dim() in [1, 2, 4]
if dim == 1:
p = p.T
p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1)
n = p.size(0)
res = p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0]
# print(res)
return res
# first_attn = True
# for block_i, block in enumerate(fm_vit.model.text_model.encoder.layers):
# for k in ['k_proj', 'q_proj', 'v_proj']:
# qkv = get_module(block, f'self_attn.{k}')
# new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features),
# qkv.bias is not None, qkv.weight.device)
# indexes = l1_max_indexes(qkv.weight.data, 0)
# new_qkv.weight.data.copy_(qkv.weight.data[indexes])
# if qkv.bias is not None:
# new_qkv.bias.data.copy_(qkv.bias.data[indexes])
# set_module(block, f'self_attn.{k}', new_qkv)
# proj = block.self_attn.out_proj
# new_proj = nn.Linear(_f(proj.in_features), proj.out_features,
# proj.bias is not None, proj.weight.device)
# new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)])
# if proj.bias is not None:
# new_proj.bias.data.copy_(proj.bias.data)
# set_module(block, f'self_attn.out_proj', new_proj)
# fc1 = block.mlp.fc1
# new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features),
# fc1.bias is not None, fc1.weight.device)
# indexes = l1_max_indexes(fc1.weight.data, 0)
# new_fc1.weight.data.copy_(fc1.weight.data[indexes])
# if fc1.bias is not None:
# new_fc1.bias.data.copy_(fc1.bias.data[indexes])
# set_module(block, f'mlp.fc1', new_fc1)
# fc2 = block.mlp.fc2
# new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features,
# fc2.bias is not None, fc2.weight.device)
# new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)])
# if fc2.bias is not None:
# new_fc2.bias.data.copy_(fc2.bias.data)
# set_module(block, f'mlp.fc2', new_fc2)
for block_i, block in enumerate(fm_vit.model.vision_model.encoder.layers):
# for k in ['k_proj', 'q_proj', 'v_proj']:
# qkv = get_module(block, f'self_attn.{k}')
# new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features),
# qkv.bias is not None, qkv.weight.device)
# indexes = l1_max_indexes(qkv.weight.data, 0)
# new_qkv.weight.data.copy_(qkv.weight.data[indexes])
# if qkv.bias is not None:
# new_qkv.bias.data.copy_(qkv.bias.data[indexes])
# set_module(block, f'self_attn.{k}', new_qkv)
# proj = block.self_attn.out_proj
# new_proj = nn.Linear(_f(proj.in_features), proj.out_features,
# proj.bias is not None, proj.weight.device)
# new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)])
# if proj.bias is not None:
# new_proj.bias.data.copy_(proj.bias.data)
# set_module(block, f'self_attn.out_proj', new_proj)
# ------------------
qkv = block.self_attn.qkv
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features),
qkv.bias is not None, qkv.weight.device)
indexes = l1_max_indexes(qkv.weight.data, 0)
new_qkv.weight.data.copy_(qkv.weight.data[indexes])
if qkv.bias is not None:
new_qkv.bias.data.copy_(qkv.bias.data[indexes])
set_module(block, f'self_attn.qkv', new_qkv)
proj = block.self_attn.proj
new_proj = nn.Linear(_f(proj.in_features), proj.out_features,
proj.bias is not None, proj.weight.device)
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)])
if proj.bias is not None:
new_proj.bias.data.copy_(proj.bias.data)
set_module(block, f'self_attn.proj', new_proj)
# --------------------
fc1 = block.mlp.fc1
new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features),
fc1.bias is not None, fc1.weight.device)
indexes = l1_max_indexes(fc1.weight.data, 0)
new_fc1.weight.data.copy_(fc1.weight.data[indexes])
if fc1.bias is not None:
new_fc1.bias.data.copy_(fc1.bias.data[indexes])
set_module(block, f'mlp.fc1', new_fc1)
fc2 = block.mlp.fc2
new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features,
fc2.bias is not None, fc2.weight.device)
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)])
if fc2.bias is not None:
new_fc2.bias.data.copy_(fc2.bias.data)
set_module(block, f'mlp.fc2', new_fc2)
return fm_vit
def init_md_from_fm_by_reducing_width_with_perf_test(self, fm: nn.Module, reducing_width_ratio: int,
samples: torch.Tensor) -> nn.Module:
fm_size = get_model_size(fm, True)
fm_latency = self._get_model_latency(fm, samples, 20,
get_model_device(fm), 20, False)
master_dnn = self.init_md_from_fm_by_reducing_width(fm, reducing_width_ratio)
master_dnn_size = get_model_size(master_dnn, True)
logger.debug(f'inited master DNN: {master_dnn}')
# from utils.dl.common.model import get_module
# print('after generating')
# get_module(fm, 'head').debug()
# get_module(master_dnn, 'head').debug()
# print('test master latency')
master_dnn_latency = self._get_model_latency(master_dnn, samples, 20,
get_model_device(master_dnn), 20, False)
logger.info(f'init master DNN (w/o FBS yet) by reducing foundation model\'s width (by {reducing_width_ratio:d}x)')
logger.info(f'foundation model ({fm_size:.3f}MB, {fm_latency:.4f}s/sample) -> '
f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample)\n'
f'(model size: ↓ {(fm_size / master_dnn_size):.2f}x, '
f'latency: ↓ {(fm_latency / master_dnn_latency):.2f}x)')
return master_dnn
def _get_model_latency(self, model: torch.nn.Module, model_input_size, sample_num: int,
device: str, warmup_sample_num: int, return_detail=False):
import time
if isinstance(model_input_size, tuple):
dummy_input = torch.rand(model_input_size).to(device)
else:
dummy_input = model_input_size
model = model.to(device)
model.eval()
# warm up
with torch.no_grad():
for _ in range(warmup_sample_num):
model(**dummy_input)
infer_time_list = []
if device == 'cuda' or 'cuda' in str(device):
with torch.no_grad():
for _ in range(sample_num):
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record()
model(**dummy_input)
e.record()
torch.cuda.synchronize()
cur_model_infer_time = s.elapsed_time(e) / 1000.
infer_time_list += [cur_model_infer_time]
else:
with torch.no_grad():
for _ in range(sample_num):
start = time.time()
model(**dummy_input)
cur_model_infer_time = time.time() - start
infer_time_list += [cur_model_infer_time]
avg_infer_time = sum(infer_time_list) / sample_num
if return_detail:
return avg_infer_time, infer_time_list
return avg_infer_time