Spaces:
Running
on
Zero
Running
on
Zero
add cache
Browse files- app.py +10 -14
- midi_model.py +26 -9
app.py
CHANGED
@@ -13,6 +13,7 @@ import torch
|
|
13 |
import torch.nn.functional as F
|
14 |
import tqdm
|
15 |
from huggingface_hub import hf_hub_download
|
|
|
16 |
|
17 |
import MIDI
|
18 |
from midi_model import MIDIModel, MIDIModelConfig
|
@@ -51,12 +52,14 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
|
|
51 |
input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
|
52 |
cur_len = input_tensor.shape[1]
|
53 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
|
|
54 |
with bar:
|
55 |
while cur_len < max_len:
|
56 |
end = [False] * batch_size
|
57 |
-
hidden = model.forward(input_tensor)[:, -1]
|
58 |
next_token_seq = None
|
59 |
event_names = [""] * batch_size
|
|
|
60 |
for i in range(max_token_seq):
|
61 |
mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=model.device)
|
62 |
for b in range(batch_size):
|
@@ -81,7 +84,11 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
|
|
81 |
mask_ids = [i for i in mask_ids if i not in disable_channels]
|
82 |
mask[b, mask_ids] = 1
|
83 |
mask = mask.unsqueeze(1)
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
scores = torch.softmax(logits / temp, dim=-1) * mask
|
86 |
samples = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
|
87 |
if i == 0:
|
@@ -118,21 +125,10 @@ def send_msgs(msgs):
|
|
118 |
return json.dumps(msgs)
|
119 |
|
120 |
|
121 |
-
def calc_time(x):
|
122 |
-
return 5.849e-5*x**2 + 0.04781*x + 0.1168
|
123 |
-
|
124 |
def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
|
125 |
time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
|
126 |
remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
127 |
-
|
128 |
-
start_events = 1
|
129 |
-
elif tab == 1 and mid is not None:
|
130 |
-
start_events = midi_events
|
131 |
-
elif tab == 2 and mid_seq is not None:
|
132 |
-
start_events = len(mid_seq[0])
|
133 |
-
else:
|
134 |
-
start_events = 1
|
135 |
-
t = calc_time(start_events + gen_events) - calc_time(start_events) + 5
|
136 |
if "large" in model_name:
|
137 |
t *= 2
|
138 |
return t
|
|
|
13 |
import torch.nn.functional as F
|
14 |
import tqdm
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
+
from transformers import DynamicCache
|
17 |
|
18 |
import MIDI
|
19 |
from midi_model import MIDIModel, MIDIModelConfig
|
|
|
52 |
input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
|
53 |
cur_len = input_tensor.shape[1]
|
54 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
55 |
+
cache1 = DynamicCache()
|
56 |
with bar:
|
57 |
while cur_len < max_len:
|
58 |
end = [False] * batch_size
|
59 |
+
hidden = model.forward(input_tensor[:, -1:], cache=cache1)[:, -1]
|
60 |
next_token_seq = None
|
61 |
event_names = [""] * batch_size
|
62 |
+
cache2 = DynamicCache()
|
63 |
for i in range(max_token_seq):
|
64 |
mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=model.device)
|
65 |
for b in range(batch_size):
|
|
|
84 |
mask_ids = [i for i in mask_ids if i not in disable_channels]
|
85 |
mask[b, mask_ids] = 1
|
86 |
mask = mask.unsqueeze(1)
|
87 |
+
x = next_token_seq
|
88 |
+
if i != 0:
|
89 |
+
hidden = None
|
90 |
+
x = x[:, -1:]
|
91 |
+
logits = model.forward_token(hidden, x, cache=cache2)[:, -1:]
|
92 |
scores = torch.softmax(logits / temp, dim=-1) * mask
|
93 |
samples = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
|
94 |
if i == 0:
|
|
|
125 |
return json.dumps(msgs)
|
126 |
|
127 |
|
|
|
|
|
|
|
128 |
def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
|
129 |
time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
|
130 |
remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
131 |
+
t = gen_events // 20 + 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
if "large" in model_name:
|
133 |
t *= 2
|
134 |
return t
|
midi_model.py
CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
|
|
6 |
import torch.nn.functional as F
|
7 |
import tqdm
|
8 |
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
|
9 |
-
from transformers import LlamaModel, LlamaConfig
|
10 |
from transformers.integrations import PeftAdapterMixin
|
11 |
|
12 |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
@@ -83,30 +83,40 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
83 |
set_peft_model_state_dict(self, adapter_state_dict, "default")
|
84 |
return model.merge_and_unload()
|
85 |
|
86 |
-
def forward_token(self, hidden_state, x=None):
|
87 |
"""
|
88 |
|
89 |
:param hidden_state: (batch_size, n_embd)
|
90 |
:param x: (batch_size, token_sequence_length)
|
|
|
91 |
:return: (batch_size, 1 + token_sequence_length, vocab_size)
|
92 |
"""
|
93 |
-
|
|
|
|
|
94 |
if x is not None:
|
95 |
x = self.net_token.embed_tokens(x)
|
96 |
-
hidden_state
|
97 |
-
|
|
|
|
|
|
|
|
|
98 |
return self.lm_head(hidden_state)
|
99 |
|
100 |
-
def forward(self, x):
|
101 |
"""
|
102 |
:param x: (batch_size, midi_sequence_length, token_sequence_length)
|
|
|
103 |
:return: hidden (batch_size, midi_sequence_length, n_embd)
|
104 |
"""
|
105 |
|
106 |
# merge token sequence
|
107 |
x = self.net.embed_tokens(x)
|
108 |
x = x.sum(dim=-2)
|
109 |
-
x = self.net.forward(inputs_embeds=x
|
|
|
|
|
110 |
return x.last_hidden_state
|
111 |
|
112 |
def sample_top_p_k(self, probs, p, k, generator=None):
|
@@ -149,12 +159,14 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
149 |
|
150 |
cur_len = input_tensor.shape[1]
|
151 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
|
|
152 |
with bar:
|
153 |
while cur_len < max_len:
|
154 |
end = [False] * batch_size
|
155 |
-
hidden = self.forward(input_tensor)[:, -1]
|
156 |
next_token_seq = None
|
157 |
event_names = [""] * batch_size
|
|
|
158 |
for i in range(max_token_seq):
|
159 |
mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device)
|
160 |
for b in range(batch_size):
|
@@ -170,7 +182,12 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
170 |
continue
|
171 |
mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
|
172 |
mask = mask.unsqueeze(1)
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
174 |
scores = torch.softmax(logits / temp, dim=-1) * mask
|
175 |
samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
|
176 |
if i == 0:
|
|
|
6 |
import torch.nn.functional as F
|
7 |
import tqdm
|
8 |
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
|
9 |
+
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
10 |
from transformers.integrations import PeftAdapterMixin
|
11 |
|
12 |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
|
|
83 |
set_peft_model_state_dict(self, adapter_state_dict, "default")
|
84 |
return model.merge_and_unload()
|
85 |
|
86 |
+
def forward_token(self, hidden_state=None, x=None, cache=None):
|
87 |
"""
|
88 |
|
89 |
:param hidden_state: (batch_size, n_embd)
|
90 |
:param x: (batch_size, token_sequence_length)
|
91 |
+
:param cache: Cache
|
92 |
:return: (batch_size, 1 + token_sequence_length, vocab_size)
|
93 |
"""
|
94 |
+
if hidden_state is not None:
|
95 |
+
#if you use cache, you don't need to pass in hidden_state
|
96 |
+
hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
|
97 |
if x is not None:
|
98 |
x = self.net_token.embed_tokens(x)
|
99 |
+
if hidden_state is not None:
|
100 |
+
x = torch.cat([hidden_state, x], dim=1)
|
101 |
+
hidden_state = x
|
102 |
+
hidden_state = self.net_token.forward(inputs_embeds=hidden_state,
|
103 |
+
past_key_values=cache,
|
104 |
+
use_cache=cache is not None).last_hidden_state
|
105 |
return self.lm_head(hidden_state)
|
106 |
|
107 |
+
def forward(self, x, cache = None):
|
108 |
"""
|
109 |
:param x: (batch_size, midi_sequence_length, token_sequence_length)
|
110 |
+
:param cache: Cache
|
111 |
:return: hidden (batch_size, midi_sequence_length, n_embd)
|
112 |
"""
|
113 |
|
114 |
# merge token sequence
|
115 |
x = self.net.embed_tokens(x)
|
116 |
x = x.sum(dim=-2)
|
117 |
+
x = self.net.forward(inputs_embeds=x,
|
118 |
+
past_key_values=cache,
|
119 |
+
use_cache=cache is not None)
|
120 |
return x.last_hidden_state
|
121 |
|
122 |
def sample_top_p_k(self, probs, p, k, generator=None):
|
|
|
159 |
|
160 |
cur_len = input_tensor.shape[1]
|
161 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
162 |
+
cache1 = DynamicCache()
|
163 |
with bar:
|
164 |
while cur_len < max_len:
|
165 |
end = [False] * batch_size
|
166 |
+
hidden = self.forward(input_tensor[:,-1:], cache=cache1)[:, -1]
|
167 |
next_token_seq = None
|
168 |
event_names = [""] * batch_size
|
169 |
+
cache2 = DynamicCache()
|
170 |
for i in range(max_token_seq):
|
171 |
mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device)
|
172 |
for b in range(batch_size):
|
|
|
182 |
continue
|
183 |
mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
|
184 |
mask = mask.unsqueeze(1)
|
185 |
+
x = next_token_seq
|
186 |
+
if i != 0:
|
187 |
+
# cached
|
188 |
+
hidden = None
|
189 |
+
x = x[:, -1:]
|
190 |
+
logits = self.forward_token(hidden, x, cache=cache2)[:, -1:]
|
191 |
scores = torch.softmax(logits / temp, dim=-1) * mask
|
192 |
samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
|
193 |
if i == 0:
|