shunk031 commited on
Commit
6e906e7
1 Parent(s): 7455fc8

Upload LayoutFIDNetV3

Browse files
Files changed (2) hide show
  1. config.json +8 -8
  2. modeling_layout_fidnet_v3.py +14 -3
config.json CHANGED
@@ -8,16 +8,16 @@
8
  },
9
  "d_model": 256,
10
  "id2label": {
11
- "0": "LABEL_0",
12
- "1": "LABEL_1",
13
- "2": "LABEL_2",
14
- "3": "LABEL_3"
15
  },
16
  "label2id": {
17
- "LABEL_0": 0,
18
- "LABEL_1": 1,
19
- "LABEL_2": 2,
20
- "LABEL_3": 3
21
  },
22
  "max_bbox": 10,
23
  "model_type": "layoutdm_fidnet_v3",
 
8
  },
9
  "d_model": 256,
10
  "id2label": {
11
+ "0": "logo",
12
+ "1": "text",
13
+ "2": "underlay",
14
+ "3": "embellishment"
15
  },
16
  "label2id": {
17
+ "embellishment": 3,
18
+ "logo": 0,
19
+ "text": 1,
20
+ "underlay": 2
21
  },
22
  "max_bbox": 10,
23
  "model_type": "layoutdm_fidnet_v3",
modeling_layout_fidnet_v3.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
  from dataclasses import dataclass
3
- from typing import Optional
4
 
5
  import torch
6
  import torch.nn as nn
@@ -95,7 +95,9 @@ class LayoutFIDNetV3(PreTrainedModel):
95
  self.fc_out_cls = nn.Linear(config.d_model, config.num_labels)
96
  self.fc_out_bbox = nn.Linear(config.d_model, 4)
97
 
98
- def extract_features(self, bbox, label, padding_mask):
 
 
99
  b = self.fc_bbox(bbox)
100
  l = self.emb_label(label)
101
  x = self.enc_fc_in(torch.cat([b, l], dim=-1))
@@ -103,7 +105,13 @@ class LayoutFIDNetV3(PreTrainedModel):
103
  x = self.enc_transformer(x, padding_mask)
104
  return x[0]
105
 
106
- def forward(self, bbox, label, padding_mask):
 
 
 
 
 
 
107
  B, N, _ = bbox.size()
108
  x = self.extract_features(bbox, label, padding_mask)
109
 
@@ -122,6 +130,9 @@ class LayoutFIDNetV3(PreTrainedModel):
122
  logit_cls = self.fc_out_cls(x)
123
  bbox_pred = torch.sigmoid(self.fc_out_bbox(x))
124
 
 
 
 
125
  return LayoutFIDNetV3Output(
126
  logit_disc=logit_disc, logit_cls=logit_cls, bbox_pred=bbox_pred
127
  )
 
1
  import logging
2
  from dataclasses import dataclass
3
+ from typing import Optional, Tuple, Union
4
 
5
  import torch
6
  import torch.nn as nn
 
95
  self.fc_out_cls = nn.Linear(config.d_model, config.num_labels)
96
  self.fc_out_bbox = nn.Linear(config.d_model, 4)
97
 
98
+ def extract_features(
99
+ self, bbox: torch.Tensor, label: torch.Tensor, padding_mask: torch.Tensor
100
+ ) -> torch.Tensor:
101
  b = self.fc_bbox(bbox)
102
  l = self.emb_label(label)
103
  x = self.enc_fc_in(torch.cat([b, l], dim=-1))
 
105
  x = self.enc_transformer(x, padding_mask)
106
  return x[0]
107
 
108
+ def forward(
109
+ self,
110
+ bbox: torch.Tensor,
111
+ label: torch.Tensor,
112
+ padding_mask: torch.Tensor,
113
+ return_dict: Optional[bool] = None,
114
+ ) -> Union[Tuple, LayoutFIDNetV3Output]:
115
  B, N, _ = bbox.size()
116
  x = self.extract_features(bbox, label, padding_mask)
117
 
 
130
  logit_cls = self.fc_out_cls(x)
131
  bbox_pred = torch.sigmoid(self.fc_out_bbox(x))
132
 
133
+ if not return_dict:
134
+ return logit_disc, logit_cls, bbox_pred
135
+
136
  return LayoutFIDNetV3Output(
137
  logit_disc=logit_disc, logit_cls=logit_cls, bbox_pred=bbox_pred
138
  )