KhaldiAbderrhmane commited on
Commit
4975ae2
·
verified ·
1 Parent(s): 981cfb7

Upload model

Browse files
Files changed (2) hide show
  1. model.py +198 -0
  2. model.safetensors +3 -0
model.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModel
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ import torch.nn.functional as F
6
+ from torch_geometric.nn import GCNConv,GATConv
7
+
8
+ from .config import BERTMultiGATAttentionConfig
9
+
10
+
11
+
12
+ class MultiHeadGATAttention(nn.Module):
13
+ def __init__(self, hidden_size, num_heads, dropout=0.03):
14
+ super(MultiHeadGATAttention, self).__init__()
15
+ self.hidden_size = hidden_size
16
+ self.num_heads = num_heads
17
+ self.head_dim = hidden_size // num_heads
18
+
19
+ self.query = nn.Linear(hidden_size, hidden_size)
20
+ self.key = nn.Linear(hidden_size, hidden_size)
21
+ self.value = nn.Linear(hidden_size, hidden_size)
22
+ self.out = nn.Linear(hidden_size, hidden_size)
23
+
24
+ self.gat = GATConv(hidden_size, hidden_size, heads=num_heads, concat=False)
25
+ self.alpha = nn.Parameter(torch.tensor(0.5)) # Learnable weight for combining attention outputs
26
+
27
+ self.layer_norm_q = nn.LayerNorm(hidden_size)
28
+ self.layer_norm_k = nn.LayerNorm(hidden_size)
29
+ self.layer_norm_v = nn.LayerNorm(hidden_size)
30
+ self.layer_norm_out = nn.LayerNorm(hidden_size)
31
+
32
+ self.dropout = nn.Dropout(dropout)
33
+
34
+ def forward(self, query, key, value, edge_index):
35
+ batch_size = query.size(0)
36
+ seq_length = query.size(1)
37
+
38
+ query_orig = query
39
+ query = self.layer_norm_q(self.query(query))
40
+ key = self.layer_norm_k(self.key(key))
41
+ value = self.layer_norm_v(self.value(value))
42
+
43
+ query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
44
+ key = key.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
45
+ value = value.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
46
+
47
+ attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
48
+ attention_weights = F.softmax(attention_scores, dim=-1)
49
+ attention_weights = self.dropout(attention_weights)
50
+ attended_values_std = torch.matmul(attention_weights, value).permute(0, 2, 1, 3).contiguous()
51
+ attended_values_std = attended_values_std.view(batch_size, seq_length, self.hidden_size)
52
+
53
+ query_gat = query.permute(0, 2, 1, 3).reshape(batch_size * seq_length, self.hidden_size)
54
+ value_gat = value.permute(0, 2, 1, 3).reshape(batch_size * seq_length, self.hidden_size)
55
+
56
+ attended_values_gat = self.gat(value_gat, edge_index).view(batch_size, seq_length, self.hidden_size)
57
+
58
+ # Weighted combin
59
+ attended_values = self.alpha * attended_values_std + (1 - self.alpha) * attended_values_gat
60
+
61
+ attended_values = self.layer_norm_out(self.out(attended_values))
62
+ attended_values = self.dropout(attended_values)
63
+
64
+ return query_orig + attended_values # Residual connection
65
+
66
+
67
+
68
+
69
+ class GNNPreProcessor(nn.Module):
70
+ def __init__(self, input_dim, hidden_dim, gat_heads=8):
71
+ super(GNNPreProcessor, self).__init__()
72
+ self.gcn = GCNConv(input_dim, hidden_dim)
73
+ self.gat = GATConv(hidden_dim, hidden_dim, heads=gat_heads, concat=False)
74
+ self.alpha = nn.Parameter(torch.tensor(0.5))
75
+
76
+ def forward(self, x, edge_index):
77
+ batch_size, seq_len, feature_dim = x.size()
78
+ x = x.view(batch_size * seq_len, feature_dim)
79
+ edge_index = edge_index.view(2, -1)
80
+
81
+ x_gcn = F.relu(self.gcn(x, edge_index))
82
+ x_gat = F.relu(self.gat(x, edge_index))
83
+ x = self.alpha * x_gcn + (1 - self.alpha) * x_gat
84
+
85
+ x = x.view(batch_size, seq_len, -1)
86
+ return x
87
+
88
+
89
+ class DEBERTAMultiGATAttentionModel(PreTrainedModel):
90
+ config_class = BERTMultiGATAttentionConfig
91
+ def __init__(self, config):
92
+ super(DEBERTAMultiGATAttentionModel, self).__init__(config)
93
+ self.config = config
94
+ self.transformer =AutoModel.from_pretrained(config.transformer_model)
95
+
96
+ self.gnn_preprocessor1 = GNNPreProcessor(config.gnn_input_dim, config.gnn_hidden_dim)
97
+ self.gnn_preprocessor2 = GNNPreProcessor(config.gnn_input_dim, config.gnn_hidden_dim)
98
+
99
+ self.fc_combine = nn.Linear(config.hidden_size * 2, config.hidden_size)
100
+ self.layer_norm_combine = nn.LayerNorm(config.hidden_size)
101
+ self.dropout_combine = nn.Dropout(config.dropout)
102
+
103
+ self.self_attention1 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
104
+ self.self_attention2 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
105
+ self.cross_attention = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
106
+
107
+ self.self_attention3 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
108
+ self.self_attention4 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
109
+ self.cross_attention_ = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
110
+
111
+ self.self_attention5 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
112
+ self.self_attention6 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
113
+ self.cross_attention__ = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
114
+
115
+ self.fc1 = nn.Linear(config.hidden_size * 2, 256)
116
+ self.fc2 = nn.Linear(config.hidden_size * 2, 256)
117
+ self.fc3 = nn.Linear(config.hidden_size * 2, 256)
118
+
119
+ self.layer_norm_fc1 = nn.LayerNorm(256)
120
+ self.layer_norm_fc2 = nn.LayerNorm(256)
121
+ self.layer_norm_fc3 = nn.LayerNorm(256)
122
+
123
+ self.dropout1 = nn.Dropout(config.dropout)
124
+ self.dropout2 = nn.Dropout(config.dropout)
125
+ self.dropout3 = nn.Dropout(config.dropout)
126
+ self.dropout4 = nn.Dropout(config.dropout)
127
+
128
+
129
+ self.fc_proj = nn.Linear(256, 256)
130
+ self.layer_norm_proj = nn.LayerNorm(256)
131
+ self.fc_final = nn.Linear(256, 1)
132
+
133
+ def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2, edge_index1, edge_index2):
134
+
135
+ output1_bert = self.transformer(input_ids1, attention_mask1)[0]
136
+ output2_bert = self.transformer(input_ids2, attention_mask2)[0]
137
+
138
+ edge_index1 = edge_index1.view(2, -1) # Flatten the batch dimension
139
+ edge_index2 = edge_index2.view(2, -1) # Flatten the batch dimension
140
+
141
+ output1_gnn = self.gnn_preprocessor1(output1_bert, edge_index1)
142
+ output2_gnn = self.gnn_preprocessor2(output2_bert, edge_index2)
143
+
144
+ combined_output1 = torch.cat([output1_bert, output1_gnn], dim=2)
145
+ combined_output2 = torch.cat([output2_bert, output2_gnn], dim=2)
146
+
147
+ combined_output1 = self.layer_norm_combine(self.fc_combine(combined_output1))
148
+ combined_output2 = self.layer_norm_combine(self.fc_combine(combined_output2))
149
+
150
+ combined_output1 = self.dropout_combine(F.relu(combined_output1))
151
+ combined_output2 = self.dropout_combine(F.relu(combined_output2))
152
+
153
+ #
154
+ output1 = self.self_attention1(combined_output1, combined_output1, combined_output1, edge_index1)
155
+ output2 = self.self_attention2(combined_output2, combined_output2, combined_output2, edge_index2)
156
+ attended_output = self.cross_attention(output1, output2, output2, edge_index1)
157
+
158
+ combined_output = torch.cat([output1, attended_output], dim=2)
159
+ combined_output, _ = torch.max(combined_output, dim=1)
160
+
161
+ combined_output = self.layer_norm_fc1(self.fc2(combined_output))
162
+ combined_output = self.dropout1(F.relu(combined_output))
163
+ combined_output = combined_output.unsqueeze(1)
164
+
165
+ #
166
+ output1 = self.self_attention3(combined_output1, combined_output1, combined_output1, edge_index1)
167
+ output2 = self.self_attention4(combined_output2, combined_output2, combined_output2, edge_index2)
168
+ attended_output = self.cross_attention_(output1, output2, output2, edge_index1)
169
+ combined_output = torch.cat([output1, attended_output], dim=2)
170
+ combined_output, _ = torch.max(combined_output, dim=1)
171
+
172
+ combined_output = self.layer_norm_fc2(self.fc2(combined_output))
173
+ combined_output = self.dropout2(F.relu(combined_output))
174
+ combined_output = combined_output.unsqueeze(1)
175
+
176
+ #
177
+ output1 = self.self_attention5(combined_output1, combined_output1, combined_output1, edge_index1)
178
+ output2 = self.self_attention6(combined_output2, combined_output2, combined_output2, edge_index2)
179
+ attended_output = self.cross_attention__(output1, output2, output2, edge_index1)
180
+
181
+ combined_output = torch.cat([output1, attended_output], dim=2)
182
+ combined_output, _ = torch.max(combined_output, dim=1)
183
+
184
+ combined_output = self.layer_norm_fc1(self.fc3(combined_output))
185
+ combined_output = self.dropout3(F.relu(combined_output))
186
+ combined_output = combined_output.unsqueeze(1)
187
+
188
+
189
+
190
+ hidden_state_proj = self.layer_norm_proj(self.fc_proj(combined_output))
191
+ hidden_state_proj = self.dropout4(hidden_state_proj)
192
+ final = self.fc_final(hidden_state_proj)
193
+
194
+ return torch.sigmoid(final)
195
+
196
+
197
+
198
+ AutoModel.register(BERTMultiGATAttentionConfig, DEBERTAMultiGATAttentionModel)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d91caf162fb0d5d4f415234e926c6bc2c7239fa2480e9347bda1453f4725c80
3
+ size 1043279080