skytnt commited on
Commit
0db6e6a
1 Parent(s): 0d70f90
Files changed (2) hide show
  1. app.py +10 -14
  2. midi_model.py +26 -9
app.py CHANGED
@@ -13,6 +13,7 @@ import torch
13
  import torch.nn.functional as F
14
  import tqdm
15
  from huggingface_hub import hf_hub_download
 
16
 
17
  import MIDI
18
  from midi_model import MIDIModel, MIDIModelConfig
@@ -51,12 +52,14 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
51
  input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
52
  cur_len = input_tensor.shape[1]
53
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
 
54
  with bar:
55
  while cur_len < max_len:
56
  end = [False] * batch_size
57
- hidden = model.forward(input_tensor)[:, -1]
58
  next_token_seq = None
59
  event_names = [""] * batch_size
 
60
  for i in range(max_token_seq):
61
  mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=model.device)
62
  for b in range(batch_size):
@@ -81,7 +84,11 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
81
  mask_ids = [i for i in mask_ids if i not in disable_channels]
82
  mask[b, mask_ids] = 1
83
  mask = mask.unsqueeze(1)
84
- logits = model.forward_token(hidden, next_token_seq)[:, -1:]
 
 
 
 
85
  scores = torch.softmax(logits / temp, dim=-1) * mask
86
  samples = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
87
  if i == 0:
@@ -118,21 +125,10 @@ def send_msgs(msgs):
118
  return json.dumps(msgs)
119
 
120
 
121
- def calc_time(x):
122
- return 5.849e-5*x**2 + 0.04781*x + 0.1168
123
-
124
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
125
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
126
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
127
- if tab == 0:
128
- start_events = 1
129
- elif tab == 1 and mid is not None:
130
- start_events = midi_events
131
- elif tab == 2 and mid_seq is not None:
132
- start_events = len(mid_seq[0])
133
- else:
134
- start_events = 1
135
- t = calc_time(start_events + gen_events) - calc_time(start_events) + 5
136
  if "large" in model_name:
137
  t *= 2
138
  return t
 
13
  import torch.nn.functional as F
14
  import tqdm
15
  from huggingface_hub import hf_hub_download
16
+ from transformers import DynamicCache
17
 
18
  import MIDI
19
  from midi_model import MIDIModel, MIDIModelConfig
 
52
  input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
53
  cur_len = input_tensor.shape[1]
54
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
55
+ cache1 = DynamicCache()
56
  with bar:
57
  while cur_len < max_len:
58
  end = [False] * batch_size
59
+ hidden = model.forward(input_tensor[:, -1:], cache=cache1)[:, -1]
60
  next_token_seq = None
61
  event_names = [""] * batch_size
62
+ cache2 = DynamicCache()
63
  for i in range(max_token_seq):
64
  mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=model.device)
65
  for b in range(batch_size):
 
84
  mask_ids = [i for i in mask_ids if i not in disable_channels]
85
  mask[b, mask_ids] = 1
86
  mask = mask.unsqueeze(1)
87
+ x = next_token_seq
88
+ if i != 0:
89
+ hidden = None
90
+ x = x[:, -1:]
91
+ logits = model.forward_token(hidden, x, cache=cache2)[:, -1:]
92
  scores = torch.softmax(logits / temp, dim=-1) * mask
93
  samples = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
94
  if i == 0:
 
125
  return json.dumps(msgs)
126
 
127
 
 
 
 
128
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
129
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
130
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
131
+ t = gen_events // 20 + 5
 
 
 
 
 
 
 
 
132
  if "large" in model_name:
133
  t *= 2
134
  return t
midi_model.py CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  import tqdm
8
  from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
9
- from transformers import LlamaModel, LlamaConfig
10
  from transformers.integrations import PeftAdapterMixin
11
 
12
  from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
@@ -83,30 +83,40 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
83
  set_peft_model_state_dict(self, adapter_state_dict, "default")
84
  return model.merge_and_unload()
85
 
86
- def forward_token(self, hidden_state, x=None):
87
  """
88
 
89
  :param hidden_state: (batch_size, n_embd)
90
  :param x: (batch_size, token_sequence_length)
 
91
  :return: (batch_size, 1 + token_sequence_length, vocab_size)
92
  """
93
- hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
 
 
94
  if x is not None:
95
  x = self.net_token.embed_tokens(x)
96
- hidden_state = torch.cat([hidden_state, x], dim=1)
97
- hidden_state = self.net_token.forward(inputs_embeds=hidden_state).last_hidden_state
 
 
 
 
98
  return self.lm_head(hidden_state)
99
 
100
- def forward(self, x):
101
  """
102
  :param x: (batch_size, midi_sequence_length, token_sequence_length)
 
103
  :return: hidden (batch_size, midi_sequence_length, n_embd)
104
  """
105
 
106
  # merge token sequence
107
  x = self.net.embed_tokens(x)
108
  x = x.sum(dim=-2)
109
- x = self.net.forward(inputs_embeds=x)
 
 
110
  return x.last_hidden_state
111
 
112
  def sample_top_p_k(self, probs, p, k, generator=None):
@@ -149,12 +159,14 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
149
 
150
  cur_len = input_tensor.shape[1]
151
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
 
152
  with bar:
153
  while cur_len < max_len:
154
  end = [False] * batch_size
155
- hidden = self.forward(input_tensor)[:, -1]
156
  next_token_seq = None
157
  event_names = [""] * batch_size
 
158
  for i in range(max_token_seq):
159
  mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device)
160
  for b in range(batch_size):
@@ -170,7 +182,12 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
170
  continue
171
  mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
172
  mask = mask.unsqueeze(1)
173
- logits = self.forward_token(hidden, next_token_seq)[:, -1:]
 
 
 
 
 
174
  scores = torch.softmax(logits / temp, dim=-1) * mask
175
  samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
176
  if i == 0:
 
6
  import torch.nn.functional as F
7
  import tqdm
8
  from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
9
+ from transformers import LlamaModel, LlamaConfig, DynamicCache
10
  from transformers.integrations import PeftAdapterMixin
11
 
12
  from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
 
83
  set_peft_model_state_dict(self, adapter_state_dict, "default")
84
  return model.merge_and_unload()
85
 
86
+ def forward_token(self, hidden_state=None, x=None, cache=None):
87
  """
88
 
89
  :param hidden_state: (batch_size, n_embd)
90
  :param x: (batch_size, token_sequence_length)
91
+ :param cache: Cache
92
  :return: (batch_size, 1 + token_sequence_length, vocab_size)
93
  """
94
+ if hidden_state is not None:
95
+ #if you use cache, you don't need to pass in hidden_state
96
+ hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
97
  if x is not None:
98
  x = self.net_token.embed_tokens(x)
99
+ if hidden_state is not None:
100
+ x = torch.cat([hidden_state, x], dim=1)
101
+ hidden_state = x
102
+ hidden_state = self.net_token.forward(inputs_embeds=hidden_state,
103
+ past_key_values=cache,
104
+ use_cache=cache is not None).last_hidden_state
105
  return self.lm_head(hidden_state)
106
 
107
+ def forward(self, x, cache = None):
108
  """
109
  :param x: (batch_size, midi_sequence_length, token_sequence_length)
110
+ :param cache: Cache
111
  :return: hidden (batch_size, midi_sequence_length, n_embd)
112
  """
113
 
114
  # merge token sequence
115
  x = self.net.embed_tokens(x)
116
  x = x.sum(dim=-2)
117
+ x = self.net.forward(inputs_embeds=x,
118
+ past_key_values=cache,
119
+ use_cache=cache is not None)
120
  return x.last_hidden_state
121
 
122
  def sample_top_p_k(self, probs, p, k, generator=None):
 
159
 
160
  cur_len = input_tensor.shape[1]
161
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
162
+ cache1 = DynamicCache()
163
  with bar:
164
  while cur_len < max_len:
165
  end = [False] * batch_size
166
+ hidden = self.forward(input_tensor[:,-1:], cache=cache1)[:, -1]
167
  next_token_seq = None
168
  event_names = [""] * batch_size
169
+ cache2 = DynamicCache()
170
  for i in range(max_token_seq):
171
  mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device)
172
  for b in range(batch_size):
 
182
  continue
183
  mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
184
  mask = mask.unsqueeze(1)
185
+ x = next_token_seq
186
+ if i != 0:
187
+ # cached
188
+ hidden = None
189
+ x = x[:, -1:]
190
+ logits = self.forward_token(hidden, x, cache=cache2)[:, -1:]
191
  scores = torch.softmax(logits / temp, dim=-1) * mask
192
  samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
193
  if i == 0: