leftthomas commited on
Commit
21e720a
·
1 Parent(s): 901629b

Upload modeling_resnet.py

Browse files
Files changed (1) hide show
  1. modeling_resnet.py +22 -0
modeling_resnet.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock
3
+ import torch.nn.functional as F
4
+ from .configuration_resnet import ResnetConfig
5
+
6
+
7
+ BLOCK_MAPPING = {'basic': BasicBlock, 'bottleneck': Bottleneck}
8
+
9
+
10
+ class ResnetModelForImageClassification(PreTrainedModel):
11
+ config_class = ResnetConfig
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+ block_layer = BLOCK_MAPPING[config.block_type]
15
+ self.model = ResNet(block_layer, config.layers, config.num_classes)
16
+
17
+ def forward(self, tensor, labels=None):
18
+ logits = self.model(tensor)
19
+ if labels is not None:
20
+ loss = F.cross_entropy(logits, labels)
21
+ return {'loss': loss, 'logits': logits}
22
+ return {'logits': logits}