# base_generator.py from abc import ABC, abstractmethod from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple from dataclasses import dataclass from logging import getLogger from services.model_manager import ModelManager from services.cache import ResponseCache from services.batch_processor import BatchProcessor from services.health_check import HealthCheck from config.config import GenerationConfig, ModelConfig class BaseGenerator(ABC): """Base class for all generator implementations.""" def __init__( self, model_name: str, device: Optional[str] = None, default_generation_config: Optional[GenerationConfig] = None, model_config: Optional[ModelConfig] = None, cache_size: int = 1000, max_batch_size: int = 32 ): self.logger = getLogger(__name__) self.model_manager = ModelManager(device) self.cache = ResponseCache(cache_size) self.batch_processor = BatchProcessor(max_batch_size) self.health_check = HealthCheck() # self.tokenizer = self.model_manager.tokenizers[model_name] #self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer self.default_config = default_generation_config or GenerationConfig() self.model_config = model_config or ModelConfig() @abstractmethod async def generate_stream( self, prompt: str, config: Optional[GenerationConfig] = None ) -> AsyncGenerator[str, None]: pass @abstractmethod def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: pass @abstractmethod def generate( self, prompt: str, model_kwargs: Dict[str, Any], strategy: str = "default", **kwargs ) -> str: pass