CSD / model.py
NagaSaiAbhinay's picture
Correct the forward call
ff4f79e verified
raw
history blame contribute delete
794 Bytes
import torch
import torch.nn as nn
from .config import CSDConfig
from transformers import PreTrainedModel, CLIPVisionModel
class CSDModel(PreTrainedModel):
config_class = CSDConfig
def __init__(self, config: CSDConfig):
super().__init__(config)
self.backbone = CLIPVisionModel(config)
self.out_style = nn.Linear(config.hidden_size, config.style_projection_dim, bias=False)
self.out_content = nn.Linear(config.hidden_size, config.content_projection_dim, bias=False)
@torch.inference_mode()
def forward(self, pixel_values):
features = self.backbone(pixel_values, return_dict=False)[1]
style_embeds = self.out_style(features)
content_embeds = self.out_content(features)
return features, style_embeds, content_embeds