jiayicccc commited on
Commit
8ac1d9a
·
verified ·
1 Parent(s): 422bd2d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -6
README.md CHANGED
@@ -54,7 +54,6 @@ class CustomConvNeXtModel(nn.Module):
54
  def forward(self, x):
55
  return self.convnext(x)
56
 
57
-
58
  class CustomMobileNetModel(nn.Module):
59
  def __init__(self, weights=MobileNet_V2_Weights.DEFAULT, num_classes=2):
60
  super().__init__()
@@ -81,9 +80,9 @@ class CustomMobileNetModel(nn.Module):
81
  return self.mobilenet(x)
82
 
83
  class EnsembleModel(nn.Module):
84
- def __init__(self, convnext_model, mobilenet_model, num_classes=2):
85
  super().__init__()
86
- self.convnext = convnext_model
87
  self.mobilenet = mobilenet_model
88
  self.fc = nn.Sequential(
89
  nn.Linear(num_classes * 2, 512),
@@ -93,13 +92,12 @@ class EnsembleModel(nn.Module):
93
  )
94
 
95
  def forward(self, x):
96
- convnext_out = self.convnext(x)
97
  mobilenet_out = self.mobilenet(x)
98
- combined = torch.cat((convnext_out, mobilenet_out), dim=1)
99
  output = self.fc(combined)
100
  return output
101
 
102
-
103
  convnext_model = CustomConvNeXtModel()
104
  mobilenet_model = CustomMobileNetModel()
105
  ensemble_model = EnsembleModel(convnext_model, mobilenet_model)
 
54
  def forward(self, x):
55
  return self.convnext(x)
56
 
 
57
  class CustomMobileNetModel(nn.Module):
58
  def __init__(self, weights=MobileNet_V2_Weights.DEFAULT, num_classes=2):
59
  super().__init__()
 
80
  return self.mobilenet(x)
81
 
82
  class EnsembleModel(nn.Module):
83
+ def __init__(self, resnet_model, mobilenet_model, num_classes=2):
84
  super().__init__()
85
+ self.resnet = resnet_model
86
  self.mobilenet = mobilenet_model
87
  self.fc = nn.Sequential(
88
  nn.Linear(num_classes * 2, 512),
 
92
  )
93
 
94
  def forward(self, x):
95
+ resnet_out = self.resnet(x)
96
  mobilenet_out = self.mobilenet(x)
97
+ combined = torch.cat((resnet_out, mobilenet_out), dim=1)
98
  output = self.fc(combined)
99
  return output
100
 
 
101
  convnext_model = CustomConvNeXtModel()
102
  mobilenet_model = CustomMobileNetModel()
103
  ensemble_model = EnsembleModel(convnext_model, mobilenet_model)