NagaSaiAbhinay commited on
Commit
0dea13a
·
verified ·
1 Parent(s): 9b25486

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -0
model.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch.nn as nn
2
  from .config import CSDConfig
3
  from transformers import PreTrainedModel, CLIPVisionModel
@@ -10,6 +11,7 @@ class CSDModel(PreTrainedModel):
10
  self.out_style = nn.Linear(config.hidden_size, config.style_projection_dim, bias=False)
11
  self.out_content = nn.Linear(config.hidden_size, config.content_projection_dim, bias=False)
12
 
 
13
  def forward(self, pixel_values):
14
  features = self.backbone(pixel_values)
15
  style_embeds = self.out_style(features)
 
1
+ import torch
2
  import torch.nn as nn
3
  from .config import CSDConfig
4
  from transformers import PreTrainedModel, CLIPVisionModel
 
11
  self.out_style = nn.Linear(config.hidden_size, config.style_projection_dim, bias=False)
12
  self.out_content = nn.Linear(config.hidden_size, config.content_projection_dim, bias=False)
13
 
14
+ @torch.inference_mode()
15
  def forward(self, pixel_values):
16
  features = self.backbone(pixel_values)
17
  style_embeds = self.out_style(features)