from transformers import LiltPreTrainedModel, LiltModel import copy import torch from torch import nn from torch.nn import CrossEntropyLoss from dataclasses import dataclass from typing import Dict, Optional, Tuple from transformers.utils import ModelOutput class BiaffineAttention(torch.nn.Module): """Implements a biaffine attention operator for binary relation classification. PyTorch implementation of the biaffine attention operator from "End-to-end neural relation extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used as a classifier for binary relation classification. Args: in_features (int): The size of the feature dimension of the inputs. out_features (int): The size of the feature dimension of the output. Shape: - x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of additional dimensisons. - x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of additional dimensions. - Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number of additional dimensions. Examples: >>> batch_size, in_features, out_features = 32, 100, 4 >>> biaffine_attention = BiaffineAttention(in_features, out_features) >>> x_1 = torch.randn(batch_size, in_features) >>> x_2 = torch.randn(batch_size, in_features) >>> output = biaffine_attention(x_1, x_2) >>> print(output.size()) torch.Size([32, 4]) """ def __init__(self, in_features, out_features): super(BiaffineAttention, self).__init__() self.in_features = in_features self.out_features = out_features self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False) self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True) self.reset_parameters() def forward(self, x_1, x_2): return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1)) def reset_parameters(self): self.bilinear.reset_parameters() self.linear.reset_parameters() class REDecoder(nn.Module): def __init__(self, config, input_size): super().__init__() self.entity_emb = nn.Embedding(3, input_size, scale_grad_by_freq=True) projection = nn.Sequential( nn.Linear(input_size * 2, config.hidden_size), nn.ReLU(), nn.Dropout(config.hidden_dropout_prob), nn.Linear(config.hidden_size, config.hidden_size // 2), nn.ReLU(), nn.Dropout(config.hidden_dropout_prob), ) self.ffnn_head = copy.deepcopy(projection) self.ffnn_tail = copy.deepcopy(projection) self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2) self.loss_fct = CrossEntropyLoss() def build_relation(self, relations, entities): batch_size = len(relations) new_relations = [] for b in range(batch_size): if len(entities[b]["start"]) <= 2: entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]} all_possible_relations = set( [ (i, j) for i in range(len(entities[b]["label"])) for j in range(len(entities[b]["label"])) if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2 ] ) if len(all_possible_relations) == 0: all_possible_relations = set([(0, 1)]) positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"]))) negative_relations = all_possible_relations - positive_relations positive_relations = set([i for i in positive_relations if i in all_possible_relations]) reordered_relations = list(positive_relations) + list(negative_relations) relation_per_doc = {"head": [], "tail": [], "label": []} relation_per_doc["head"] = [i[0] for i in reordered_relations] relation_per_doc["tail"] = [i[1] for i in reordered_relations] relation_per_doc["label"] = [1] * len(positive_relations) + [0] * ( len(reordered_relations) - len(positive_relations) ) assert len(relation_per_doc["head"]) != 0 new_relations.append(relation_per_doc) return new_relations, entities def get_predicted_relations(self, logits, relations, entities): pred_relations = [] for i, pred_label in enumerate(logits.argmax(-1)): if pred_label != 1: continue rel = {} rel["head_id"] = relations["head"][i] rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]]) rel["head_type"] = entities["label"][rel["head_id"]] rel["tail_id"] = relations["tail"][i] rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]]) rel["tail_type"] = entities["label"][rel["tail_id"]] rel["type"] = 1 pred_relations.append(rel) return pred_relations def forward(self, hidden_states, entities, relations): batch_size, max_n_words, context_dim = hidden_states.size() device = hidden_states.device relations, entities = self.build_relation(relations, entities) loss = 0 all_pred_relations = [] all_logits = [] all_labels = [] for b in range(batch_size): head_entities = torch.tensor(relations[b]["head"], device=device) tail_entities = torch.tensor(relations[b]["tail"], device=device) relation_labels = torch.tensor(relations[b]["label"], device=device) entities_start_index = torch.tensor(entities[b]["start"], device=device) entities_labels = torch.tensor(entities[b]["label"], device=device) head_index = entities_start_index[head_entities] head_label = entities_labels[head_entities] head_label_repr = self.entity_emb(head_label) tail_index = entities_start_index[tail_entities] tail_label = entities_labels[tail_entities] tail_label_repr = self.entity_emb(tail_label) head_repr = torch.cat( (hidden_states[b][head_index], head_label_repr), dim=-1, ) tail_repr = torch.cat( (hidden_states[b][tail_index], tail_label_repr), dim=-1, ) heads = self.ffnn_head(head_repr) tails = self.ffnn_tail(tail_repr) logits = self.rel_classifier(heads, tails) pred_relations = self.get_predicted_relations(logits, relations[b], entities[b]) all_pred_relations.append(pred_relations) all_logits.append(logits) all_labels.append(relation_labels) all_logits = torch.cat(all_logits, 0) all_labels = torch.cat(all_labels, 0) loss = self.loss_fct(all_logits, all_labels) return loss, all_pred_relations @dataclass class ReOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None entities: Optional[Dict] = None relations: Optional[Dict] = None pred_relations: Optional[Dict] = None class REHead(nn.Module): def __init__(self, config): super().__init__() self.dropout = nn.Dropout(config.hidden_dropout_prob) self.extractor = REDecoder(config, config.hidden_size) def forward(self,sequence_output, entities, relations): sequence_output = self.dropout(sequence_output) loss, pred_relations = self.extractor(sequence_output, entities, relations) return ReOutput( loss=loss, entities=entities, relations=relations, pred_relations=pred_relations, ) class LiLTRobertaLikeForRelationExtraction(LiltPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) self.lilt = LiltModel(config, add_pooling_layer=False) # self.dropout = nn.Dropout(config.hidden_dropout_prob) # self.extractor = REDecoder(config, config.hidden_size) self.rehead = REHead(config) self.init_weights() def forward( self, input_ids=None, bbox=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, entities=None, relations=None, ): # for param in self.lilt.parameters(): # param.requires_grad = False outputs = self.lilt( input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) seq_length = input_ids.size(1) sequence_output = outputs[0] re_output = self.rehead(sequence_output, entities, relations) return re_output