leaf-large / modeling_leaf.py
baskra's picture
Upload model
f0c06da verified
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel
from transformers import PreTrainedModel
from .configuration_leaf import LeafConfig
from .mappings import idx_to_ef, idx_to_classname
class LeafModel(PreTrainedModel):
"""
LEAF model for text classification.
"""
config_class = LeafConfig
def __init__(self, config: LeafConfig):
super().__init__(config)
self._base_model = AutoModel.from_pretrained(config.model_name)
self._device = "cuda" if torch.cuda.is_available() else "cpu"
hidden_dim = self._base_model.config.hidden_size
self.head = ClassificationHead(hidden_dim=hidden_dim, num_classes=2097,
idx_to_ef=idx_to_ef, idx_to_classname=idx_to_classname,
device=self._device)
def forward(self, input_ids, attention_mask, **kwargs) -> dict:
if "classes" not in kwargs:
kwargs["classes"] = None
outputs = self._base_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
attention_mask = attention_mask.unsqueeze(-1)
masked_outputs = outputs * attention_mask.type_as(outputs)
nom = masked_outputs.sum(dim=1)
denom = attention_mask.sum(dim=1)
denom = denom.masked_fill(denom == 0, 1)
return self.head(nom / denom, **kwargs)
class ClassificationHead(nn.Module):
"""
Model head to predict a categorical target variable.
"""
def __init__(self, hidden_dim: int, num_classes: int, idx_to_ef: dict, idx_to_classname: Optional[dict],
device: str):
super().__init__()
self.linear = nn.Linear(in_features=hidden_dim, out_features=num_classes)
self.loss = nn.CrossEntropyLoss()
# Turn dict into lookup table
self.idx_to_ef = torch.Tensor([idx_to_ef[k] for k in sorted(idx_to_ef.keys())]).to(device)
self.idx_to_ef.requires_grad = False
self.idx_to_classname = idx_to_classname
def __call__(self, activations: torch.Tensor, classes: Optional[torch.Tensor], **kwargs) -> dict:
return_dict = {}
logits = self.linear(activations)
return_dict["logits"] = logits
if classes:
loss = self.loss(logits, classes)
return_dict["loss"] = loss
_, predicted_classes = torch.max(F.softmax(logits, dim=1), dim=1)
return_dict["class_idx"] = predicted_classes
return_dict["ef_score"] = self.idx_to_ef[predicted_classes]
if self.idx_to_classname:
return_dict["class"] = [self.idx_to_classname[str(c)] for c in
predicted_classes.cpu().numpy()]
return return_dict