import os,sys | |
import math | |
from pretrain.track.model import build_track_model | |
import torch.nn as nn | |
class Downstream_cage_model(nn.Module): | |
def __init__(self,pretrain_model,embed_dim,crop): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(embed_dim, 128), | |
nn.ReLU(), | |
nn.Linear(128,1) | |
) | |
self.pretrain_model=pretrain_model | |
self.crop=crop | |
def forward(self,x): | |
x=self.pretrain_model(x) | |
out=self.mlp(x[:,self.crop:-self.crop,:]) | |
return out | |
def build_cage_model(args): | |
pretrain_model=build_track_model(args) | |
model=Downstream_cage_model( | |
pretrain_model=pretrain_model, | |
embed_dim=args.embed_dim, | |
crop=args.crop | |
) | |
return model | |
# import os,sys | |
# # import inspect | |
# # currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
# # parentdir = os.path.dirname(currentdir) | |
# # sys.path.insert(0, parentdir) | |
# from pretrain.track.layers import AttentionPool,Enformer,CNN | |
# from pretrain.track.transformers import Transformer | |
# from einops.layers.torch import Rearrange | |
# from einops import rearrange | |
# import torch | |
# import torch.nn as nn | |
# import torch.nn.functional as F | |
# | |
# | |
# class Convblock(nn.Module): | |
# def __init__(self,in_channel,kernel_size,dilate_size,dropout=0.1): | |
# super().__init__() | |
# self.conv=nn.Sequential( | |
# nn.Conv2d( | |
# in_channel, in_channel, | |
# kernel_size, padding=self.pad(kernel_size, dilate_size), | |
# dilation=dilate_size), | |
# nn.GroupNorm(16, in_channel), | |
# nn.Dropout(dropout) | |
# ) | |
# def pad(self,kernelsize, dialte_size): | |
# return (kernelsize - 1) * dialte_size // 2 | |
# def symmetric(self,x): | |
# return (x + x.permute(0,1,3,2)) / 2 | |
# def forward(self,x): | |
# identity=x | |
# out=self.conv(x) | |
# x=out+identity | |
# x=self.symmetric(x) | |
# return F.relu(x) | |
# | |
# class dilated_tower(nn.Module): | |
# def __init__(self,embed_dim,in_channel=48,kernel_size=9,dilate_rate=4): | |
# super().__init__() | |
# dilate_convs=[] | |
# for i in range(dilate_rate+1): | |
# dilate_convs.append( | |
# Convblock(in_channel,kernel_size=kernel_size,dilate_size=2**i)) | |
# | |
# self.cnn=nn.Sequential( | |
# Rearrange('b l n d -> b d l n'), | |
# nn.Conv2d(embed_dim, in_channel, kernel_size=1), | |
# *dilate_convs, | |
# nn.Conv2d(in_channel, in_channel, kernel_size=1), | |
# Rearrange('b d l n -> b l n d'), | |
# ) | |
# def forward(self,x,crop): | |
# x=self.cnn(x) | |
# x=x[:,crop:-crop,crop:-crop,:] | |
# return x | |
# | |
# | |
# class Tranmodel(nn.Module): | |
# def __init__(self, backbone, transfomer): | |
# super().__init__() | |
# self.backbone = backbone | |
# self.transformer = transfomer | |
# hidden_dim = transfomer.d_model | |
# self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1) | |
# def forward(self, input): | |
# input=rearrange(input,'b n c l -> (b n) c l') | |
# src = self.backbone(input) | |
# src=self.input_proj(src) | |
# src = self.transformer(src) | |
# return src | |
# | |
# class finetunemodel(nn.Module): | |
# def __init__(self, pretrain_model, hidden_dim, embed_dim, bins, crop=25): | |
# super().__init__() | |
# self.pretrain_model = pretrain_model | |
# self.bins = bins | |
# self.crop = crop | |
# self.attention_pool = AttentionPool(hidden_dim) | |
# self.project = nn.Sequential( | |
# Rearrange('(b n) c -> b c n', n=bins), | |
# nn.Conv1d(hidden_dim, hidden_dim, kernel_size=9, padding=4, groups=hidden_dim), | |
# nn.InstanceNorm1d(hidden_dim, affine=True), | |
# nn.Conv1d(hidden_dim, embed_dim, kernel_size=1), | |
# nn.ReLU(inplace=True), | |
# nn.Dropout(0.2) | |
# ) | |
# self.transformer = Enformer(dim=embed_dim, depth=4, heads=6) | |
# self.prediction_head = nn.Sequential( | |
# nn.Linear(embed_dim, 1) | |
# ) | |
# | |
# | |
# def forward(self, x): | |
# # x = rearrange(x, 'b n c l -> (b n) c l') | |
# x = self.pretrain_model(x) | |
# x = self.attention_pool(x) | |
# x = self.project(x) | |
# x = rearrange(x, 'b c n -> b n c') | |
# x = self.transformer(x) | |
# x = self.prediction_head(x[:, self.crop:-self.crop, :]) | |
# return x | |
# | |
# def build_backbone(): | |
# model = CNN() | |
# return model | |
# def build_transformer(args): | |
# return Transformer( | |
# d_model=args.hidden_dim, | |
# dropout=args.dropout, | |
# nhead=args.nheads, | |
# dim_feedforward=args.dim_feedforward, | |
# num_encoder_layers=args.enc_layers, | |
# num_decoder_layers=args.dec_layers | |
# ) | |
# def build_cage_model(args): | |
# backbone = build_backbone() | |
# transformer = build_transformer(args) | |
# pretrain_model = Tranmodel( | |
# backbone=backbone, | |
# transfomer=transformer, | |
# ) | |
# | |
# model=finetunemodel(pretrain_model,hidden_dim=args.hidden_dim,embed_dim=args.embed_dim, | |
# bins=args.bins,crop=args.crop) | |
# return model |