Spaces:
Runtime error
Runtime error
from typing import Dict | |
import torch | |
import torch.nn as nn | |
device = "cpu" | |
class SeqClassifier(nn.Module): | |
def __init__( | |
self, | |
embeddings: torch.tensor, | |
hidden_size: int, | |
num_layers: int, | |
dropout: float, | |
bidirectional: bool, | |
num_class: int, | |
) -> None: | |
super(SeqClassifier, self).__init__() | |
self.embed = nn.Embedding.from_pretrained(embeddings, freeze=False) | |
self.hidden_size=hidden_size | |
self.num_layers=num_layers | |
self.dropout=dropout | |
self.bidirectional=bidirectional | |
self.num_class=num_class | |
# model architecture | |
self.rnn = nn.LSTM( | |
input_size=embeddings.size(1), | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
dropout=dropout, | |
bidirectional=bidirectional, | |
batch_first=True | |
) | |
self.dropout_layer = nn.Dropout(p=self.dropout) | |
self.fc = nn.Linear(self.encoder_output_size, num_class) | |
def encoder_output_size(self) -> int: | |
# calculate the output dimension of rnn | |
if self.bidirectional: | |
return self.hidden_size * 2 | |
else: | |
return self.hidden_size | |
def forward(self, batch) -> torch.Tensor: | |
# ε°θΌΈε ₯ε΅ε ₯ε°θ©ε΅ε ₯η©ΊιοΌε°±ζ―ζθ©η΄’εΌζζθ©ει | |
embedded = self.embed(batch) | |
# ι LSTM ε±€ | |
rnn_output, _ = self.rnn(embedded) | |
rnn_output = self.dropout_layer(rnn_output) | |
if not self.training: | |
last_hidden_state_forward = rnn_output[ -1, :self.hidden_size] # ζ£εζΉεηιθηΆζ | |
last_hidden_state_backward = rnn_output[ 0, self.hidden_size:] # εεζΉεηιθηΆζ | |
combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=0) | |
# ιιε ¨ι£ζ₯ε±€ | |
logits = self.fc(combined_hidden_state) | |
return logits # θΏει ζΈ¬η΅ζ | |
last_hidden_state_forward = rnn_output[:, -1, :self.hidden_size] # ζ£εζΉεηιθηΆζ | |
last_hidden_state_backward = rnn_output[:, 0, self.hidden_size:] # εεζΉεηιθηΆζ | |
combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=1) | |
# ιιε ¨ι£ζ₯ε±€ | |
logits = self.fc(combined_hidden_state) | |
return logits # θΏει ζΈ¬η΅ζ | |