File size: 5,006 Bytes
7d55fca a9a821c 7d55fca a9a821c 7d55fca a9a821c 7d55fca a9a821c 7d55fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from collections import OrderedDict
from typing import Dict, Final, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import CLIPVisionModelWithProjection, logging
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
from .configuration_predictor import AestheticsPredictorConfig
logging.set_verbosity_error()
URLS_LINEAR: Final[Dict[str, str]] = {
"sac+logos+ava1-l14-linearMSE": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth",
"ava+logos-l14-linearMSE": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/ava%2Blogos-l14-linearMSE.pth",
}
URLS_RELU: Final[Dict[str, str]] = {
"ava+logos-l14-reluMSE": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/ava%2Blogos-l14-reluMSE.pth",
}
class AestheticsPredictorV2Linear(CLIPVisionModelWithProjection):
def __init__(self, config: AestheticsPredictorConfig) -> None:
super().__init__(config)
self.layers = nn.Sequential(
nn.Linear(config.projection_dim, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
self.post_init()
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = super().forward(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = outputs[0] # image_embeds
image_embeds /= image_embeds.norm(dim=-1, keepdim=True)
prediction = self.layers(image_embeds)
loss = None
if labels is not None:
loss_fct = nn.MSELoss()
loss = loss_fct()
if not return_dict:
return (loss, prediction, image_embeds)
return ImageClassifierOutputWithNoAttention(
loss=loss,
logits=prediction,
hidden_states=image_embeds,
)
class AestheticsPredictorV2ReLU(AestheticsPredictorV2Linear):
def __init__(self, config: AestheticsPredictorConfig) -> None:
super().__init__(config)
self.layers = nn.Sequential(
nn.Linear(config.projection_dim, 1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.ReLU(),
nn.Linear(16, 1),
)
self.post_init()
def convert_v2_linear_from_openai_clip(
predictor_head_name: str,
openai_model_name: str = "openai/clip-vit-large-patch14",
config: Optional[AestheticsPredictorConfig] = None,
) -> AestheticsPredictorV2Linear:
config = config or AestheticsPredictorConfig.from_pretrained(openai_model_name)
model = AestheticsPredictorV2Linear(config)
clip_model = CLIPVisionModelWithProjection.from_pretrained(openai_model_name)
model.load_state_dict(clip_model.state_dict(), strict=False)
state_dict = torch.hub.load_state_dict_from_url(
URLS_LINEAR[predictor_head_name], map_location="cpu"
)
assert isinstance(state_dict, OrderedDict)
# remove `layers.` from the key of the state_dict
state_dict = OrderedDict(
((k.replace("layers.", ""), v) for k, v in state_dict.items())
)
model.layers.load_state_dict(state_dict)
model.eval()
return model
def convert_v2_relu_from_openai_clip(
predictor_head_name: str,
openai_model_name: str = "openai/clip-vit-large-patch14",
config: Optional[AestheticsPredictorConfig] = None,
) -> AestheticsPredictorV2ReLU:
config = config or AestheticsPredictorConfig.from_pretrained(openai_model_name)
model = AestheticsPredictorV2ReLU(config)
clip_model = CLIPVisionModelWithProjection.from_pretrained(openai_model_name)
model.load_state_dict(clip_model.state_dict(), strict=False)
state_dict = torch.hub.load_state_dict_from_url(
URLS_RELU[predictor_head_name], map_location="cpu"
)
assert isinstance(state_dict, OrderedDict)
# remove `layers.` from the key of the state_dict
state_dict = OrderedDict(
((k.replace("layers.", ""), v) for k, v in state_dict.items())
)
model.layers.load_state_dict(state_dict)
model.eval()
return model
|