Update models.py
Browse files
models.py
CHANGED
@@ -2,9 +2,21 @@ import torch
|
|
2 |
from transformers import AutoModelForTokenClassification, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
|
3 |
from transformers import BertForTokenClassification, BertForSequenceClassification,BertPreTrainedModel, BertModel
|
4 |
import torch.nn as nn
|
5 |
-
from .utils import *
|
6 |
import torch.nn.functional as F
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
class Model_Rational_Label(BertPreTrainedModel):
|
|
|
2 |
from transformers import AutoModelForTokenClassification, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
|
3 |
from transformers import BertForTokenClassification, BertForSequenceClassification,BertPreTrainedModel, BertModel
|
4 |
import torch.nn as nn
|
|
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
+
class BertPooler(nn.Module):
|
8 |
+
def __init__(self, config):
|
9 |
+
super().__init__()
|
10 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
11 |
+
self.activation = nn.Tanh()
|
12 |
+
|
13 |
+
def forward(self, hidden_states):
|
14 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
15 |
+
# to the first token.
|
16 |
+
first_token_tensor = hidden_states[:, 0]
|
17 |
+
pooled_output = self.dense(first_token_tensor)
|
18 |
+
pooled_output = self.activation(pooled_output)
|
19 |
+
return pooled_output
|
20 |
|
21 |
|
22 |
class Model_Rational_Label(BertPreTrainedModel):
|