iakarshu commited on
Commit
9a6e24a
·
1 Parent(s): 3c5f133

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +128 -0
utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from modeling import LiLT
3
+ import torch
4
+ ## Defining pytorch lightning model
5
+ from sklearn.metrics import accuracy_score, confusion_matrix
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ import numpy as np
10
+ import torchmetrics
11
+ import pytorch_lightning as pl
12
+
13
+
14
+ id2label = ['scientific_report',
15
+ 'resume',
16
+ 'memo',
17
+ 'file_folder',
18
+ 'specification',
19
+ 'news_article',
20
+ 'letter',
21
+ 'form',
22
+ 'budget',
23
+ 'handwritten',
24
+ 'email',
25
+ 'invoice',
26
+ 'presentation',
27
+ 'scientific_publication',
28
+ 'questionnaire',
29
+ 'advertisement']
30
+
31
+ class LiLTForClassification(nn.Module):
32
+
33
+ def __init__(self, config):
34
+ super(LiLTForClassification, self).__init__()
35
+
36
+ self.lilt = LiLT(config)
37
+ self.config = config
38
+ self.dropout = nn.Dropout(config['hidden_dropout_prob'])
39
+ self.linear_layer = nn.Linear(in_features = config['hidden_size'] * 2, out_features = len(id2label)) ## Number of Classes
40
+
41
+ def forward(self, batch_dict):
42
+ encodings = self.lilt(batch_dict['input_words'], batch_dict['input_boxes'])
43
+ final_out = torch.cat([encodings['layout_hidden_states'][-1],
44
+ encodings['text_hidden_states'][-1]
45
+ ],
46
+ axis = -1)[:, 0, :]
47
+ final_out = self.linear_layer(final_out)
48
+ return final_out
49
+
50
+
51
+ class LiLTPL(pl.LightningModule):
52
+
53
+ def __init__(self, config , lr = 5e-5):
54
+ super(LiLTPL, self).__init__()
55
+
56
+ self.save_hyperparameters()
57
+ self.config = config
58
+ self.lilt = LiLTForClassification(config)
59
+
60
+ self.num_classes = len(id2label)
61
+ self.train_accuracy_metric = torchmetrics.Accuracy()
62
+ self.val_accuracy_metric = torchmetrics.Accuracy()
63
+ self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes)
64
+ self.precision_macro_metric = torchmetrics.Precision(
65
+ average="macro", num_classes=self.num_classes
66
+ )
67
+ self.recall_macro_metric = torchmetrics.Recall(
68
+ average="macro", num_classes=self.num_classes
69
+ )
70
+ self.precision_micro_metric = torchmetrics.Precision(average="micro")
71
+ self.recall_micro_metric = torchmetrics.Recall(average="micro")
72
+
73
+ def forward(self, batch_dict):
74
+ logits = self.lilt(batch_dict)
75
+ return logits
76
+
77
+ def training_step(self, batch, batch_idx):
78
+ logits = self.forward(batch)
79
+
80
+ loss = nn.CrossEntropyLoss()(logits, batch['label'])
81
+ preds = torch.argmax(logits, 1)
82
+
83
+ ## Calculating the accuracy score
84
+ train_acc = self.train_accuracy_metric(preds, batch["label"])
85
+
86
+ ## Logging
87
+ self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True)
88
+ self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True)
89
+
90
+ return loss
91
+
92
+ def validation_step(self, batch, batch_idx):
93
+ logits = self.forward(batch)
94
+ loss = nn.CrossEntropyLoss()(logits, batch['label'])
95
+ preds = torch.argmax(logits, 1)
96
+
97
+ labels = batch['label']
98
+ # Metrics
99
+ valid_acc = self.val_accuracy_metric(preds, labels)
100
+ precision_macro = self.precision_macro_metric(preds, labels)
101
+ recall_macro = self.recall_macro_metric(preds, labels)
102
+ precision_micro = self.precision_micro_metric(preds, labels)
103
+ recall_micro = self.recall_micro_metric(preds, labels)
104
+ f1 = self.f1_metric(preds, labels)
105
+
106
+ # Logging metrics
107
+ self.log("valid/loss", loss, prog_bar=True, on_step=True, logger=True)
108
+ self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True)
109
+ self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
110
+ self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
111
+ self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
112
+ self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
113
+ self.log("valid/f1", f1, prog_bar=True, on_epoch=True)
114
+
115
+ return {"label": batch['label'], "logits": logits}
116
+
117
+ def validation_epoch_end(self, outputs):
118
+ labels = torch.cat([x["label"] for x in outputs])
119
+ logits = torch.cat([x["logits"] for x in outputs])
120
+ preds = torch.argmax(logits, 1)
121
+
122
+ wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())})
123
+ self.logger.experiment.log(
124
+ {"roc": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())}
125
+ )
126
+
127
+ def configure_optimizers(self):
128
+ return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr'])