Spaces:
Runtime error
Runtime error
File size: 2,361 Bytes
71bd54f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import torch
import torch.nn as nn
class RNN(nn.Module):
def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, cuda=True, device='cuda'):
super(RNN, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.device = device
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim,
num_layers=self.num_layers, batch_first=True)
self.fc1 = nn.Linear(self.hidden_dim, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, num_classes)
self.relu = nn.ReLU()
def forward(self, x, notes):
h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
nn.init.xavier_normal_(h)
nn.init.xavier_normal_(c)
h = h.to(self.device)
c = c.to(self.device)
x = x.to(self.device)
output, _ = self.lstm(x, (h, c))
out = self.fc2(self.relu(self.fc1(output[:, -1, :])))
return out
class MMRNN(nn.ModuleList):
def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, embed_size=768, device="cuda"):
super(MMRNN, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.device = device
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim,
num_layers=self.num_layers, batch_first=True)
self.fc1 = nn.Linear(self.hidden_dim, embed_size)
self.fc2 = nn.Linear(embed_size, num_classes)
self.lnorm_out = nn.LayerNorm(embed_size)
self.lnorm_embed = nn.LayerNorm(embed_size)
def forward(self, x, note):
h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
nn.init.xavier_normal_(h)
nn.init.xavier_normal_(c)
h = h.to(self.device)
c = c.to(self.device)
x = x.to(self.device)
note = note.to(self.device)
output, _ = self.lstm(x, (h, c))
# Take last hidden state
out = self.fc1(output[:, -1, :])
note = self.lnorm_embed(note)
out = self.lnorm_out(out)
out = note + out
out = self.fc2(out)
return out.squeeze(1)
|