Update README.md
Browse files
README.md
CHANGED
@@ -54,7 +54,6 @@ class CustomConvNeXtModel(nn.Module):
|
|
54 |
def forward(self, x):
|
55 |
return self.convnext(x)
|
56 |
|
57 |
-
|
58 |
class CustomMobileNetModel(nn.Module):
|
59 |
def __init__(self, weights=MobileNet_V2_Weights.DEFAULT, num_classes=2):
|
60 |
super().__init__()
|
@@ -81,9 +80,9 @@ class CustomMobileNetModel(nn.Module):
|
|
81 |
return self.mobilenet(x)
|
82 |
|
83 |
class EnsembleModel(nn.Module):
|
84 |
-
def __init__(self,
|
85 |
super().__init__()
|
86 |
-
self.
|
87 |
self.mobilenet = mobilenet_model
|
88 |
self.fc = nn.Sequential(
|
89 |
nn.Linear(num_classes * 2, 512),
|
@@ -93,13 +92,12 @@ class EnsembleModel(nn.Module):
|
|
93 |
)
|
94 |
|
95 |
def forward(self, x):
|
96 |
-
|
97 |
mobilenet_out = self.mobilenet(x)
|
98 |
-
combined = torch.cat((
|
99 |
output = self.fc(combined)
|
100 |
return output
|
101 |
|
102 |
-
|
103 |
convnext_model = CustomConvNeXtModel()
|
104 |
mobilenet_model = CustomMobileNetModel()
|
105 |
ensemble_model = EnsembleModel(convnext_model, mobilenet_model)
|
|
|
54 |
def forward(self, x):
|
55 |
return self.convnext(x)
|
56 |
|
|
|
57 |
class CustomMobileNetModel(nn.Module):
|
58 |
def __init__(self, weights=MobileNet_V2_Weights.DEFAULT, num_classes=2):
|
59 |
super().__init__()
|
|
|
80 |
return self.mobilenet(x)
|
81 |
|
82 |
class EnsembleModel(nn.Module):
|
83 |
+
def __init__(self, resnet_model, mobilenet_model, num_classes=2):
|
84 |
super().__init__()
|
85 |
+
self.resnet = resnet_model
|
86 |
self.mobilenet = mobilenet_model
|
87 |
self.fc = nn.Sequential(
|
88 |
nn.Linear(num_classes * 2, 512),
|
|
|
92 |
)
|
93 |
|
94 |
def forward(self, x):
|
95 |
+
resnet_out = self.resnet(x)
|
96 |
mobilenet_out = self.mobilenet(x)
|
97 |
+
combined = torch.cat((resnet_out, mobilenet_out), dim=1)
|
98 |
output = self.fc(combined)
|
99 |
return output
|
100 |
|
|
|
101 |
convnext_model = CustomConvNeXtModel()
|
102 |
mobilenet_model = CustomMobileNetModel()
|
103 |
ensemble_model = EnsembleModel(convnext_model, mobilenet_model)
|