Spaces:
Sleeping
Sleeping
File size: 5,121 Bytes
e1ab149 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from typing import Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
class BaseModel(nn.Module):
"""Base model class for animal classification."""
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""Get probability predictions."""
with torch.no_grad():
logits = self(x)
return F.softmax(logits, dim=1)
@classmethod
def load_from_checkpoint(
cls,
path: str,
map_location: Any = None
) -> 'BaseModel':
"""Load model from checkpoint."""
checkpoint = torch.load(path, map_location=map_location)
model = cls(num_classes=checkpoint['config']['num_classes'])
model.load_state_dict(checkpoint['model_state_dict'])
return model
def save_checkpoint(
self,
path: str,
extra_data: Dict[str, Any] = None
) -> None:
"""Save model checkpoint."""
data = {
'model_state_dict': self.state_dict(),
'config': {
'num_classes': self.get_num_classes(),
'model_type': self.__class__.__name__
}
}
if extra_data:
if 'config' in extra_data:
data['config'].update(extra_data['config'])
del extra_data['config']
data.update(extra_data)
torch.save(data, path)
def get_num_classes(self) -> int:
"""Get number of output classes."""
raise NotImplementedError
class CNNModel(BaseModel):
def __init__(self, num_classes: int, input_size: int = 224):
super(CNNModel, self).__init__()
self.conv_layers = nn.Sequential(
# First block: 32 filters
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
# Second block: 64 filters
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
# Third block: 128 filters
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
# Global Average Pooling
nn.AdaptiveAvgPool2d(1)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(0.5),
nn.Linear(128, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
self._initialize_weights()
def _initialize_weights(self):
"""Initialize model weights."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_layers(x)
return self.classifier(x)
def get_num_classes(self) -> int:
return self.classifier[-1].out_features
class EfficientNetModel(BaseModel):
"""EfficientNet-based model for animal classification."""
def __init__(
self,
num_classes: int,
model_name: str = "efficientnet_b0",
pretrained: bool = True
):
super(EfficientNetModel, self).__init__()
self.base_model = timm.create_model(
model_name,
pretrained=pretrained,
num_classes=0
)
with torch.no_grad():
dummy_input = torch.randn(1, 3, 224, 224)
features = self.base_model(dummy_input)
feature_dim = features.shape[1]
# Simpler classifier structure matching the saved model
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(feature_dim, num_classes)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.base_model(x)
return self.classifier(features)
def get_num_classes(self) -> int:
return self.classifier[-1].out_features
def get_model(model_type: str, num_classes: int, **kwargs) -> BaseModel:
"""Factory function to get model by type."""
models = {
'cnn': CNNModel,
'efficientnet': EfficientNetModel
}
if model_type not in models:
raise ValueError(f"Model type {model_type} not supported. Available models: {list(models.keys())}")
return models[model_type](num_classes=num_classes, **kwargs) |