Spaces:
Runtime error
Runtime error
"""Central AI Hub for coordinating all AI agents and operations.""" | |
import os | |
import uuid | |
import asyncio | |
from loguru import logger | |
from utils.llm_orchestrator import LLMOrchestrator | |
from ctransformers import AutoModelForCausalLM | |
import torch | |
from huggingface_hub import hf_hub_download | |
class CentralAIHub: | |
def __init__(self, api_key=None, model_path=None): | |
"""Initialize the Central AI Hub.""" | |
self.api_key = api_key | |
self.model_path = model_path | |
self.cache_dir = os.path.join(os.getcwd(), ".cache") | |
os.makedirs(self.cache_dir, exist_ok=True) | |
self.llm = None | |
self.llm_orchestrator = None | |
self.agents = { | |
'code_analysis': 'CodeAnalysisAgent', | |
'code_generation': 'CodeGenerationAgent', | |
'error_fixing': 'ErrorFixingAgent' | |
} | |
self.tasks = {} | |
self.active_agents = {} | |
self._initialized = False | |
self.max_retries = 3 | |
self.retry_delay = 2 | |
async def _initialize_llm_client(self): | |
"""Initialize LLM client with retry mechanism and proper model selection.""" | |
try: | |
if self.model_path: | |
# Load local model | |
logger.info(f"Loading local model from {self.model_path}") | |
self.llm = AutoModelForCausalLM.from_pretrained( | |
self.model_path, | |
model_type="qwen", | |
cache_dir=self.cache_dir, | |
local_files_only=True | |
) | |
logger.info(f"Local model loaded successfully") | |
else: | |
# Download model from Hugging Face Hub | |
model_name = "Qwen/Qwen2.5-14B-Instruct-GGUF" | |
model_filename = "Qwen2.5-14B_Uncensored_Instruct-Q8_0.gguf" | |
cached_model_path = os.path.join(self.cache_dir, model_filename) | |
if not os.path.exists(cached_model_path): | |
logger.info(f"Downloading model {model_filename} from Hugging Face Hub") | |
hf_hub_download( | |
repo_id=model_name, | |
filename=model_filename, | |
cache_dir=self.cache_dir, | |
local_files_only=False | |
) | |
logger.info(f"Model downloaded to {cached_model_path}") | |
else: | |
logger.info(f"Using cached model at {cached_model_path}") | |
self.llm = AutoModelForCausalLM.from_pretrained( | |
cached_model_path, | |
model_type="qwen", | |
local_files_only=True | |
) | |
logger.info(f"Model loaded successfully") | |
self.llm_orchestrator = LLMOrchestrator(self.llm) | |
return True | |
except Exception as e: | |
logger.error(f"Failed to initialize LLM client: {e}") | |
return False | |
async def start(self): | |
"""Start the Central AI Hub and initialize agents only after successful LLM connection.""" | |
if self._initialized: | |
return | |
logger.info("Starting Central AI Hub...") | |
if not await self._initialize_llm_client(): # Initialize LLM client first | |
raise Exception("Failed to initialize LLM client.") | |
for agent_type, agent_class in self.agents.items(): | |
try: | |
await self.initialize_agent(agent_class) | |
logger.info(f"Initialized {agent_class}") | |
except Exception as e: | |
logger.error(f"Failed to initialize agent {agent_class}: {e}") | |
raise # Re-raise the exception to halt the startup | |
self._initialized = True | |
logger.info("Central AI Hub initialization complete.") | |
async def delegate_task(self, task): | |
"""Delegate a task to appropriate agent.""" | |
if not task: | |
raise Exception("Task cannot be None") | |
task_id = str(uuid.uuid4()) | |
agent_type = await self.select_agent(task) | |
if not agent_type: | |
raise Exception(f"No suitable agent found for task type: {task['type']}") | |
self.tasks[task_id] = { | |
'status': 'active', | |
'task': task, | |
'agent': agent_type, | |
'result': None | |
} | |
# Process task asynchronously | |
asyncio.create_task(self._process_task(task_id)) | |
return task_id | |
async def _process_task(self, task_id): | |
"""Process a task asynchronously.""" | |
task_info = self.tasks[task_id] | |
try: | |
# Simulate task processing | |
await asyncio.sleep(2) # Simulated work | |
task_info['status'] = 'completed' | |
task_info['result'] = "Task processed successfully" | |
logger.info(f"Task {task_id} completed") | |
except Exception as e: | |
task_info['status'] = 'failed' | |
task_info['error'] = str(e) | |
logger.error(f"Error processing task {task_id}: {str(e)}") | |
async def get_task_status(self, task_id): | |
"""Get status of a task.""" | |
task_info = self.tasks.get(task_id, {'status': 'not_found'}) | |
return task_info | |
async def select_agent(self, task): | |
"""Select appropriate agent for task.""" | |
return self.agents.get(task['type']) | |
async def initialize_agent(self, agent_id): | |
"""Initialize an agent.""" | |
if agent_id not in self.agents.values(): | |
raise ValueError(f"Agent {agent_id} not found") | |
self.active_agents[agent_id] = True | |
async def shutdown(self): | |
"""Shutdown the Central AI Hub.""" | |
logger.info("Shutting down Central AI Hub...") | |
# Clean up active agents | |
self.active_agents.clear() | |
# Cancel any pending tasks | |
for task_id, task in self.tasks.items(): | |
if task['status'] == 'active': | |
task['status'] = 'cancelled' | |