mini-omni commited on
Commit
2541285
·
1 Parent(s): e37ec4f

fix device

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. inference.py +1 -1
  3. utils/snac_utils.py +13 -10
.gitignore CHANGED
@@ -4,6 +4,7 @@
4
  checkpoint/
5
  checkpoint_bak/
6
  output/
 
7
 
8
  __pycache__/
9
  *.py[cod]
 
4
  checkpoint/
5
  checkpoint_bak/
6
  output/
7
+ .DS_Store
8
 
9
  __pycache__/
10
  *.py[cod]
inference.py CHANGED
@@ -494,7 +494,7 @@ class OmniInference:
494
  if current_index == nums_generate:
495
  current_index = 0
496
  snac = get_snac(list_output, index, nums_generate)
497
- audio_stream = generate_audio_data(snac, self.snacmodel)
498
  yield audio_stream
499
 
500
  input_pos = input_pos.add_(1)
 
494
  if current_index == nums_generate:
495
  current_index = 0
496
  snac = get_snac(list_output, index, nums_generate)
497
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
498
  yield audio_stream
499
 
500
  input_pos = input_pos.add_(1)
utils/snac_utils.py CHANGED
@@ -21,8 +21,8 @@ def layershift(input_id, layer, stride=4160, shift=152000):
21
  return input_id + shift + layer * stride
22
 
23
 
24
- def generate_audio_data(snac_tokens, snacmodel):
25
- audio = reconstruct_tensors(snac_tokens)
26
  with torch.inference_mode():
27
  audio_hat = snacmodel.decode(audio)
28
  audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
@@ -55,9 +55,12 @@ def reconscruct_snac(output_list):
55
  return output
56
 
57
 
58
- def reconstruct_tensors(flattened_output):
59
  """Reconstructs the list of tensors from the flattened output."""
60
 
 
 
 
61
  def count_elements_between_hashes(lst):
62
  try:
63
  # Find the index of the first '#'
@@ -107,9 +110,9 @@ def reconstruct_tensors(flattened_output):
107
  tensor3.append(flattened_output[i + 6])
108
  tensor3.append(flattened_output[i + 7])
109
  codes = [
110
- list_to_torch_tensor(tensor1).cuda(),
111
- list_to_torch_tensor(tensor2).cuda(),
112
- list_to_torch_tensor(tensor3).cuda(),
113
  ]
114
 
115
  if n_tensors == 15:
@@ -133,10 +136,10 @@ def reconstruct_tensors(flattened_output):
133
  tensor4.append(flattened_output[i + 15])
134
 
135
  codes = [
136
- list_to_torch_tensor(tensor1).cuda(),
137
- list_to_torch_tensor(tensor2).cuda(),
138
- list_to_torch_tensor(tensor3).cuda(),
139
- list_to_torch_tensor(tensor4).cuda(),
140
  ]
141
 
142
  return codes
 
21
  return input_id + shift + layer * stride
22
 
23
 
24
+ def generate_audio_data(snac_tokens, snacmodel, device=None):
25
+ audio = reconstruct_tensors(snac_tokens, device)
26
  with torch.inference_mode():
27
  audio_hat = snacmodel.decode(audio)
28
  audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
 
55
  return output
56
 
57
 
58
+ def reconstruct_tensors(flattened_output, device=None):
59
  """Reconstructs the list of tensors from the flattened output."""
60
 
61
+ if device is None:
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
  def count_elements_between_hashes(lst):
65
  try:
66
  # Find the index of the first '#'
 
110
  tensor3.append(flattened_output[i + 6])
111
  tensor3.append(flattened_output[i + 7])
112
  codes = [
113
+ list_to_torch_tensor(tensor1).to(device),
114
+ list_to_torch_tensor(tensor2).to(device),
115
+ list_to_torch_tensor(tensor3).to(device),
116
  ]
117
 
118
  if n_tensors == 15:
 
136
  tensor4.append(flattened_output[i + 15])
137
 
138
  codes = [
139
+ list_to_torch_tensor(tensor1).to(device),
140
+ list_to_torch_tensor(tensor2).to(device),
141
+ list_to_torch_tensor(tensor3).to(device),
142
+ list_to_torch_tensor(tensor4).to(device),
143
  ]
144
 
145
  return codes