Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,172 Bytes
bc752b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import yaml
import os
import re
from vita.model.vita_tts.utils import init_encoder_llm, load_checkpoint
class inferencePipeline():
def __init__(self, args):
self.args = args
with open(self.args.model_path + "/audiollm/train.yaml", 'r') as fin:
configs = yaml.safe_load(fin)
configs['cmvn_file'] = self.args.model_path + "/audiollm/global_cmvn"
configs['model_conf']['llm_path'] = self.args.llm_path
# Init asr model from configs
self.model = init_encoder_llm(configs)
load_checkpoint(self.model, self.args.model_path + "/audiollm/final.pt")
device = torch.device('cuda')
self.model = self.model.to(device)
self.model.eval()
def speech_dialogue(self,
audio: tuple,
role: str=None,
stat: str='sl',
past_key_values=None,
last_id=None,
past_tokens=None,
adapter_cache=None,
encoder_cache=None,
pe_index=0):
with torch.no_grad():
## input fbank
feats = audio
if feats is not None:
feats = feats.to('cuda')
feats_lengths = torch.tensor([feats.size(1)]).to('cuda')
else:
feats_lengths = None
extra_inputs = {}
extra_inputs['top_p'] = self.args.top_p
extra_inputs['top_k'] = self.args.top_k
extra_inputs['temperature'] = self.args.temperature
extra_inputs['past_key_values'] = past_key_values
extra_inputs['stat'] = stat
extra_inputs['last_id'] = last_id
extra_inputs['adapter_cache'] = adapter_cache
extra_inputs['encoder_cache'] = encoder_cache
extra_inputs['pe_index'] = pe_index
if role is not None and past_key_values is None:
# add <|im_end|> in chat_prefix
extra_inputs['role'] = '<|im_start|>system\n' + role # + '<|im_end|>'
with torch.autocast(device_type="cuda",
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32):
# preprocess system role first
if stat == 'pre':
past_key_values = self.model.set_system_role(extra_inputs)
stat = 'sl'
else:
(last_id, stat, past_key_values, adapter_cache,
encoder_cache, pe_index, hidden_state) = self.model.recognize(
feats,
feats_lengths,
extra_inputs=extra_inputs)
outputs = dict(
past_key_values=past_key_values,
stat=stat,
last_id=last_id,
adapter_cache=adapter_cache,
encoder_cache=encoder_cache,
pe_index=pe_index,
)
if stat == 'cs':
if past_tokens is None:
past_tokens = []
past_tokens.append(last_id[0][0])
text = self.model.tokenizer.decode(past_tokens, skip_special_tokens=True)
outputs['hidden_state'] = hidden_state
outputs['text'] = text
outputs['past_tokens'] = past_tokens
return outputs
def post_process(self, text):
"""
Post-processes the input text to standardize various characters and formatting.
Parameters:
- text (str): The input text string to be post-processed.
Actions:
1. Replaces various Chinese and English punctuation marks with standardized ones.
2. Removes newline, tab, and other unwanted whitespace characters.
3. Removes special characters like asterisks, underscores, backticks, and tildes.
4. Condenses whitespace following periods and colons.
5. Adjusts the format of numbered lists to use appropriate separators
6. Ensures the text ends with an appropriate punctuation mark
Returns:
- str: The post-processed text string.
"""
text = text.replace('、', ',')
text = text.replace('(', ',')
text = text.replace(')', ',')
text = text.replace('(', ',')
text = text.replace(')', ',')
text = re.sub(r'[\n\r\t]', '', text)
text = re.sub(r'[*_`~]', '', text)
text = re.sub(r'(\.|\:)\s+', r'\1', text)
if re.search(r'[\u4e00-\u9fa5]', text):
text = re.sub(r'(\d+)\.\s*([\u4e00-\u9fa5A-Za-z])', r'\1:\2', text)
else:
text = re.sub(r'(\d+)\.\s*([\w])', r'\1:\2', text)
if text and text[-1] not in ["。", "?", "!", ".", "?", "!"]:
if text[-1] in [",", ",", ";", ";", ":", ":", "、"]:
text = text[:-1] + "。"
else:
text += "。"
return text
|