Spaces:
Paused
Paused
Create BaseModel.py
Browse files- xdecoder/BaseModel.py +37 -0
xdecoder/BaseModel.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import os
|
9 |
+
import logging
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from utils.model_loading import align_and_update_state_dicts
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class BaseModel(nn.Module):
|
20 |
+
def __init__(self, opt, module: nn.Module):
|
21 |
+
super(BaseModel, self).__init__()
|
22 |
+
self.opt = opt
|
23 |
+
self.model = module
|
24 |
+
|
25 |
+
def forward(self, *inputs, **kwargs):
|
26 |
+
outputs = self.model(*inputs, **kwargs)
|
27 |
+
return outputs
|
28 |
+
|
29 |
+
def save_pretrained(self, save_dir):
|
30 |
+
save_path = os.path.join(save_dir, 'model_state_dict.pt')
|
31 |
+
torch.save(self.model.state_dict(), save_path)
|
32 |
+
|
33 |
+
def from_pretrained(self, load_path):
|
34 |
+
state_dict = torch.load(load_path, map_location=self.opt['device'])
|
35 |
+
state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)
|
36 |
+
self.model.load_state_dict(state_dict, strict=False)
|
37 |
+
return self
|