Spaces:
Runtime error
Runtime error
init
Browse files- app.py +1 -1
- midi_model.py +1 -2
- requirements.txt +0 -1
app.py
CHANGED
@@ -179,7 +179,7 @@ if __name__ == "__main__":
|
|
179 |
parser = argparse.ArgumentParser()
|
180 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
181 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
182 |
-
parser.add_argument("--device", type=str, default="
|
183 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
184 |
model_path = hf_hub_download(repo_id="skytnt/midi-model", filename="model.ckpt")
|
185 |
opt = parser.parse_args()
|
|
|
179 |
parser = argparse.ArgumentParser()
|
180 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
181 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
182 |
+
parser.add_argument("--device", type=str, default="cpu", help="device to run model")
|
183 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
184 |
model_path = hf_hub_download(repo_id="skytnt/midi-model", filename="model.ckpt")
|
185 |
opt = parser.parse_args()
|
midi_model.py
CHANGED
@@ -3,13 +3,12 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
import tqdm
|
6 |
-
import pytorch_lightning as pl
|
7 |
from transformers import LlamaModel, LlamaConfig
|
8 |
|
9 |
from midi_tokenizer import MIDITokenizer
|
10 |
|
11 |
|
12 |
-
class MIDIModel(
|
13 |
def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
|
14 |
*args, **kwargs):
|
15 |
super(MIDIModel, self).__init__()
|
|
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
import tqdm
|
|
|
6 |
from transformers import LlamaModel, LlamaConfig
|
7 |
|
8 |
from midi_tokenizer import MIDITokenizer
|
9 |
|
10 |
|
11 |
+
class MIDIModel(nn.Module):
|
12 |
def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
|
13 |
*args, **kwargs):
|
14 |
super(MIDIModel, self).__init__()
|
requirements.txt
CHANGED
@@ -2,6 +2,5 @@ Pillow
|
|
2 |
numpy
|
3 |
torch
|
4 |
transformers
|
5 |
-
pytorch_lightning
|
6 |
gradio
|
7 |
pyfluidsynth
|
|
|
2 |
numpy
|
3 |
torch
|
4 |
transformers
|
|
|
5 |
gradio
|
6 |
pyfluidsynth
|