Spaces:
Sleeping
Sleeping
File size: 3,201 Bytes
749745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict
def tf2th(conv_weights):
"""Possibly convert HWIO to OIHW."""
if conv_weights.ndim == 4:
conv_weights = conv_weights.transpose([3, 2, 0, 1])
return torch.from_numpy(conv_weights)
def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
import re
layer_keys = sorted(state_dict.keys())
for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
if not stage_with_dcn:
continue
for old_key in layer_keys:
pattern = ".*block{}.*conv2.*".format(ix)
r = re.match(pattern, old_key)
if r is None:
continue
for param in ["weight", "bias"]:
if old_key.find(param) is -1:
continue
if "unit01" in old_key:
continue
new_key = old_key.replace("conv2.{}".format(param), "conv2.conv.{}".format(param))
print("pattern: {}, old_key: {}, new_key: {}".format(pattern, old_key, new_key))
# Calculate SD conv weight
w = state_dict[old_key]
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / torch.sqrt(v + 1e-10)
state_dict[new_key] = w
del state_dict[old_key]
return state_dict
def load_big_format(cfg, f):
model = OrderedDict()
weights = np.load(f)
cmap = {"a": 1, "b": 2, "c": 3}
for key, val in weights.items():
old_key = key.replace("resnet/", "")
if "root_block" in old_key:
new_key = "root.conv.weight"
elif "/proj/standardized_conv2d/kernel" in old_key:
key_pattern = old_key.replace("/proj/standardized_conv2d/kernel", "").replace("resnet/", "")
bname, uname, cidx = key_pattern.split("/")
new_key = "{}.downsample.{}.conv{}.weight".format(bname, uname, cmap[cidx])
elif "/standardized_conv2d/kernel" in old_key:
key_pattern = old_key.replace("/standardized_conv2d/kernel", "").replace("resnet/", "")
bname, uname, cidx = key_pattern.split("/")
new_key = "{}.{}.conv{}.weight".format(bname, uname, cmap[cidx])
elif "/group_norm/gamma" in old_key:
key_pattern = old_key.replace("/group_norm/gamma", "").replace("resnet/", "")
bname, uname, cidx = key_pattern.split("/")
new_key = "{}.{}.gn{}.weight".format(bname, uname, cmap[cidx])
elif "/group_norm/beta" in old_key:
key_pattern = old_key.replace("/group_norm/beta", "").replace("resnet/", "")
bname, uname, cidx = key_pattern.split("/")
new_key = "{}.{}.gn{}.bias".format(bname, uname, cmap[cidx])
else:
print("Unknown key {}".format(old_key))
continue
print("Map {} -> {}".format(key, new_key))
model[new_key] = tf2th(val)
model = _rename_conv_weights_for_deformable_conv_layers(model, cfg)
return dict(model=model)
|