Plachta commited on
Commit
2072390
·
verified ·
1 Parent(s): e38da33

Update modules/cosyvoice_tokenizer/frontend.py

Browse files
modules/cosyvoice_tokenizer/frontend.py CHANGED
@@ -1,54 +1,52 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from functools import partial
15
- import onnxruntime
16
- import torch
17
- import numpy as np
18
- import whisper
19
- import torchaudio.compliance.kaldi as kaldi
20
-
21
- class CosyVoiceFrontEnd:
22
-
23
- def __init__(self, speech_tokenizer_model: str, device: str = 'cuda', device_id: int = 0):
24
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
- option = onnxruntime.SessionOptions()
26
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
27
- option.intra_op_num_threads = 1
28
- self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if device == "cuda" else "CPUExecutionProvider"])
29
- if device == 'cuda':
30
- self.speech_tokenizer_session.set_providers(['CUDAExecutionProvider'], [ {'device_id': device_id}])
31
-
32
- def extract_speech_token(self, speech):
33
- feat = whisper.log_mel_spectrogram(speech, n_mels=128)
34
- speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
35
- self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
36
- speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
37
- speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
38
- return speech_token, speech_token_len
39
-
40
- def _extract_spk_embedding(self, speech):
41
- feat = kaldi.fbank(speech,
42
- num_mel_bins=80,
43
- dither=0,
44
- sample_frequency=16000)
45
- feat = feat - feat.mean(dim=0, keepdim=True)
46
- embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
47
- embedding = torch.tensor([embedding]).to(self.device)
48
- return embedding
49
-
50
- def _extract_speech_feat(self, speech):
51
- speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
52
- speech_feat = speech_feat.unsqueeze(dim=0)
53
- speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
54
  return speech_feat, speech_feat_len
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ import onnxruntime
16
+ import torch
17
+ import numpy as np
18
+ import whisper
19
+ import torchaudio.compliance.kaldi as kaldi
20
+
21
+ class CosyVoiceFrontEnd:
22
+
23
+ def __init__(self, speech_tokenizer_model: str, device: str = 'cuda', device_id: int = 0):
24
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ option = onnxruntime.SessionOptions()
26
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
27
+ option.intra_op_num_threads = 1
28
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider" if device == "cuda" and torch.cuda.is_available() else "CPUExecutionProvider"])
29
+
30
+ def extract_speech_token(self, speech):
31
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
32
+ speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
33
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
34
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
35
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
36
+ return speech_token, speech_token_len
37
+
38
+ def _extract_spk_embedding(self, speech):
39
+ feat = kaldi.fbank(speech,
40
+ num_mel_bins=80,
41
+ dither=0,
42
+ sample_frequency=16000)
43
+ feat = feat - feat.mean(dim=0, keepdim=True)
44
+ embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
45
+ embedding = torch.tensor([embedding]).to(self.device)
46
+ return embedding
47
+
48
+ def _extract_speech_feat(self, speech):
49
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
50
+ speech_feat = speech_feat.unsqueeze(dim=0)
51
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
 
 
52
  return speech_feat, speech_feat_len