import torch import torch.nn as nn from .config import CSDConfig from transformers import PreTrainedModel, CLIPVisionModel class CSDModel(PreTrainedModel): config_class = CSDConfig def __init__(self, config: CSDConfig): super().__init__(config) self.backbone = CLIPVisionModel(config) self.out_style = nn.Linear(config.hidden_size, config.style_projection_dim, bias=False) self.out_content = nn.Linear(config.hidden_size, config.content_projection_dim, bias=False) @torch.inference_mode() def forward(self, pixel_values): features = self.backbone(pixel_values, return_dict=False)[1] style_embeds = self.out_style(features) content_embeds = self.out_content(features) return features, style_embeds, content_embeds