CSD / model.py
NagaSaiAbhinay's picture
Update model.py
0dea13a verified
raw
history blame
718 Bytes
import torch
import torch.nn as nn
from .config import CSDConfig
from transformers import PreTrainedModel, CLIPVisionModel
class CSDModel(PreTrainedModel):
config_class = CSDConfig
def __init__(self, config: CSDConfig):
super().__init__(config)
self.backbone = CLIPVisionModel(config)
self.out_style = nn.Linear(config.hidden_size, config.style_projection_dim, bias=False)
self.out_content = nn.Linear(config.hidden_size, config.content_projection_dim, bias=False)
@torch.inference_mode()
def forward(self, pixel_values):
features = self.backbone(pixel_values)
style_embeds = self.out_style(features)
content_embeds = self.out_content(features)