|
import torch |
|
import torch.nn as nn |
|
from .config import CSDConfig |
|
from transformers import PreTrainedModel, CLIPVisionModel |
|
|
|
class CSDModel(PreTrainedModel): |
|
config_class = CSDConfig |
|
def __init__(self, config: CSDConfig): |
|
super().__init__(config) |
|
self.backbone = CLIPVisionModel(config) |
|
self.out_style = nn.Linear(config.hidden_size, config.style_projection_dim, bias=False) |
|
self.out_content = nn.Linear(config.hidden_size, config.content_projection_dim, bias=False) |
|
|
|
@torch.inference_mode() |
|
def forward(self, pixel_values): |
|
features = self.backbone(pixel_values, return_dict=False)[1] |
|
style_embeds = self.out_style(features) |
|
content_embeds = self.out_content(features) |
|
return features, style_embeds, content_embeds |