PMRF / utils /create_arch.py
ohayonguy
first commit fixed
b7f3942
raw
history blame
No virus
4.63 kB
from arch.hourglass import image_transformer_v2 as itv2
from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2
from arch.swinir.swinir import SwinIR
def create_arch(arch, condition_channels=0):
# arch should be, e.g., swinir_XL, or hdit_XL
arch_name, arch_size = arch.split('_')
arch_config = arch_configs[arch_name][arch_size].copy()
arch_config['in_channels'] += condition_channels
return arch_name_to_object[arch_name](**arch_config)
arch_configs = {
'hdit': {
"ImageNet256Sp4": {
'in_channels': 3,
'out_channels': 3,
'widths': [256, 512, 1024],
'depths': [2, 2, 8],
'patch_size': [4, 4],
'self_attns': [
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
{"type": "global", "d_head": 64}
],
'mapping_depth': 2,
'mapping_width': 768,
'dropout_rate': [0, 0, 0],
'mapping_dropout_rate': 0.0
},
"XL2": {
'in_channels': 3,
'out_channels': 3,
'widths': [384, 768],
'depths': [2, 11],
'patch_size': [4, 4],
'self_attns': [
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
{"type": "global", "d_head": 64}
],
'mapping_depth': 2,
'mapping_width': 768,
'dropout_rate': [0, 0],
'mapping_dropout_rate': 0.0
}
},
'swinir': {
"M": {
'in_channels': 3,
'out_channels': 3,
'embed_dim': 120,
'depths': [6, 6, 6, 6, 6],
'num_heads': [6, 6, 6, 6, 6],
'resi_connection': '1conv',
'sf': 8
},
"L": {
'in_channels': 3,
'out_channels': 3,
'embed_dim': 180,
'depths': [6, 6, 6, 6, 6, 6, 6, 6],
'num_heads': [6, 6, 6, 6, 6, 6, 6, 6],
'resi_connection': '1conv',
'sf': 8
},
},
}
def create_swinir_model(in_channels, out_channels, embed_dim, depths, num_heads, resi_connection,
sf):
return SwinIR(
img_size=64,
patch_size=1,
in_chans=in_channels,
num_out_ch=out_channels,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=8,
mlp_ratio=2,
sf=sf,
img_range=1.0,
upsampler="nearest+conv",
resi_connection=resi_connection,
unshuffle=True,
unshuffle_scale=8
)
def create_hdit_model(widths,
depths,
self_attns,
dropout_rate,
mapping_depth,
mapping_width,
mapping_dropout_rate,
in_channels,
out_channels,
patch_size
):
assert len(widths) == len(depths)
assert len(widths) == len(self_attns)
assert len(widths) == len(dropout_rate)
mapping_d_ff = mapping_width * 3
d_ffs = []
for width in widths:
d_ffs.append(width * 3)
levels = []
for depth, width, d_ff, self_attn, dropout in zip(depths, widths, d_ffs, self_attns, dropout_rate):
if self_attn['type'] == 'global':
self_attn = itv2.GlobalAttentionSpec(self_attn.get('d_head', 64))
elif self_attn['type'] == 'neighborhood':
self_attn = itv2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7))
elif self_attn['type'] == 'shifted-window':
self_attn = itv2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size'])
elif self_attn['type'] == 'none':
self_attn = itv2.NoAttentionSpec()
else:
raise ValueError(f'unsupported self attention type {self_attn["type"]}')
levels.append(itv2.LevelSpec(depth, width, d_ff, self_attn, dropout))
mapping = itv2.MappingSpec(mapping_depth, mapping_width, mapping_d_ff, mapping_dropout_rate)
model = ImageTransformerDenoiserModelV2(
levels=levels,
mapping=mapping,
in_channels=in_channels,
out_channels=out_channels,
patch_size=patch_size,
num_classes=0,
mapping_cond_dim=0,
)
return model
arch_name_to_object = {
'hdit': create_hdit_model,
'swinir': create_swinir_model,
}