from abc import abstractmethod from typing import List, Dict from speakers.load.serializable import Serializable from speakers.processors import ProcessorData, BaseProcessor from speakers.server.model.flow_data import PayLoad import logging class FlowData(Serializable): """ 当前runner的任务参数 """ @property @abstractmethod def type(self) -> str: """Type of the Message, used for serialization.""" @property def lc_serializable(self) -> bool: """Whether this class is Processor serializable.""" return True class Runner(Serializable): """ runner的任务id""" task_id: str flow_data: FlowData @property def type(self) -> str: """Type of the Runner Message, used for serialization.""" return 'runner' @property def lc_serializable(self) -> bool: """Whether this class is Processor serializable.""" return True # Define a base class for tasks class BaseTask: """ 基础任务处理器由任务管理器创建,用于执行runner flow 的任务,子类实现具体的处理流程 此类定义了流程runner task的生命周期 """ def __init__(self, preprocess_dict: Dict[str, BaseProcessor]): self._progress_hooks = [] self._add_logger_hook() self._preprocess_dict = preprocess_dict self.logger = logging.getLogger('base_task_runner') @classmethod def from_config(cls, cfg=None): return cls(preprocess_dict={}) def _add_logger_hook(self): """ 默认的任务日志监听者 :return: """ LOG_MESSAGES = { 'dispatch_voice_task': 'dispatch_voice_task', 'saved': 'Saving results', } LOG_MESSAGES_SKIP = { 'skip-no-text': 'No text regions with text! - Skipping', } LOG_MESSAGES_ERROR = { 'error': 'task error', } async def ph(task_id: str, runner_stat: str, state: str, finished: bool = False): if state in LOG_MESSAGES: self.logger.info(LOG_MESSAGES[state]) elif state in LOG_MESSAGES_SKIP: self.logger.warn(LOG_MESSAGES_SKIP[state]) elif state in LOG_MESSAGES_ERROR: self.logger.error(LOG_MESSAGES_ERROR[state]) self.add_progress_hook(ph) def add_progress_hook(self, ph): """ 注册监听器 :param ph: 监听者 :return: """ self._progress_hooks.append(ph) async def report_progress(self, task_id: str, runner_stat: str, state: str, finished: bool = False): """ 任务通知监听器 :param task_id: 任务id :param runner_stat: 任务执行位置 :param state: 状态 :param finished: 是否完成 :return: """ for ph in self._progress_hooks: await ph(task_id, runner_stat, state, finished) @classmethod def prepare(cls, payload: PayLoad) -> Runner: """ 预处理 Args: payload (PayLoad): runner flow data Raises: NotImplementedError: This method should be overridden by subclasses. """ raise NotImplementedError @classmethod async def dispatch(cls, runner: Runner): """ 当前runner task具体flow data Args: runner (ProcessorData): runner flow data Raises: NotImplementedError: This method should be overridden by subclasses. """ raise NotImplementedError @classmethod def complete(cls, runner: Runner): """ 后置处理 Args: runner (Runner): runner flow data Raises: NotImplementedError: This method should be overridden by subclasses. """ raise NotImplementedError