Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
class ClassificationHead(nn.Module): | |
"""Head for sentence-level classification tasks.""" | |
def __init__(self, hidden_dim): | |
super().__init__() | |
self.dense = nn.Linear(hidden_dim, hidden_dim) | |
self.Dropout = nn.Dropout(0.1) | |
self.out_proj = nn.Linear(hidden_dim, 1) | |
self.rnn_pool = nn.GRU(input_size=768, | |
hidden_size=768, | |
num_layers=1, | |
batch_first=True) | |
self.func_dense = nn.Linear(hidden_dim, hidden_dim) | |
self.func_out_proj = nn.Linear(hidden_dim, 2) | |
def forward(self, hidden): | |
x = self.Dropout(hidden) | |
x = self.dense(x) | |
x = torch.tanh(x) | |
x = self.Dropout(x) | |
x = self.out_proj(x) | |
out, func_x = self.rnn_pool(hidden) | |
func_x = func_x.squeeze(0) | |
func_x = self.Dropout(func_x) | |
func_x = self.func_dense(func_x) | |
func_x = torch.tanh(func_x) | |
func_x = self.Dropout(func_x) | |
func_x = self.func_out_proj(func_x) | |
return x.squeeze(-1), func_x | |
class StatementT5(nn.Module): | |
def __init__(self, t5, tokenizer, device, hidden_dim=768): | |
super(StatementT5, self).__init__() | |
self.max_num_statement = 155 | |
self.word_embedding = t5.shared | |
self.rnn_statement_embedding = nn.GRU(input_size=768, | |
hidden_size=768, | |
num_layers=1, | |
batch_first=True) | |
self.t5 = t5 | |
self.tokenizer = tokenizer | |
self.device = device | |
# CLS head | |
self.classifier = ClassificationHead(hidden_dim=hidden_dim) | |
def forward(self, input_ids, statement_mask, labels=None, func_labels=None): | |
statement_mask = statement_mask[:, :self.max_num_statement] | |
if self.training: | |
embed = self.word_embedding(input_ids) | |
inputs_embeds = torch.randn(embed.shape[0], embed.shape[1], embed.shape[3]).to(self.device) | |
for i in range(len(embed)): | |
statement_of_tokens = embed[i] | |
out, statement_embed = self.rnn_statement_embedding(statement_of_tokens) | |
inputs_embeds[i, :, :] = statement_embed | |
inputs_embeds = inputs_embeds[:, :self.max_num_statement, :] | |
rep = self.t5(inputs_embeds=inputs_embeds, attention_mask=statement_mask).last_hidden_state | |
logits, func_logits = self.classifier(rep) | |
loss_fct = nn.CrossEntropyLoss() | |
statement_loss = loss_fct(logits, labels) | |
loss_fct_2 = nn.CrossEntropyLoss() | |
func_loss = loss_fct_2(func_logits, func_labels) | |
return statement_loss, func_loss | |
else: | |
embed = self.word_embedding(input_ids) | |
inputs_embeds = torch.randn(embed.shape[0], embed.shape[1], embed.shape[3]).to(self.device) | |
for i in range(len(embed)): | |
statement_of_tokens = embed[i] | |
out, statement_embed = self.rnn_statement_embedding(statement_of_tokens) | |
inputs_embeds[i, :, :] = statement_embed | |
inputs_embeds = inputs_embeds[:, :self.max_num_statement, :] | |
rep = self.t5(inputs_embeds=inputs_embeds, attention_mask=statement_mask).last_hidden_state | |
logits, func_logits = self.classifier(rep) | |
probs = torch.sigmoid(logits) | |
func_probs = torch.softmax(func_logits, dim=-1) | |
return probs, func_probs |