|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling |
|
from transformers.models.clip import CLIPPreTrainedModel, CLIPTextConfig, CLIPTextModel |
|
from transformers.models.clip.modeling_clip import ( |
|
CLIP_TEXT_INPUTS_DOCSTRING, |
|
CLIPTextTransformer, |
|
_expand_mask, |
|
_make_causal_mask, |
|
) |
|
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings |
|
|
|
CLIP_SKIP_TEXT_INPUTS_DOCSTRING = ( |
|
CLIP_TEXT_INPUTS_DOCSTRING |
|
+ r""" |
|
clip_skip (`int`, *optional*, defaults to 1): |
|
Skip the final N layers of the CLIP text encoder. Some Diffusion models were trained |
|
using the hidden states from the 2nd-last layer of the CLIP text encoder (ie clip_skip=2), |
|
so we reproduce that behavior here for use with those models. |
|
""" |
|
) |
|
|
|
|
|
class CLIPSkipTextTransformer(CLIPTextTransformer): |
|
@add_start_docstrings_to_model_forward(CLIP_SKIP_TEXT_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
clip_skip: int = 1, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
r""" |
|
Returns: |
|
|
|
""" |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is None: |
|
raise ValueError("You have to specify input_ids") |
|
|
|
input_shape = input_ids.size() |
|
input_ids = input_ids.view(-1, input_shape[-1]) |
|
|
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) |
|
|
|
|
|
|
|
causal_attention_mask = _make_causal_mask( |
|
input_shape, hidden_states.dtype, device=hidden_states.device |
|
) |
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) |
|
|
|
encoder_outputs: BaseModelOutput = self.encoder( |
|
inputs_embeds=hidden_states, |
|
attention_mask=attention_mask, |
|
causal_attention_mask=causal_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=True, |
|
return_dict=True, |
|
) |
|
|
|
|
|
|
|
last_hidden_state = encoder_outputs.hidden_states[-clip_skip] |
|
last_hidden_state = self.final_layer_norm(last_hidden_state) |
|
|
|
|
|
|
|
|
|
pooled_output = last_hidden_state[ |
|
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), |
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), |
|
] |
|
|
|
if not return_dict: |
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=last_hidden_state, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
def _build_causal_attention_mask(self, bsz, seq_len, dtype): |
|
|
|
|
|
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) |
|
mask.fill_(torch.tensor(torch.finfo(dtype).min)) |
|
mask.triu_(1) |
|
mask = mask.unsqueeze(1) |
|
return mask |
|
|
|
|
|
class CLIPSkipTextModel(CLIPTextModel): |
|
config_class = CLIPTextConfig |
|
|
|
_no_split_modules = ["CLIPEncoderLayer"] |
|
|
|
def __init__(self, config: CLIPTextConfig): |
|
super().__init__(config) |
|
self.text_model = CLIPSkipTextTransformer(config) |
|
|
|
self.post_init() |
|
|
|
@add_start_docstrings_to_model_forward(CLIP_SKIP_TEXT_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
clip_skip: int = 1, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
r""" |
|
Returns: |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, CLIPSkipTextModel |
|
|
|
>>> model = CLIPSkipTextModel.from_pretrained("openai/clip-vit-base-patch32") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") |
|
|
|
>>> outputs = model(**inputs) |
|
>>> last_hidden_state = outputs.last_hidden_state |
|
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states |
|
```""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
return self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
clip_skip=clip_skip, |
|
) |
|
|