KhaldiAbderrhmane
commited on
Upload model
Browse files- model.py +198 -0
- model.safetensors +3 -0
model.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel, AutoModel
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch_geometric.nn import GCNConv,GATConv
|
7 |
+
|
8 |
+
from .config import BERTMultiGATAttentionConfig
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
class MultiHeadGATAttention(nn.Module):
|
13 |
+
def __init__(self, hidden_size, num_heads, dropout=0.03):
|
14 |
+
super(MultiHeadGATAttention, self).__init__()
|
15 |
+
self.hidden_size = hidden_size
|
16 |
+
self.num_heads = num_heads
|
17 |
+
self.head_dim = hidden_size // num_heads
|
18 |
+
|
19 |
+
self.query = nn.Linear(hidden_size, hidden_size)
|
20 |
+
self.key = nn.Linear(hidden_size, hidden_size)
|
21 |
+
self.value = nn.Linear(hidden_size, hidden_size)
|
22 |
+
self.out = nn.Linear(hidden_size, hidden_size)
|
23 |
+
|
24 |
+
self.gat = GATConv(hidden_size, hidden_size, heads=num_heads, concat=False)
|
25 |
+
self.alpha = nn.Parameter(torch.tensor(0.5)) # Learnable weight for combining attention outputs
|
26 |
+
|
27 |
+
self.layer_norm_q = nn.LayerNorm(hidden_size)
|
28 |
+
self.layer_norm_k = nn.LayerNorm(hidden_size)
|
29 |
+
self.layer_norm_v = nn.LayerNorm(hidden_size)
|
30 |
+
self.layer_norm_out = nn.LayerNorm(hidden_size)
|
31 |
+
|
32 |
+
self.dropout = nn.Dropout(dropout)
|
33 |
+
|
34 |
+
def forward(self, query, key, value, edge_index):
|
35 |
+
batch_size = query.size(0)
|
36 |
+
seq_length = query.size(1)
|
37 |
+
|
38 |
+
query_orig = query
|
39 |
+
query = self.layer_norm_q(self.query(query))
|
40 |
+
key = self.layer_norm_k(self.key(key))
|
41 |
+
value = self.layer_norm_v(self.value(value))
|
42 |
+
|
43 |
+
query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
44 |
+
key = key.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
45 |
+
value = value.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
46 |
+
|
47 |
+
attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
48 |
+
attention_weights = F.softmax(attention_scores, dim=-1)
|
49 |
+
attention_weights = self.dropout(attention_weights)
|
50 |
+
attended_values_std = torch.matmul(attention_weights, value).permute(0, 2, 1, 3).contiguous()
|
51 |
+
attended_values_std = attended_values_std.view(batch_size, seq_length, self.hidden_size)
|
52 |
+
|
53 |
+
query_gat = query.permute(0, 2, 1, 3).reshape(batch_size * seq_length, self.hidden_size)
|
54 |
+
value_gat = value.permute(0, 2, 1, 3).reshape(batch_size * seq_length, self.hidden_size)
|
55 |
+
|
56 |
+
attended_values_gat = self.gat(value_gat, edge_index).view(batch_size, seq_length, self.hidden_size)
|
57 |
+
|
58 |
+
# Weighted combin
|
59 |
+
attended_values = self.alpha * attended_values_std + (1 - self.alpha) * attended_values_gat
|
60 |
+
|
61 |
+
attended_values = self.layer_norm_out(self.out(attended_values))
|
62 |
+
attended_values = self.dropout(attended_values)
|
63 |
+
|
64 |
+
return query_orig + attended_values # Residual connection
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
class GNNPreProcessor(nn.Module):
|
70 |
+
def __init__(self, input_dim, hidden_dim, gat_heads=8):
|
71 |
+
super(GNNPreProcessor, self).__init__()
|
72 |
+
self.gcn = GCNConv(input_dim, hidden_dim)
|
73 |
+
self.gat = GATConv(hidden_dim, hidden_dim, heads=gat_heads, concat=False)
|
74 |
+
self.alpha = nn.Parameter(torch.tensor(0.5))
|
75 |
+
|
76 |
+
def forward(self, x, edge_index):
|
77 |
+
batch_size, seq_len, feature_dim = x.size()
|
78 |
+
x = x.view(batch_size * seq_len, feature_dim)
|
79 |
+
edge_index = edge_index.view(2, -1)
|
80 |
+
|
81 |
+
x_gcn = F.relu(self.gcn(x, edge_index))
|
82 |
+
x_gat = F.relu(self.gat(x, edge_index))
|
83 |
+
x = self.alpha * x_gcn + (1 - self.alpha) * x_gat
|
84 |
+
|
85 |
+
x = x.view(batch_size, seq_len, -1)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class DEBERTAMultiGATAttentionModel(PreTrainedModel):
|
90 |
+
config_class = BERTMultiGATAttentionConfig
|
91 |
+
def __init__(self, config):
|
92 |
+
super(DEBERTAMultiGATAttentionModel, self).__init__(config)
|
93 |
+
self.config = config
|
94 |
+
self.transformer =AutoModel.from_pretrained(config.transformer_model)
|
95 |
+
|
96 |
+
self.gnn_preprocessor1 = GNNPreProcessor(config.gnn_input_dim, config.gnn_hidden_dim)
|
97 |
+
self.gnn_preprocessor2 = GNNPreProcessor(config.gnn_input_dim, config.gnn_hidden_dim)
|
98 |
+
|
99 |
+
self.fc_combine = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
100 |
+
self.layer_norm_combine = nn.LayerNorm(config.hidden_size)
|
101 |
+
self.dropout_combine = nn.Dropout(config.dropout)
|
102 |
+
|
103 |
+
self.self_attention1 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
|
104 |
+
self.self_attention2 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
|
105 |
+
self.cross_attention = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
|
106 |
+
|
107 |
+
self.self_attention3 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
|
108 |
+
self.self_attention4 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
|
109 |
+
self.cross_attention_ = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
|
110 |
+
|
111 |
+
self.self_attention5 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
|
112 |
+
self.self_attention6 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
|
113 |
+
self.cross_attention__ = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
|
114 |
+
|
115 |
+
self.fc1 = nn.Linear(config.hidden_size * 2, 256)
|
116 |
+
self.fc2 = nn.Linear(config.hidden_size * 2, 256)
|
117 |
+
self.fc3 = nn.Linear(config.hidden_size * 2, 256)
|
118 |
+
|
119 |
+
self.layer_norm_fc1 = nn.LayerNorm(256)
|
120 |
+
self.layer_norm_fc2 = nn.LayerNorm(256)
|
121 |
+
self.layer_norm_fc3 = nn.LayerNorm(256)
|
122 |
+
|
123 |
+
self.dropout1 = nn.Dropout(config.dropout)
|
124 |
+
self.dropout2 = nn.Dropout(config.dropout)
|
125 |
+
self.dropout3 = nn.Dropout(config.dropout)
|
126 |
+
self.dropout4 = nn.Dropout(config.dropout)
|
127 |
+
|
128 |
+
|
129 |
+
self.fc_proj = nn.Linear(256, 256)
|
130 |
+
self.layer_norm_proj = nn.LayerNorm(256)
|
131 |
+
self.fc_final = nn.Linear(256, 1)
|
132 |
+
|
133 |
+
def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2, edge_index1, edge_index2):
|
134 |
+
|
135 |
+
output1_bert = self.transformer(input_ids1, attention_mask1)[0]
|
136 |
+
output2_bert = self.transformer(input_ids2, attention_mask2)[0]
|
137 |
+
|
138 |
+
edge_index1 = edge_index1.view(2, -1) # Flatten the batch dimension
|
139 |
+
edge_index2 = edge_index2.view(2, -1) # Flatten the batch dimension
|
140 |
+
|
141 |
+
output1_gnn = self.gnn_preprocessor1(output1_bert, edge_index1)
|
142 |
+
output2_gnn = self.gnn_preprocessor2(output2_bert, edge_index2)
|
143 |
+
|
144 |
+
combined_output1 = torch.cat([output1_bert, output1_gnn], dim=2)
|
145 |
+
combined_output2 = torch.cat([output2_bert, output2_gnn], dim=2)
|
146 |
+
|
147 |
+
combined_output1 = self.layer_norm_combine(self.fc_combine(combined_output1))
|
148 |
+
combined_output2 = self.layer_norm_combine(self.fc_combine(combined_output2))
|
149 |
+
|
150 |
+
combined_output1 = self.dropout_combine(F.relu(combined_output1))
|
151 |
+
combined_output2 = self.dropout_combine(F.relu(combined_output2))
|
152 |
+
|
153 |
+
#
|
154 |
+
output1 = self.self_attention1(combined_output1, combined_output1, combined_output1, edge_index1)
|
155 |
+
output2 = self.self_attention2(combined_output2, combined_output2, combined_output2, edge_index2)
|
156 |
+
attended_output = self.cross_attention(output1, output2, output2, edge_index1)
|
157 |
+
|
158 |
+
combined_output = torch.cat([output1, attended_output], dim=2)
|
159 |
+
combined_output, _ = torch.max(combined_output, dim=1)
|
160 |
+
|
161 |
+
combined_output = self.layer_norm_fc1(self.fc2(combined_output))
|
162 |
+
combined_output = self.dropout1(F.relu(combined_output))
|
163 |
+
combined_output = combined_output.unsqueeze(1)
|
164 |
+
|
165 |
+
#
|
166 |
+
output1 = self.self_attention3(combined_output1, combined_output1, combined_output1, edge_index1)
|
167 |
+
output2 = self.self_attention4(combined_output2, combined_output2, combined_output2, edge_index2)
|
168 |
+
attended_output = self.cross_attention_(output1, output2, output2, edge_index1)
|
169 |
+
combined_output = torch.cat([output1, attended_output], dim=2)
|
170 |
+
combined_output, _ = torch.max(combined_output, dim=1)
|
171 |
+
|
172 |
+
combined_output = self.layer_norm_fc2(self.fc2(combined_output))
|
173 |
+
combined_output = self.dropout2(F.relu(combined_output))
|
174 |
+
combined_output = combined_output.unsqueeze(1)
|
175 |
+
|
176 |
+
#
|
177 |
+
output1 = self.self_attention5(combined_output1, combined_output1, combined_output1, edge_index1)
|
178 |
+
output2 = self.self_attention6(combined_output2, combined_output2, combined_output2, edge_index2)
|
179 |
+
attended_output = self.cross_attention__(output1, output2, output2, edge_index1)
|
180 |
+
|
181 |
+
combined_output = torch.cat([output1, attended_output], dim=2)
|
182 |
+
combined_output, _ = torch.max(combined_output, dim=1)
|
183 |
+
|
184 |
+
combined_output = self.layer_norm_fc1(self.fc3(combined_output))
|
185 |
+
combined_output = self.dropout3(F.relu(combined_output))
|
186 |
+
combined_output = combined_output.unsqueeze(1)
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
hidden_state_proj = self.layer_norm_proj(self.fc_proj(combined_output))
|
191 |
+
hidden_state_proj = self.dropout4(hidden_state_proj)
|
192 |
+
final = self.fc_final(hidden_state_proj)
|
193 |
+
|
194 |
+
return torch.sigmoid(final)
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
AutoModel.register(BERTMultiGATAttentionConfig, DEBERTAMultiGATAttentionModel)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d91caf162fb0d5d4f415234e926c6bc2c7239fa2480e9347bda1453f4725c80
|
3 |
+
size 1043279080
|