NagaSaiAbhinay
commited on
Correct the forward call
Browse files
model.py
CHANGED
@@ -13,7 +13,7 @@ class CSDModel(PreTrainedModel):
|
|
13 |
|
14 |
@torch.inference_mode()
|
15 |
def forward(self, pixel_values):
|
16 |
-
features = self.backbone(pixel_values)
|
17 |
style_embeds = self.out_style(features)
|
18 |
content_embeds = self.out_content(features)
|
19 |
return features, style_embeds, content_embeds
|
|
|
13 |
|
14 |
@torch.inference_mode()
|
15 |
def forward(self, pixel_values):
|
16 |
+
features = self.backbone(pixel_values, return_dict=False)[1]
|
17 |
style_embeds = self.out_style(features)
|
18 |
content_embeds = self.out_content(features)
|
19 |
return features, style_embeds, content_embeds
|