Upload LayoutFIDNetV3
Browse files- config.json +8 -8
- modeling_layout_fidnet_v3.py +14 -3
config.json
CHANGED
@@ -8,16 +8,16 @@
|
|
8 |
},
|
9 |
"d_model": 256,
|
10 |
"id2label": {
|
11 |
-
"0": "
|
12 |
-
"1": "
|
13 |
-
"2": "
|
14 |
-
"3": "
|
15 |
},
|
16 |
"label2id": {
|
17 |
-
"
|
18 |
-
"
|
19 |
-
"
|
20 |
-
"
|
21 |
},
|
22 |
"max_bbox": 10,
|
23 |
"model_type": "layoutdm_fidnet_v3",
|
|
|
8 |
},
|
9 |
"d_model": 256,
|
10 |
"id2label": {
|
11 |
+
"0": "logo",
|
12 |
+
"1": "text",
|
13 |
+
"2": "underlay",
|
14 |
+
"3": "embellishment"
|
15 |
},
|
16 |
"label2id": {
|
17 |
+
"embellishment": 3,
|
18 |
+
"logo": 0,
|
19 |
+
"text": 1,
|
20 |
+
"underlay": 2
|
21 |
},
|
22 |
"max_bbox": 10,
|
23 |
"model_type": "layoutdm_fidnet_v3",
|
modeling_layout_fidnet_v3.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import logging
|
2 |
from dataclasses import dataclass
|
3 |
-
from typing import Optional
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
@@ -95,7 +95,9 @@ class LayoutFIDNetV3(PreTrainedModel):
|
|
95 |
self.fc_out_cls = nn.Linear(config.d_model, config.num_labels)
|
96 |
self.fc_out_bbox = nn.Linear(config.d_model, 4)
|
97 |
|
98 |
-
def extract_features(
|
|
|
|
|
99 |
b = self.fc_bbox(bbox)
|
100 |
l = self.emb_label(label)
|
101 |
x = self.enc_fc_in(torch.cat([b, l], dim=-1))
|
@@ -103,7 +105,13 @@ class LayoutFIDNetV3(PreTrainedModel):
|
|
103 |
x = self.enc_transformer(x, padding_mask)
|
104 |
return x[0]
|
105 |
|
106 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
B, N, _ = bbox.size()
|
108 |
x = self.extract_features(bbox, label, padding_mask)
|
109 |
|
@@ -122,6 +130,9 @@ class LayoutFIDNetV3(PreTrainedModel):
|
|
122 |
logit_cls = self.fc_out_cls(x)
|
123 |
bbox_pred = torch.sigmoid(self.fc_out_bbox(x))
|
124 |
|
|
|
|
|
|
|
125 |
return LayoutFIDNetV3Output(
|
126 |
logit_disc=logit_disc, logit_cls=logit_cls, bbox_pred=bbox_pred
|
127 |
)
|
|
|
1 |
import logging
|
2 |
from dataclasses import dataclass
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
|
|
95 |
self.fc_out_cls = nn.Linear(config.d_model, config.num_labels)
|
96 |
self.fc_out_bbox = nn.Linear(config.d_model, 4)
|
97 |
|
98 |
+
def extract_features(
|
99 |
+
self, bbox: torch.Tensor, label: torch.Tensor, padding_mask: torch.Tensor
|
100 |
+
) -> torch.Tensor:
|
101 |
b = self.fc_bbox(bbox)
|
102 |
l = self.emb_label(label)
|
103 |
x = self.enc_fc_in(torch.cat([b, l], dim=-1))
|
|
|
105 |
x = self.enc_transformer(x, padding_mask)
|
106 |
return x[0]
|
107 |
|
108 |
+
def forward(
|
109 |
+
self,
|
110 |
+
bbox: torch.Tensor,
|
111 |
+
label: torch.Tensor,
|
112 |
+
padding_mask: torch.Tensor,
|
113 |
+
return_dict: Optional[bool] = None,
|
114 |
+
) -> Union[Tuple, LayoutFIDNetV3Output]:
|
115 |
B, N, _ = bbox.size()
|
116 |
x = self.extract_features(bbox, label, padding_mask)
|
117 |
|
|
|
130 |
logit_cls = self.fc_out_cls(x)
|
131 |
bbox_pred = torch.sigmoid(self.fc_out_bbox(x))
|
132 |
|
133 |
+
if not return_dict:
|
134 |
+
return logit_disc, logit_cls, bbox_pred
|
135 |
+
|
136 |
return LayoutFIDNetV3Output(
|
137 |
logit_disc=logit_disc, logit_cls=logit_cls, bbox_pred=bbox_pred
|
138 |
)
|