skytnt commited on
Commit
ff5739c
1 Parent(s): cd1d5fd
Files changed (3) hide show
  1. app.py +1 -1
  2. midi_model.py +1 -2
  3. requirements.txt +0 -1
app.py CHANGED
@@ -179,7 +179,7 @@ if __name__ == "__main__":
179
  parser = argparse.ArgumentParser()
180
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
181
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
182
- parser.add_argument("--device", type=str, default="cuda", help="device to run model")
183
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
184
  model_path = hf_hub_download(repo_id="skytnt/midi-model", filename="model.ckpt")
185
  opt = parser.parse_args()
 
179
  parser = argparse.ArgumentParser()
180
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
181
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
182
+ parser.add_argument("--device", type=str, default="cpu", help="device to run model")
183
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
184
  model_path = hf_hub_download(repo_id="skytnt/midi-model", filename="model.ckpt")
185
  opt = parser.parse_args()
midi_model.py CHANGED
@@ -3,13 +3,12 @@ import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  import tqdm
6
- import pytorch_lightning as pl
7
  from transformers import LlamaModel, LlamaConfig
8
 
9
  from midi_tokenizer import MIDITokenizer
10
 
11
 
12
- class MIDIModel(pl.LightningModule):
13
  def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
14
  *args, **kwargs):
15
  super(MIDIModel, self).__init__()
 
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  import tqdm
 
6
  from transformers import LlamaModel, LlamaConfig
7
 
8
  from midi_tokenizer import MIDITokenizer
9
 
10
 
11
+ class MIDIModel(nn.Module):
12
  def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
13
  *args, **kwargs):
14
  super(MIDIModel, self).__init__()
requirements.txt CHANGED
@@ -2,6 +2,5 @@ Pillow
2
  numpy
3
  torch
4
  transformers
5
- pytorch_lightning
6
  gradio
7
  pyfluidsynth
 
2
  numpy
3
  torch
4
  transformers
 
5
  gradio
6
  pyfluidsynth