File size: 5,279 Bytes
890b6a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os,sys
import math
from pretrain.track.model import build_track_model
import torch.nn as nn
class Downstream_cage_model(nn.Module):
def __init__(self,pretrain_model,embed_dim,crop):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(embed_dim, 128),
nn.ReLU(),
nn.Linear(128,1)
)
self.pretrain_model=pretrain_model
self.crop=crop
def forward(self,x):
x=self.pretrain_model(x)
out=self.mlp(x[:,self.crop:-self.crop,:])
return out
def build_cage_model(args):
pretrain_model=build_track_model(args)
model=Downstream_cage_model(
pretrain_model=pretrain_model,
embed_dim=args.embed_dim,
crop=args.crop
)
return model
# import os,sys
# # import inspect
# # currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
# # parentdir = os.path.dirname(currentdir)
# # sys.path.insert(0, parentdir)
# from pretrain.track.layers import AttentionPool,Enformer,CNN
# from pretrain.track.transformers import Transformer
# from einops.layers.torch import Rearrange
# from einops import rearrange
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
#
#
# class Convblock(nn.Module):
# def __init__(self,in_channel,kernel_size,dilate_size,dropout=0.1):
# super().__init__()
# self.conv=nn.Sequential(
# nn.Conv2d(
# in_channel, in_channel,
# kernel_size, padding=self.pad(kernel_size, dilate_size),
# dilation=dilate_size),
# nn.GroupNorm(16, in_channel),
# nn.Dropout(dropout)
# )
# def pad(self,kernelsize, dialte_size):
# return (kernelsize - 1) * dialte_size // 2
# def symmetric(self,x):
# return (x + x.permute(0,1,3,2)) / 2
# def forward(self,x):
# identity=x
# out=self.conv(x)
# x=out+identity
# x=self.symmetric(x)
# return F.relu(x)
#
# class dilated_tower(nn.Module):
# def __init__(self,embed_dim,in_channel=48,kernel_size=9,dilate_rate=4):
# super().__init__()
# dilate_convs=[]
# for i in range(dilate_rate+1):
# dilate_convs.append(
# Convblock(in_channel,kernel_size=kernel_size,dilate_size=2**i))
#
# self.cnn=nn.Sequential(
# Rearrange('b l n d -> b d l n'),
# nn.Conv2d(embed_dim, in_channel, kernel_size=1),
# *dilate_convs,
# nn.Conv2d(in_channel, in_channel, kernel_size=1),
# Rearrange('b d l n -> b l n d'),
# )
# def forward(self,x,crop):
# x=self.cnn(x)
# x=x[:,crop:-crop,crop:-crop,:]
# return x
#
#
# class Tranmodel(nn.Module):
# def __init__(self, backbone, transfomer):
# super().__init__()
# self.backbone = backbone
# self.transformer = transfomer
# hidden_dim = transfomer.d_model
# self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1)
# def forward(self, input):
# input=rearrange(input,'b n c l -> (b n) c l')
# src = self.backbone(input)
# src=self.input_proj(src)
# src = self.transformer(src)
# return src
#
# class finetunemodel(nn.Module):
# def __init__(self, pretrain_model, hidden_dim, embed_dim, bins, crop=25):
# super().__init__()
# self.pretrain_model = pretrain_model
# self.bins = bins
# self.crop = crop
# self.attention_pool = AttentionPool(hidden_dim)
# self.project = nn.Sequential(
# Rearrange('(b n) c -> b c n', n=bins),
# nn.Conv1d(hidden_dim, hidden_dim, kernel_size=9, padding=4, groups=hidden_dim),
# nn.InstanceNorm1d(hidden_dim, affine=True),
# nn.Conv1d(hidden_dim, embed_dim, kernel_size=1),
# nn.ReLU(inplace=True),
# nn.Dropout(0.2)
# )
# self.transformer = Enformer(dim=embed_dim, depth=4, heads=6)
# self.prediction_head = nn.Sequential(
# nn.Linear(embed_dim, 1)
# )
#
#
# def forward(self, x):
# # x = rearrange(x, 'b n c l -> (b n) c l')
# x = self.pretrain_model(x)
# x = self.attention_pool(x)
# x = self.project(x)
# x = rearrange(x, 'b c n -> b n c')
# x = self.transformer(x)
# x = self.prediction_head(x[:, self.crop:-self.crop, :])
# return x
#
# def build_backbone():
# model = CNN()
# return model
# def build_transformer(args):
# return Transformer(
# d_model=args.hidden_dim,
# dropout=args.dropout,
# nhead=args.nheads,
# dim_feedforward=args.dim_feedforward,
# num_encoder_layers=args.enc_layers,
# num_decoder_layers=args.dec_layers
# )
# def build_cage_model(args):
# backbone = build_backbone()
# transformer = build_transformer(args)
# pretrain_model = Tranmodel(
# backbone=backbone,
# transfomer=transformer,
# )
#
# model=finetunemodel(pretrain_model,hidden_dim=args.hidden_dim,embed_dim=args.embed_dim,
# bins=args.bins,crop=args.crop)
# return model |