shunk031 commited on
Commit
ccd6fbb
1 Parent(s): 2ad102c

Upload LayoutDmFIDNetV3

Browse files
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LayoutDmFIDNetV3"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_fidnet_v3.LayoutDmFIDNetV3Config",
7
+ "AutoModel": "modeling_fidnet_v3.LayoutDmFIDNetV3"
8
+ },
9
+ "d_model": 256,
10
+ "id2label": {
11
+ "0": "text",
12
+ "1": "title",
13
+ "2": "list",
14
+ "3": "table",
15
+ "4": "figure"
16
+ },
17
+ "label2id": {
18
+ "figure": 4,
19
+ "list": 2,
20
+ "table": 3,
21
+ "text": 0,
22
+ "title": 1
23
+ },
24
+ "max_bbox": 25,
25
+ "model_type": "layoutdm_fidnet_v3",
26
+ "nhead": 4,
27
+ "num_layers": 4,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.36.2"
30
+ }
configuration_fidnet_v3.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class LayoutDmFIDNetV3Config(PretrainedConfig):
5
+ model_type = "layoutdm_fidnet_v3"
6
+
7
+ def __init__(
8
+ self,
9
+ d_model: int = 256,
10
+ nhead: int = 4,
11
+ num_layers: int = 4,
12
+ max_bbox: int = 50,
13
+ **kwargs
14
+ ):
15
+ super().__init__(**kwargs)
16
+ self.d_model = d_model
17
+ self.nhead = nhead
18
+ self.num_layers = num_layers
19
+ self.max_bbox = max_bbox
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f8e2ed0173dbc3942b4f58b0664fa2917ec46bfad5e8b59e4301d2a2ea44c80
3
+ size 11673377
modeling_fidnet_v3.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers.modeling_utils import PreTrainedModel
6
+
7
+ from .configuration_fidnet_v3 import LayoutDmFIDNetV3Config
8
+
9
+
10
+ @dataclass
11
+ class LayoutDmFIDNetV3Output(object):
12
+ logit_dict: torch.Tensor
13
+ logit_cls: torch.Tensor
14
+ bbox_pred: torch.Tensor
15
+
16
+
17
+ class TransformerWithToken(nn.Module):
18
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int, num_layers: int):
19
+ super().__init__()
20
+
21
+ self.token = nn.Parameter(torch.randn(1, 1, d_model))
22
+ token_mask = torch.zeros(1, 1, dtype=torch.bool)
23
+ self.register_buffer("token_mask", token_mask)
24
+
25
+ self.core = nn.TransformerEncoder(
26
+ nn.TransformerEncoderLayer(
27
+ d_model=d_model,
28
+ nhead=nhead,
29
+ dim_feedforward=dim_feedforward,
30
+ ),
31
+ num_layers=num_layers,
32
+ )
33
+
34
+ def forward(self, x, src_key_padding_mask):
35
+ # x: [N, B, E]
36
+ # padding_mask: [B, N]
37
+ # `False` for valid values
38
+ # `True` for padded values
39
+
40
+ B = x.size(1)
41
+
42
+ token = self.token.expand(-1, B, -1)
43
+ x = torch.cat([token, x], dim=0)
44
+
45
+ token_mask = self.token_mask.expand(B, -1)
46
+ padding_mask = torch.cat([token_mask, src_key_padding_mask], dim=1)
47
+
48
+ x = self.core(x, src_key_padding_mask=padding_mask)
49
+
50
+ return x
51
+
52
+
53
+ class LayoutDmFIDNetV3(PreTrainedModel):
54
+ config_class = LayoutDmFIDNetV3Config
55
+
56
+ def __init__(self, config: LayoutDmFIDNetV3Config):
57
+ super().__init__(config)
58
+ self.config = config
59
+
60
+ # encoder
61
+ self.emb_label = nn.Embedding(config.num_labels, config.d_model)
62
+ self.fc_bbox = nn.Linear(4, config.d_model)
63
+ self.enc_fc_in = nn.Linear(config.d_model * 2, config.d_model)
64
+
65
+ self.enc_transformer = TransformerWithToken(
66
+ d_model=config.d_model,
67
+ dim_feedforward=config.d_model // 2,
68
+ nhead=config.nhead,
69
+ num_layers=config.num_layers,
70
+ )
71
+
72
+ self.fc_out_disc = nn.Linear(config.d_model, 1)
73
+
74
+ # decoder
75
+ self.pos_token = nn.Parameter(torch.rand(config.max_bbox, 1, config.d_model))
76
+ self.dec_fc_in = nn.Linear(config.d_model * 2, config.d_model)
77
+
78
+ te = nn.TransformerEncoderLayer(
79
+ d_model=config.d_model,
80
+ nhead=config.nhead,
81
+ dim_feedforward=config.d_model // 2,
82
+ )
83
+ self.dec_transformer = nn.TransformerEncoder(te, num_layers=config.num_layers)
84
+
85
+ self.fc_out_cls = nn.Linear(config.d_model, config.num_labels)
86
+ self.fc_out_bbox = nn.Linear(config.d_model, 4)
87
+
88
+ def extract_features(self, bbox, label, padding_mask):
89
+ b = self.fc_bbox(bbox)
90
+ l = self.emb_label(label)
91
+ x = self.enc_fc_in(torch.cat([b, l], dim=-1))
92
+ x = torch.relu(x).permute(1, 0, 2)
93
+ x = self.enc_transformer(x, padding_mask)
94
+ return x[0]
95
+
96
+ def forward(self, bbox, label, padding_mask):
97
+ B, N, _ = bbox.size()
98
+ x = self.extract_features(bbox, label, padding_mask)
99
+
100
+ logit_disc = self.fc_out_disc(x).squeeze(-1)
101
+
102
+ x = x.unsqueeze(0).expand(N, -1, -1)
103
+ t = self.pos_token[:N].expand(-1, B, -1)
104
+ x = torch.cat([x, t], dim=-1)
105
+ x = torch.relu(self.dec_fc_in(x))
106
+
107
+ x = self.dec_transformer(x, src_key_padding_mask=padding_mask)
108
+ # x = x.permute(1, 0, 2)[~padding_mask]
109
+ x = x.permute(1, 0, 2)
110
+
111
+ # logit_cls: [B, N, L] bbox_pred: [B, N, 4]
112
+ logit_cls = self.fc_out_cls(x)
113
+ bbox_pred = torch.sigmoid(self.fc_out_bbox(x))
114
+
115
+ return LayoutDmFIDNetV3Output(
116
+ logit_disc=logit_disc, logit_cls=logit_cls, bbox_pred=bbox_pred
117
+ )