RVC-Speakers / speakers /tasks /base_task.py
glide-the
Add large files to Git LFS
04ffec9
raw
history blame
3.91 kB
from abc import abstractmethod
from typing import List, Dict
from speakers.load.serializable import Serializable
from speakers.processors import ProcessorData, BaseProcessor
from speakers.server.model.flow_data import PayLoad
import logging
class FlowData(Serializable):
"""
当前runner的任务参数
"""
@property
@abstractmethod
def type(self) -> str:
"""Type of the Message, used for serialization."""
@property
def lc_serializable(self) -> bool:
"""Whether this class is Processor serializable."""
return True
class Runner(Serializable):
""" runner的任务id"""
task_id: str
flow_data: FlowData
@property
def type(self) -> str:
"""Type of the Runner Message, used for serialization."""
return 'runner'
@property
def lc_serializable(self) -> bool:
"""Whether this class is Processor serializable."""
return True
# Define a base class for tasks
class BaseTask:
"""
基础任务处理器由任务管理器创建,用于执行runner flow 的任务,子类实现具体的处理流程
此类定义了流程runner task的生命周期
"""
def __init__(self, preprocess_dict: Dict[str, BaseProcessor]):
self._progress_hooks = []
self._add_logger_hook()
self._preprocess_dict = preprocess_dict
self.logger = logging.getLogger('base_task_runner')
@classmethod
def from_config(cls, cfg=None):
return cls(preprocess_dict={})
def _add_logger_hook(self):
"""
默认的任务日志监听者
:return:
"""
LOG_MESSAGES = {
'dispatch_voice_task': 'dispatch_voice_task',
'saved': 'Saving results',
}
LOG_MESSAGES_SKIP = {
'skip-no-text': 'No text regions with text! - Skipping',
}
LOG_MESSAGES_ERROR = {
'error': 'task error',
}
async def ph(task_id: str, runner_stat: str, state: str, finished: bool = False):
if state in LOG_MESSAGES:
self.logger.info(LOG_MESSAGES[state])
elif state in LOG_MESSAGES_SKIP:
self.logger.warn(LOG_MESSAGES_SKIP[state])
elif state in LOG_MESSAGES_ERROR:
self.logger.error(LOG_MESSAGES_ERROR[state])
self.add_progress_hook(ph)
def add_progress_hook(self, ph):
"""
注册监听器
:param ph: 监听者
:return:
"""
self._progress_hooks.append(ph)
async def report_progress(self, task_id: str, runner_stat: str, state: str, finished: bool = False):
"""
任务通知监听器
:param task_id: 任务id
:param runner_stat: 任务执行位置
:param state: 状态
:param finished: 是否完成
:return:
"""
for ph in self._progress_hooks:
await ph(task_id, runner_stat, state, finished)
@classmethod
def prepare(cls, payload: PayLoad) -> Runner:
"""
预处理
Args:
payload (PayLoad): runner flow data
Raises:
NotImplementedError: This method should be overridden by subclasses.
"""
raise NotImplementedError
@classmethod
async def dispatch(cls, runner: Runner):
"""
当前runner task具体flow data
Args:
runner (ProcessorData): runner flow data
Raises:
NotImplementedError: This method should be overridden by subclasses.
"""
raise NotImplementedError
@classmethod
def complete(cls, runner: Runner):
"""
后置处理
Args:
runner (Runner): runner flow data
Raises:
NotImplementedError: This method should be overridden by subclasses.
"""
raise NotImplementedError