skytnt commited on
Commit
1079729
1 Parent(s): 48914a6

merge lora into model

Browse files
Files changed (2) hide show
  1. app.py +13 -15
  2. midi_model.py +8 -0
app.py CHANGED
@@ -142,12 +142,7 @@ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_sele
142
  def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
143
  key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
144
  seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
145
- model, lora_name = models[model_name]
146
- if lora_name is None and model.peft_loaded():
147
- model.disable_adapters()
148
- elif lora_name is not None:
149
- model.enable_adapters()
150
- model.set_adapter(lora_name)
151
  model.to(device=opt.device)
152
  tokenizer = model.tokenizer
153
  bpm = int(bpm)
@@ -258,7 +253,7 @@ def finish_run(model_name, mid_seq):
258
  if mid_seq is None:
259
  outputs = [None] * OUTPUT_BATCH_SIZE
260
  return *outputs, []
261
- tokenizer = models[model_name][0].tokenizer
262
  outputs = []
263
  end_msgs = [create_msg("progress", [0, 0])]
264
  if not os.path.exists("outputs"):
@@ -282,7 +277,7 @@ def render_audio(model_name, mid_seq, should_render_audio):
282
  if (not should_render_audio) or mid_seq is None:
283
  outputs = [None] * OUTPUT_BATCH_SIZE
284
  return tuple(outputs)
285
- tokenizer = models[model_name][0].tokenizer
286
  outputs = []
287
  if not os.path.exists("outputs"):
288
  os.mkdir("outputs")
@@ -293,13 +288,15 @@ def render_audio(model_name, mid_seq, should_render_audio):
293
  audio_futures.append(audio_future)
294
  for future in audio_futures:
295
  outputs.append((44100, future.result()))
 
 
296
  return tuple(outputs)
297
 
298
 
299
  def undo_continuation(model_name, mid_seq, continuation_state):
300
  if mid_seq is None or len(continuation_state) < 2:
301
  return mid_seq, continuation_state, send_msgs([])
302
- tokenizer = models[model_name][0].tokenizer
303
  if isinstance(continuation_state[-1], list):
304
  mid_seq = continuation_state[-1]
305
  else:
@@ -399,14 +396,15 @@ if __name__ == "__main__":
399
  ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
400
  state_dict = ckpt.get("state_dict", ckpt)
401
  model.load_state_dict(state_dict, strict=False)
402
- for lora_name, lora_repo in loras.items():
403
- model.load_adapter(lora_repo, lora_name)
404
- if loras:
405
- model.disable_adapters()
406
  model.to(device="cpu", dtype=torch.float32).eval()
407
- models[name] = model, None
408
  for lora_name, lora_repo in loras.items():
409
- models[f"{name} with {lora_name} lora"] = model, lora_name
 
 
 
 
 
410
 
411
  load_javascript()
412
  app = gr.Blocks()
 
142
  def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
143
  key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
144
  seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
145
+ model = models[model_name]
 
 
 
 
 
146
  model.to(device=opt.device)
147
  tokenizer = model.tokenizer
148
  bpm = int(bpm)
 
253
  if mid_seq is None:
254
  outputs = [None] * OUTPUT_BATCH_SIZE
255
  return *outputs, []
256
+ tokenizer = models[model_name].tokenizer
257
  outputs = []
258
  end_msgs = [create_msg("progress", [0, 0])]
259
  if not os.path.exists("outputs"):
 
277
  if (not should_render_audio) or mid_seq is None:
278
  outputs = [None] * OUTPUT_BATCH_SIZE
279
  return tuple(outputs)
280
+ tokenizer = models[model_name].tokenizer
281
  outputs = []
282
  if not os.path.exists("outputs"):
283
  os.mkdir("outputs")
 
288
  audio_futures.append(audio_future)
289
  for future in audio_futures:
290
  outputs.append((44100, future.result()))
291
+ if OUTPUT_BATCH_SIZE == 1:
292
+ return outputs[0]
293
  return tuple(outputs)
294
 
295
 
296
  def undo_continuation(model_name, mid_seq, continuation_state):
297
  if mid_seq is None or len(continuation_state) < 2:
298
  return mid_seq, continuation_state, send_msgs([])
299
+ tokenizer = models[model_name].tokenizer
300
  if isinstance(continuation_state[-1], list):
301
  mid_seq = continuation_state[-1]
302
  else:
 
396
  ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
397
  state_dict = ckpt.get("state_dict", ckpt)
398
  model.load_state_dict(state_dict, strict=False)
 
 
 
 
399
  model.to(device="cpu", dtype=torch.float32).eval()
400
+ models[name] = model
401
  for lora_name, lora_repo in loras.items():
402
+ model = MIDIModel(config=MIDIModelConfig.from_name(config))
403
+ model.load_state_dict(state_dict, strict=False)
404
+ print(f"loading lora {lora_repo} for {name}")
405
+ model = model.load_merge_lora(lora_repo)
406
+ model.to(device="cpu", dtype=torch.float32).eval()
407
+ models[f"{name} with {lora_name} lora"] = model
408
 
409
  load_javascript()
410
  app = gr.Blocks()
midi_model.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import tqdm
 
8
  from transformers import LlamaModel, LlamaConfig
9
  from transformers.integrations import PeftAdapterMixin
10
 
@@ -75,6 +76,13 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
75
  def peft_loaded(self):
76
  return self._hf_peft_config_loaded
77
 
 
 
 
 
 
 
 
78
  def forward_token(self, hidden_state, x=None):
79
  """
80
 
 
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import tqdm
8
+ from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
9
  from transformers import LlamaModel, LlamaConfig
10
  from transformers.integrations import PeftAdapterMixin
11
 
 
76
  def peft_loaded(self):
77
  return self._hf_peft_config_loaded
78
 
79
+ def load_merge_lora(self, model_id):
80
+ peft_config = PeftConfig.from_pretrained(model_id)
81
+ model = LoraModel(self, peft_config, adapter_name="default")
82
+ adapter_state_dict = load_peft_weights(model_id, device=self.device)
83
+ set_peft_model_state_dict(self, adapter_state_dict, "default")
84
+ return model.merge_and_unload()
85
+
86
  def forward_token(self, hidden_state, x=None):
87
  """
88