File size: 4,800 Bytes
c220a5a
 
d5d4c61
c220a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5d4c61
 
 
c220a5a
 
 
 
 
 
 
d5d4c61
 
 
 
 
 
 
c220a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5d4c61
 
 
c220a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import logging
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput

from .configuration_layout_fidnet_v3 import LayoutFIDNetV3Config

logger = logging.getLogger(__name__)


@dataclass
class LayoutFIDNetV3Output(ModelOutput):
    logit_disc: torch.Tensor
    logit_cls: torch.Tensor
    bbox_pred: torch.Tensor


class TransformerWithToken(nn.Module):
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int,
        num_layers: int,
    ) -> None:
        super().__init__()

        self.token = nn.Parameter(torch.randn(1, 1, d_model))
        token_mask = torch.zeros(1, 1, dtype=torch.bool)
        self.register_buffer("token_mask", token_mask)

        self.core = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
            ),
            num_layers=num_layers,
        )

    def forward(self, x, src_key_padding_mask):
        # x: [N, B, E]
        # padding_mask: [B, N]
        #   `False` for valid values
        #   `True` for padded values

        B = x.size(1)

        token = self.token.expand(-1, B, -1)
        x = torch.cat([token, x], dim=0)

        token_mask = self.token_mask.expand(B, -1)
        padding_mask = torch.cat([token_mask, src_key_padding_mask], dim=1)

        x = self.core(x, src_key_padding_mask=padding_mask)

        return x


class LayoutFIDNetV3(PreTrainedModel):
    config_class = LayoutFIDNetV3Config

    def __init__(self, config: LayoutFIDNetV3Config) -> None:
        super().__init__(config)

        # encoder
        self.emb_label = nn.Embedding(config.num_labels, config.d_model)
        self.fc_bbox = nn.Linear(4, config.d_model)
        self.enc_fc_in = nn.Linear(config.d_model * 2, config.d_model)

        self.enc_transformer = TransformerWithToken(
            d_model=config.d_model,
            dim_feedforward=config.d_model // 2,
            nhead=config.nhead,
            num_layers=config.num_layers,
        )

        self.fc_out_disc = nn.Linear(config.d_model, 1)

        # decoder
        self.pos_token = nn.Parameter(torch.rand(config.max_bbox, 1, config.d_model))
        self.dec_fc_in = nn.Linear(config.d_model * 2, config.d_model)

        te = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.d_model // 2,
        )
        self.dec_transformer = nn.TransformerEncoder(te, num_layers=config.num_layers)

        self.fc_out_cls = nn.Linear(config.d_model, config.num_labels)
        self.fc_out_bbox = nn.Linear(config.d_model, 4)

    def extract_features(
        self, bbox: torch.Tensor, label: torch.Tensor, padding_mask: torch.Tensor
    ) -> torch.Tensor:
        b = self.fc_bbox(bbox)
        l = self.emb_label(label)
        x = self.enc_fc_in(torch.cat([b, l], dim=-1))
        x = torch.relu(x).permute(1, 0, 2)
        x = self.enc_transformer(x, padding_mask)
        return x[0]

    def forward(
        self,
        bbox: torch.Tensor,
        label: torch.Tensor,
        padding_mask: torch.Tensor,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, LayoutFIDNetV3Output]:
        B, N, _ = bbox.size()
        x = self.extract_features(bbox, label, padding_mask)

        logit_disc = self.fc_out_disc(x).squeeze(-1)

        x = x.unsqueeze(0).expand(N, -1, -1)
        t = self.pos_token[:N].expand(-1, B, -1)
        x = torch.cat([x, t], dim=-1)
        x = torch.relu(self.dec_fc_in(x))

        x = self.dec_transformer(x, src_key_padding_mask=padding_mask)
        # x = x.permute(1, 0, 2)[~padding_mask]
        x = x.permute(1, 0, 2)

        # logit_cls: [B, N, L]    bbox_pred: [B, N, 4]
        logit_cls = self.fc_out_cls(x)
        bbox_pred = torch.sigmoid(self.fc_out_bbox(x))

        if not return_dict:
            return logit_disc, logit_cls, bbox_pred

        return LayoutFIDNetV3Output(
            logit_disc=logit_disc, logit_cls=logit_cls, bbox_pred=bbox_pred
        )


def convert_from_checkpoint(
    repo_id: str, filename: str, config: Optional[LayoutFIDNetV3Config] = None
) -> LayoutFIDNetV3:
    from huggingface_hub import hf_hub_download

    checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
    config = config or LayoutFIDNetV3Config()
    model = LayoutFIDNetV3(config)

    logger.info(f"Loading model from {checkpoint_path}")
    state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]

    model.load_state_dict(state_dict, strict=True)
    model.eval()

    return model