Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
from collections import OrderedDict | |
def tf2th(conv_weights): | |
"""Possibly convert HWIO to OIHW.""" | |
if conv_weights.ndim == 4: | |
conv_weights = conv_weights.transpose([3, 2, 0, 1]) | |
return torch.from_numpy(conv_weights) | |
def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg): | |
import re | |
layer_keys = sorted(state_dict.keys()) | |
for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1): | |
if not stage_with_dcn: | |
continue | |
for old_key in layer_keys: | |
pattern = ".*block{}.*conv2.*".format(ix) | |
r = re.match(pattern, old_key) | |
if r is None: | |
continue | |
for param in ["weight", "bias"]: | |
if old_key.find(param) is -1: | |
continue | |
if "unit01" in old_key: | |
continue | |
new_key = old_key.replace("conv2.{}".format(param), "conv2.conv.{}".format(param)) | |
print("pattern: {}, old_key: {}, new_key: {}".format(pattern, old_key, new_key)) | |
# Calculate SD conv weight | |
w = state_dict[old_key] | |
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) | |
w = (w - m) / torch.sqrt(v + 1e-10) | |
state_dict[new_key] = w | |
del state_dict[old_key] | |
return state_dict | |
def load_big_format(cfg, f): | |
model = OrderedDict() | |
weights = np.load(f) | |
cmap = {"a": 1, "b": 2, "c": 3} | |
for key, val in weights.items(): | |
old_key = key.replace("resnet/", "") | |
if "root_block" in old_key: | |
new_key = "root.conv.weight" | |
elif "/proj/standardized_conv2d/kernel" in old_key: | |
key_pattern = old_key.replace("/proj/standardized_conv2d/kernel", "").replace("resnet/", "") | |
bname, uname, cidx = key_pattern.split("/") | |
new_key = "{}.downsample.{}.conv{}.weight".format(bname, uname, cmap[cidx]) | |
elif "/standardized_conv2d/kernel" in old_key: | |
key_pattern = old_key.replace("/standardized_conv2d/kernel", "").replace("resnet/", "") | |
bname, uname, cidx = key_pattern.split("/") | |
new_key = "{}.{}.conv{}.weight".format(bname, uname, cmap[cidx]) | |
elif "/group_norm/gamma" in old_key: | |
key_pattern = old_key.replace("/group_norm/gamma", "").replace("resnet/", "") | |
bname, uname, cidx = key_pattern.split("/") | |
new_key = "{}.{}.gn{}.weight".format(bname, uname, cmap[cidx]) | |
elif "/group_norm/beta" in old_key: | |
key_pattern = old_key.replace("/group_norm/beta", "").replace("resnet/", "") | |
bname, uname, cidx = key_pattern.split("/") | |
new_key = "{}.{}.gn{}.bias".format(bname, uname, cmap[cidx]) | |
else: | |
print("Unknown key {}".format(old_key)) | |
continue | |
print("Map {} -> {}".format(key, new_key)) | |
model[new_key] = tf2th(val) | |
model = _rename_conv_weights_for_deformable_conv_layers(model, cfg) | |
return dict(model=model) | |