File size: 17,032 Bytes
071945c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
# coding=utf-8

# Copyright 2024 LY Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import logging
from typing import Optional, Union

import timm
import torch
import torch.distributed as dist
import torch.distributed.nn
import torch.nn as nn
import torch.nn.functional as F
from timm.models.swin_transformer import SwinTransformer as TimmSwinTransformer
from transformers import PreTrainedModel
from transformers.utils.logging import get_logger

from .configuration_clyp import (
    CLYPTextBackboneConfig,
    CLYPTextEncoderConfig,
    CLYPVisionBackboneConfig,
    CLYPVisionEncoderConfig,
)
from .model_rinna import RinnaCLIPConfig, RinnaCLIPModel

DEFAULT_LOGGER = get_logger(__name__)


class VisionEncoder(nn.Module):
    """Vision encoder to extract image feateurs.

    Pooler and neck are optional.
    Instead of defining pooler and neck in VisionEncoder, you can define them in algorithm classes.

    Attributes:
        backbone (nn.Module): backbone loaded from timm, huggingface or registry.
        pooler (nn.Module): module to extract image-level features.
        neck (nn.Module): module to adjust feature dimensions.
    """

    def __init__(
        self,
        backbone: nn.Module,
        pooler: Optional[nn.Module] = None,
        neck: Optional[nn.Module] = None,
    ) -> None:
        super().__init__()
        self.backbone = backbone
        self.pooler = pooler
        self.neck = neck

    def forward(self, imgs: torch.Tensor):
        """A method to extract image features.

        Args:
            imgs (torch.Tensor): shape=(batch_size, channels, height, width).

        Returns:
            out (torch.Tensor): the output shape changes depending on pooler, and the following shapes are usually expected.
                - output only image-level features like CLIP: shape=(batch_size, embed_dim)
                - output image-level and local patch features like BLIP2: shape=(batch_size, embed_dim, length)
        """
        out = self.backbone(imgs)  # Shape=(batch_size, channels, height, width)
        if self.pooler:
            out = self.pooler(out)
        if self.neck:
            out = self.neck(out)
        return out


class SwinTransformerPerm(nn.Module):
    """Wrapper for SwinTransformer in timm.

    This wrapper changes the output shape to (batch_size, channels, height, width).
    The original shape of timm SwinTransformer is (batch_size, height, width, channels).
    """

    def __init__(self, swin: nn.Module) -> None:
        super().__init__()
        self.swin = swin

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.swin(x)
        out = out.permute(0, 3, 1, 2)
        return out


def load_from_timm(
    config: CLYPVisionBackboneConfig,
    use_gradient_checkpointing: bool = False,
    path_weights: Optional[str] = None,
    logger: logging.Logger = DEFAULT_LOGGER,
):
    """Create a backbone using a method: timm.create_model.

    Args:
        config (TimmBackboneConfig): config fed to timm.create_model.
        use_gradient_checkpointing (bool): True if use gradient checkpointing.
        path_weights (str): path to weights for backbone initialization.
    """
    # backbone
    assert config is not None
    backbone = timm.create_model(
        model_name=config.model_name,
        pretrained=config.pretrained,
        **config.extra_kwargs,
    )
    backbone.reset_classifier(0, "")

    logger.info(
        f"    - load from timm: model_name={config.model_name}, pretrained={config.pretrained}"
    )

    # gradient checkpointing
    backbone.set_grad_checkpointing(enable=use_gradient_checkpointing)
    if use_gradient_checkpointing:
        logger.info("    - gradient checkpointing is enebled.")

    # init weights
    if path_weights:
        state_dict = torch.load(path_weights, map_location="cpu")
        checks = backbone.load_state_dict(state_dict, strict=False)
        logger.info(f"    - load weights from {path_weights}")
        logger.info(f"    - state dict checks: {checks}")

    # swin
    if isinstance(backbone, TimmSwinTransformer):
        backbone = SwinTransformerPerm(backbone)
    return backbone


def create_vision_encoder(
    config: CLYPVisionEncoderConfig, logger: logging.Logger = DEFAULT_LOGGER
) -> VisionEncoder:
    assert config.pooler_config.input_type
    backbone = load_from_timm(config.backbone_config, logger=logger)
    pooler = CLSTokenPooling(
        config.pooler_config.input_type, config.pooler_config.return_patch_features
    )
    neck = Linear(
        config.neck_config.in_channels,
        config.neck_config.out_channels,
        config.neck_config.bias,
    )
    return VisionEncoder(backbone, pooler=pooler, neck=neck)


class TextEncoder(nn.Module):
    """Text encoder to extract text features.

    Pooler and neck are optional.
    Instead of defining pooler and neck in TextEncoder, you can define them in algorithm classes.

    Attributes:
        backbone (nn.Module): backbone loaded from timm, huggingface or registry.
        pooler (nn.Module): module to extract image-level features.
        neck (nn.Module): module to adjust feature dimensions.

    """

    def __init__(
        self,
        backbone: nn.Module,
        pooler: Optional[nn.Module] = None,
        neck: Optional[nn.Module] = None,
    ) -> None:
        super().__init__()
        self.backbone = backbone
        self.pooler = pooler
        self.neck = neck

    def forward(self, inputs: dict) -> torch.Tensor:
        """A method to extract text features.

        Args:
            inputs (dict): basic keys are shown below:
                - input_ids (torch.Tensor)
                - attention_mask (Optional[torch.Tensor])
                - position_ids (Optional[torch.Tensor])
                - token_type_ids (Optional[torch.Tensor])
                - output_attentions Optional[bool]
                - output_hidden_states Optional[bool]

        Returns:
            out (torch.Tensor): the output shape changes depending on pooler, and the following shapes are usually expected.
                - output only class token like CLIP: shape=(batch_size, embed_dim)
                - output all token features like BLIP2: shape=(batch_size, embed_dim, length)
        """
        out = self.backbone(**inputs)
        if self.pooler:
            out = self.pooler(out)
        if self.neck:
            out = self.neck(out)
        return out


class TextBackboneModelWrapper(nn.Module):
    def __init__(self, model: nn.Module) -> None:
        super().__init__()
        self.model = model.text_model

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        out = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
        )
        return out

    def set_gradient_checkpointing(self, enabled: bool) -> None:
        if enabled:
            self.model.gradient_checkpointing_enable()


def load_from_huggingface(
    config: CLYPTextBackboneConfig,
    use_gradient_checkpointing: bool = False,
    path_weights: Optional[str] = None,
    logger: logging.Logger = DEFAULT_LOGGER,
) -> nn.Module:
    """Load a backbone from huggingface.

    Args:
        config (HuggingfaceBackboneConfig): config fed to AutoModel.from_pretrained.
        use_gradient_checkpointing (bool): True if use gradient checkpointing.
        path_weights (str): path to weights for backbone initialization.
    """

    # NOTE:
    # Initialize Rinna CLIP without pretrained weights here,
    # because CLYP model loads its whole weights afterward
    auto_config = RinnaCLIPConfig.from_pretrained(config.model_name)
    backbone = RinnaCLIPModel(auto_config)

    logger.info(f"    - load from huggingface: model_name={config.model_name}")

    # gradient checkpointing
    if isinstance(backbone, PreTrainedModel):
        if use_gradient_checkpointing:
            backbone.gradient_checkpointing_enable()
            logger.info("    - gradient checkpointing is enabled")
    else:
        raise NotImplementedError()

    # init weights
    if path_weights:
        raise NotImplementedError()
    return backbone


def create_text_encoder(
    config: CLYPTextEncoderConfig, logger: logging.Logger = DEFAULT_LOGGER
) -> TextEncoder:
    assert config.pooler_config.input_type
    backbone = TextBackboneModelWrapper(
        load_from_huggingface(config.backbone_config, logger=logger)
    )
    pooler = CLSTokenPooling(
        config.pooler_config.input_type, config.pooler_config.return_patch_features
    )
    neck = Linear(
        config.neck_config.in_channels,
        config.neck_config.out_channels,
        bias=config.neck_config.bias,
    )
    return TextEncoder(backbone, pooler=pooler, neck=neck)


class Linear(nn.Module):
    """Linear layer."""

    def __init__(self, in_channels: int, out_channels: int, bias: bool) -> None:
        """
        Args:
            in_channels (int): input feature dimension.
            out_channels (out): output feature dimension.
            bias (bool): True if use bias in nn.Linear.
        """
        super().__init__()
        self.linear = nn.Linear(in_channels, out_channels, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): shape=(batch_size, ..., in_channels).

        Returns:
            out (torch.Tensor): shape=(batch_size, ..., out_channels).
        """
        out = self.linear(x)
        return out


class CLSTokenPooling(nn.Module):
    """A module to extract class token."""

    def __init__(self, input_type: str, return_patch_features: bool) -> None:
        """
        Args:
            input_type (str): timm or huggingface.
                - If input_type is timm, x[:, 0] is extracted as a class token.
                - If input_type is huggingface, x.last_hidden_state[:,0] is extracted as a class token.
            return_patch_features (bool): True if output local features.
        """
        super().__init__()
        assert input_type in ["timm", "huggingface"]
        self.input_type = input_type
        self.return_patch_features = return_patch_features

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): shape=(batch_size, length, dim).

        Returns:
            out (torch.Tensor): shape=(batch_size, dim).
        """
        # tensor: shape=(batch_size, length, dim)
        if self.input_type == "timm":
            assert x.ndim == 3, "CLSTokenPooling: dimension of input tensor must be 3."
            if self.return_patch_features:
                return x
            else:
                return x[:, 0]

        # huggingface
        elif self.input_type == "huggingface":
            out = x.last_hidden_state
            if self.return_patch_features:
                return out
            else:
                return out[:, 0]


class InfoNCELoss(nn.Module):
    def __init__(
        self,
        learn_temperature: bool,
        init_temperature: float,
        max_temperature: Optional[float] = None,
        min_temperature: Optional[float] = None,
        label_smoothing: float = 0.0,
        gather_with_grad: bool = False,
    ):
        super().__init__()
        self.label_smoothing = label_smoothing
        self.gather_with_grad = gather_with_grad

        # set temperature
        self.learn_temperature = learn_temperature
        self.temperature = torch.ones([]) * init_temperature
        if self.learn_temperature:
            self.temperature = nn.Parameter(self.temperature)
            self.max_temperature = max_temperature
            self.min_temperature = min_temperature

        # whether clip temperature or not
        self.require_temperature_clipping = self.learn_temperature and (
            self.max_temperature or self.min_temperature
        )

    def clip_temperature(self):
        if self.require_temperature_clipping:
            self.temperature.data = torch.clamp(
                self.temperature, self.min_temperature, self.max_temperature
            )

    def forward(
        self,
        image_feats: torch.Tensor,
        text_feats: torch.Tensor,
        return_similarity: bool = False,
    ) -> Union[torch.Tensor, tuple[torch.Tensor]]:
        # gather image and text features
        image_feats_all = concat_all_gather(
            image_feats, with_grad=self.gather_with_grad
        )
        text_feats_all = concat_all_gather(text_feats, with_grad=self.gather_with_grad)

        # compute cosine similarity
        sim_i2t = image_to_text_similarity(
            image_feats=image_feats,
            text_feats=text_feats_all,
        )
        sim_t2i = text_to_image_similarity(
            text_feats=text_feats,
            image_feats=image_feats_all,
        )

        # logits, scaled cosine similarity
        logits_i2t = sim_i2t / self.temperature
        logits_t2i = sim_t2i / self.temperature

        # obtain targets
        rank = dist.get_rank()
        batch_size = image_feats.size(0)
        targets = torch.arange(batch_size) + batch_size * rank
        targets = targets.to(dtype=torch.long, device=image_feats.device)

        # calculate loss
        loss_i2t = F.cross_entropy(
            logits_i2t, targets, label_smoothing=self.label_smoothing
        )
        loss_t2i = F.cross_entropy(
            logits_t2i, targets, label_smoothing=self.label_smoothing
        )
        loss = (loss_i2t + loss_t2i) / 2.0

        if not return_similarity:
            return loss
        else:
            return loss, sim_i2t, sim_t2i


def image_to_text_similarity(
    image_feats: torch.Tensor, text_feats: torch.Tensor
) -> torch.Tensor:
    """
    Args:
        image_feats (torch.Tensor): shape=(num_imgs, embed_dim) or (num_imgs, num_query_tokens, embed_dim).
        text_feats (torch.Tensor): shape=(num_texts, embed_dim).

    Returns:
        sim_i2t (torch.Tensor): shape=(num_imgs, num_texts).
    """
    assert image_feats.ndim in [2, 3]
    assert text_feats.ndim == 2

    # normalize features
    image_feats = F.normalize(image_feats, dim=-1)
    text_feats = F.normalize(text_feats, dim=-1)

    if image_feats.ndim == 2:
        sim_i2t = image_feats @ text_feats.T
    else:
        # a query token with maximum cosine similarity is selected
        sim_i2t = torch.matmul(
            image_feats.unsqueeze(1), text_feats.unsqueeze(0).unsqueeze(-1)
        ).squeeze()  # shape=(num_imgs, num_texts, num_query_tokens)
        sim_i2t, _ = sim_i2t.max(dim=-1)  # shape=(num_imgs, num_texts)
    return sim_i2t


def text_to_image_similarity(text_feats: torch.Tensor, image_feats: torch.Tensor):
    """
    Args:
        text_feats (torch.Tensor): shape=(num_texts, embed_dim).
        image_feats (torch.Tensor): shape=(num_imgs, embed_dim) or (num_imgs, num_query_tokens, embed_dim).

    Returns:
        similarity_maxtrix (torch.Tensor): shape=(num_texts, num_imgs).
    """
    assert image_feats.ndim in [2, 3]
    assert text_feats.ndim == 2

    # normalize features
    image_feats = F.normalize(image_feats, dim=-1)
    text_feats = F.normalize(text_feats, dim=-1)

    if image_feats.ndim == 2:
        sim_t2i = text_feats @ image_feats.T
    else:
        # a query token with maximum cosine similarity is selected
        sim_t2i = torch.matmul(
            text_feats.unsqueeze(1).unsqueeze(1),
            image_feats.permute(0, 2, 1).unsqueeze(0),
        ).squeeze()
        sim_t2i, _ = sim_t2i.max(dim=-1)
    return sim_t2i


def concat_all_gather(tensor: torch.Tensor, with_grad: bool):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.

    Another implementation: https://github.com/salesforce/LAVIS/blob/main/lavis/models/base_model.py#L202-L237
    """
    if with_grad:
        output = torch.cat(torch.distributed.nn.all_gather(tensor), dim=0)
    else:
        tensors_gather = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
        dist.all_gather(tensors_gather, tensor, async_op=False)
        output = torch.cat(tensors_gather, dim=0)
    return output