skytnt commited on
Commit
5c5b5ca
1 Parent(s): ff5739c

Update midi_model.py

Browse files
Files changed (1) hide show
  1. midi_model.py +2 -1
midi_model.py CHANGED
@@ -4,11 +4,12 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  import tqdm
6
  from transformers import LlamaModel, LlamaConfig
 
7
 
8
  from midi_tokenizer import MIDITokenizer
9
 
10
 
11
- class MIDIModel(nn.Module):
12
  def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
13
  *args, **kwargs):
14
  super(MIDIModel, self).__init__()
 
4
  import torch.nn.functional as F
5
  import tqdm
6
  from transformers import LlamaModel, LlamaConfig
7
+ from transformers.modeling_utils import ModuleUtilsMixin
8
 
9
  from midi_tokenizer import MIDITokenizer
10
 
11
 
12
+ class MIDIModel(nn.Module, ModuleUtilsMixin):
13
  def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
14
  *args, **kwargs):
15
  super(MIDIModel, self).__init__()