CSD / model.py
NagaSaiAbhinay's picture
Correct the forward call
ff4f79e verified
raw
history blame
794 Bytes
import torch
import torch.nn as nn
from .config import CSDConfig
from transformers import PreTrainedModel, CLIPVisionModel
class CSDModel(PreTrainedModel):
config_class = CSDConfig
def __init__(self, config: CSDConfig):
super().__init__(config)
self.backbone = CLIPVisionModel(config)
self.out_style = nn.Linear(config.hidden_size, config.style_projection_dim, bias=False)
self.out_content = nn.Linear(config.hidden_size, config.content_projection_dim, bias=False)
@torch.inference_mode()
def forward(self, pixel_values):
features = self.backbone(pixel_values, return_dict=False)[1]
style_embeds = self.out_style(features)
content_embeds = self.out_content(features)
return features, style_embeds, content_embeds