from arch.hourglass import image_transformer_v2 as itv2 from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2 from arch.swinir.swinir import SwinIR def create_arch(arch, condition_channels=0): # arch should be, e.g., swinir_XL, or hdit_XL arch_name, arch_size = arch.split('_') arch_config = arch_configs[arch_name][arch_size].copy() arch_config['in_channels'] += condition_channels return arch_name_to_object[arch_name](**arch_config) arch_configs = { 'hdit': { "ImageNet256Sp4": { 'in_channels': 3, 'out_channels': 3, 'widths': [256, 512, 1024], 'depths': [2, 2, 8], 'patch_size': [4, 4], 'self_attns': [ {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, {"type": "global", "d_head": 64} ], 'mapping_depth': 2, 'mapping_width': 768, 'dropout_rate': [0, 0, 0], 'mapping_dropout_rate': 0.0 }, "XL2": { 'in_channels': 3, 'out_channels': 3, 'widths': [384, 768], 'depths': [2, 11], 'patch_size': [4, 4], 'self_attns': [ {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, {"type": "global", "d_head": 64} ], 'mapping_depth': 2, 'mapping_width': 768, 'dropout_rate': [0, 0], 'mapping_dropout_rate': 0.0 } }, 'swinir': { "M": { 'in_channels': 3, 'out_channels': 3, 'embed_dim': 120, 'depths': [6, 6, 6, 6, 6], 'num_heads': [6, 6, 6, 6, 6], 'resi_connection': '1conv', 'sf': 8 }, "L": { 'in_channels': 3, 'out_channels': 3, 'embed_dim': 180, 'depths': [6, 6, 6, 6, 6, 6, 6, 6], 'num_heads': [6, 6, 6, 6, 6, 6, 6, 6], 'resi_connection': '1conv', 'sf': 8 }, }, } def create_swinir_model(in_channels, out_channels, embed_dim, depths, num_heads, resi_connection, sf): return SwinIR( img_size=64, patch_size=1, in_chans=in_channels, num_out_ch=out_channels, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=8, mlp_ratio=2, sf=sf, img_range=1.0, upsampler="nearest+conv", resi_connection=resi_connection, unshuffle=True, unshuffle_scale=8 ) def create_hdit_model(widths, depths, self_attns, dropout_rate, mapping_depth, mapping_width, mapping_dropout_rate, in_channels, out_channels, patch_size ): assert len(widths) == len(depths) assert len(widths) == len(self_attns) assert len(widths) == len(dropout_rate) mapping_d_ff = mapping_width * 3 d_ffs = [] for width in widths: d_ffs.append(width * 3) levels = [] for depth, width, d_ff, self_attn, dropout in zip(depths, widths, d_ffs, self_attns, dropout_rate): if self_attn['type'] == 'global': self_attn = itv2.GlobalAttentionSpec(self_attn.get('d_head', 64)) elif self_attn['type'] == 'neighborhood': self_attn = itv2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7)) elif self_attn['type'] == 'shifted-window': self_attn = itv2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size']) elif self_attn['type'] == 'none': self_attn = itv2.NoAttentionSpec() else: raise ValueError(f'unsupported self attention type {self_attn["type"]}') levels.append(itv2.LevelSpec(depth, width, d_ff, self_attn, dropout)) mapping = itv2.MappingSpec(mapping_depth, mapping_width, mapping_d_ff, mapping_dropout_rate) model = ImageTransformerDenoiserModelV2( levels=levels, mapping=mapping, in_channels=in_channels, out_channels=out_channels, patch_size=patch_size, num_classes=0, mapping_cond_dim=0, ) return model arch_name_to_object = { 'hdit': create_hdit_model, 'swinir': create_swinir_model, }