Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
class RNN(nn.Module): | |
def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, cuda=True, device='cuda'): | |
super(RNN, self).__init__() | |
self.hidden_dim = hidden_dim | |
self.num_layers = num_layers | |
self.device = device | |
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim, | |
num_layers=self.num_layers, batch_first=True) | |
self.fc1 = nn.Linear(self.hidden_dim, self.hidden_dim) | |
self.fc2 = nn.Linear(self.hidden_dim, num_classes) | |
self.relu = nn.ReLU() | |
def forward(self, x, notes): | |
h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim) | |
c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim) | |
nn.init.xavier_normal_(h) | |
nn.init.xavier_normal_(c) | |
h = h.to(self.device) | |
c = c.to(self.device) | |
x = x.to(self.device) | |
output, _ = self.lstm(x, (h, c)) | |
out = self.fc2(self.relu(self.fc1(output[:, -1, :]))) | |
return out | |
class MMRNN(nn.ModuleList): | |
def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, embed_size=768, device="cuda"): | |
super(MMRNN, self).__init__() | |
self.hidden_dim = hidden_dim | |
self.num_layers = num_layers | |
self.device = device | |
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim, | |
num_layers=self.num_layers, batch_first=True) | |
self.fc1 = nn.Linear(self.hidden_dim, embed_size) | |
self.fc2 = nn.Linear(embed_size, num_classes) | |
self.lnorm_out = nn.LayerNorm(embed_size) | |
self.lnorm_embed = nn.LayerNorm(embed_size) | |
def forward(self, x, note): | |
h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim) | |
c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim) | |
nn.init.xavier_normal_(h) | |
nn.init.xavier_normal_(c) | |
h = h.to(self.device) | |
c = c.to(self.device) | |
x = x.to(self.device) | |
note = note.to(self.device) | |
output, _ = self.lstm(x, (h, c)) | |
# Take last hidden state | |
out = self.fc1(output[:, -1, :]) | |
note = self.lnorm_embed(note) | |
out = self.lnorm_out(out) | |
out = note + out | |
out = self.fc2(out) | |
return out.squeeze(1) | |