wanderor commited on
Commit
1ceb0cb
·
verified ·
1 Parent(s): 21e853e

Supports MacOS in MiniCPM-o 2.6

Browse files

Decouples from cuda.

Note: verified in torch 2.5.1. It does not work in torch 2.3.1 on MacOS (specified in requirements).

Files changed (1) hide show
  1. modeling_minicpmo.py +6 -6
modeling_minicpmo.py CHANGED
@@ -184,7 +184,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
184
  args=(),
185
  init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
186
  )
187
- vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32)
188
  vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
189
  return vocos
190
 
@@ -1185,7 +1185,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1185
 
1186
  terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
1187
  generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
1188
- input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda()
1189
 
1190
  spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
1191
  spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
@@ -1289,7 +1289,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1289
  text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
1290
  tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
1291
  "input_ids"
1292
- ].cuda()
1293
  return tts_input_ids
1294
 
1295
  def _build_streaming_mask(self, tts_tokens_len):
@@ -1320,7 +1320,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1320
  gen_text = text.split("<|tts_eos|>")[0]
1321
  tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
1322
  tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
1323
- tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long)
1324
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1325
 
1326
  logits_warpers, logits_processors = gen_logits(
@@ -1617,7 +1617,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1617
 
1618
  tts_input_ids = self.tts_processor.text_tokenizer(
1619
  tts_text, return_tensors="pt", add_special_tokens=False
1620
- )["input_ids"].cuda()
1621
  text_input_ids = tts_input_ids[:, begin:end]
1622
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1623
  position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
@@ -1726,7 +1726,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1726
  if end > begin:
1727
  tts_input_ids = self.tts_processor.text_tokenizer(
1728
  tts_text, return_tensors="pt", add_special_tokens=False
1729
- )["input_ids"].cuda()
1730
  text_input_ids = tts_input_ids[:, begin:end]
1731
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1732
  position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
 
184
  args=(),
185
  init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
186
  )
187
+ vocos = Vocos(feature_extractor, backbone, head).to(self.device).eval().to(torch.float32)
188
  vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
189
  return vocos
190
 
 
1185
 
1186
  terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
1187
  generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
1188
+ input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(self.device)
1189
 
1190
  spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
1191
  spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
 
1289
  text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
1290
  tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
1291
  "input_ids"
1292
+ ].to(self.device)
1293
  return tts_input_ids
1294
 
1295
  def _build_streaming_mask(self, tts_tokens_len):
 
1320
  gen_text = text.split("<|tts_eos|>")[0]
1321
  tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
1322
  tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
1323
+ tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long)
1324
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1325
 
1326
  logits_warpers, logits_processors = gen_logits(
 
1617
 
1618
  tts_input_ids = self.tts_processor.text_tokenizer(
1619
  tts_text, return_tensors="pt", add_special_tokens=False
1620
+ )["input_ids"].to(self.device)
1621
  text_input_ids = tts_input_ids[:, begin:end]
1622
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1623
  position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
 
1726
  if end > begin:
1727
  tts_input_ids = self.tts_processor.text_tokenizer(
1728
  tts_text, return_tensors="pt", add_special_tokens=False
1729
+ )["input_ids"].to(self.device)
1730
  text_input_ids = tts_input_ids[:, begin:end]
1731
  streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
1732
  position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)