Doven
update code.
f7009b3
raw
history blame
885 Bytes
import torch.nn as nn
import torch
import timm
import os
def Model():
model = timm.create_model("resnet18", pretrained=True)
model.fc = nn.Linear(512, 100)
if os.path.exists(os.path.join(os.path.dirname(__file__), "full_model.pth")):
model.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__), "full_model.pth"), map_location="cpu"))
for k, v in model.named_parameters():
if k in ["layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn2.weight", "layer4.1.bn2.bias"]:
v.requires_grad = True
else: # requires_grad = False
v.requires_grad = False
return model, model.fc
if __name__ == "__main__":
model, _ = Model()
print(model)
num_param = 0
for k, v in model.named_parameters():
num_param += v.numel()
print(k)
print("num_param:", num_param)