import os, sys from pathlib import Path import aiosqlite import asyncio from typing import Optional, Tuple, Dict from contextlib import asynccontextmanager import logging import json # Added for serialization/deserialization from .utils import ensure_content_dirs, generate_content_hash from .models import CrawlResult, MarkdownGenerationResult import xxhash import aiofiles from .config import NEED_MIGRATION from .version_manager import VersionManager from .async_logger import AsyncLogger from .utils import get_error_context, create_box_message # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") os.makedirs(DB_PATH, exist_ok=True) DB_PATH = os.path.join(base_directory, "crawl4ai.db") class AsyncDatabaseManager: def __init__(self, pool_size: int = 10, max_retries: int = 3): self.db_path = DB_PATH self.content_paths = ensure_content_dirs(os.path.dirname(DB_PATH)) self.pool_size = pool_size self.max_retries = max_retries self.connection_pool: Dict[int, aiosqlite.Connection] = {} self.pool_lock = asyncio.Lock() self.init_lock = asyncio.Lock() self.connection_semaphore = asyncio.Semaphore(pool_size) self._initialized = False self.version_manager = VersionManager() self.logger = AsyncLogger( log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"), verbose=False, tag_width=10 ) async def initialize(self): """Initialize the database and connection pool""" try: self.logger.info("Initializing database", tag="INIT") # Ensure the database file exists os.makedirs(os.path.dirname(self.db_path), exist_ok=True) # Check if version update is needed needs_update = self.version_manager.needs_update() # Always ensure base table exists await self.ainit_db() # Verify the table exists async with aiosqlite.connect(self.db_path, timeout=30.0) as db: async with db.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='crawled_data'" ) as cursor: result = await cursor.fetchone() if not result: raise Exception("crawled_data table was not created") # If version changed or fresh install, run updates if needs_update: self.logger.info("New version detected, running updates", tag="INIT") await self.update_db_schema() from .migrations import run_migration # Import here to avoid circular imports await run_migration() self.version_manager.update_version() # Update stored version after successful migration self.logger.success("Version update completed successfully", tag="COMPLETE") else: self.logger.success("Database initialization completed successfully", tag="COMPLETE") except Exception as e: self.logger.error( message="Database initialization error: {error}", tag="ERROR", params={"error": str(e)} ) self.logger.info( message="Database will be initialized on first use", tag="INIT" ) raise async def cleanup(self): """Cleanup connections when shutting down""" async with self.pool_lock: for conn in self.connection_pool.values(): await conn.close() self.connection_pool.clear() @asynccontextmanager async def get_connection(self): """Connection pool manager with enhanced error handling""" if not self._initialized: async with self.init_lock: if not self._initialized: try: await self.initialize() self._initialized = True except Exception as e: import sys error_context = get_error_context(sys.exc_info()) self.logger.error( message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}", tag="ERROR", force_verbose=True, params={ "error": str(e), "context": error_context["code_context"], "traceback": error_context["full_traceback"] } ) raise await self.connection_semaphore.acquire() task_id = id(asyncio.current_task()) try: async with self.pool_lock: if task_id not in self.connection_pool: try: conn = await aiosqlite.connect( self.db_path, timeout=30.0 ) await conn.execute('PRAGMA journal_mode = WAL') await conn.execute('PRAGMA busy_timeout = 5000') # Verify database structure async with conn.execute("PRAGMA table_info(crawled_data)") as cursor: columns = await cursor.fetchall() column_names = [col[1] for col in columns] expected_columns = { 'url', 'html', 'cleaned_html', 'markdown', 'extracted_content', 'success', 'media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files' } missing_columns = expected_columns - set(column_names) if missing_columns: raise ValueError(f"Database missing columns: {missing_columns}") self.connection_pool[task_id] = conn except Exception as e: import sys error_context = get_error_context(sys.exc_info()) error_message = ( f"Unexpected error in db get_connection at line {error_context['line_no']} " f"in {error_context['function']} ({error_context['filename']}):\n" f"Error: {str(e)}\n\n" f"Code context:\n{error_context['code_context']}" ) self.logger.error( message=create_box_message(error_message, type= "error"), ) raise yield self.connection_pool[task_id] except Exception as e: import sys error_context = get_error_context(sys.exc_info()) error_message = ( f"Unexpected error in db get_connection at line {error_context['line_no']} " f"in {error_context['function']} ({error_context['filename']}):\n" f"Error: {str(e)}\n\n" f"Code context:\n{error_context['code_context']}" ) self.logger.error( message=create_box_message(error_message, type= "error"), ) raise finally: async with self.pool_lock: if task_id in self.connection_pool: await self.connection_pool[task_id].close() del self.connection_pool[task_id] self.connection_semaphore.release() async def execute_with_retry(self, operation, *args): """Execute database operations with retry logic""" for attempt in range(self.max_retries): try: async with self.get_connection() as db: result = await operation(db, *args) await db.commit() return result except Exception as e: if attempt == self.max_retries - 1: self.logger.error( message="Operation failed after {retries} attempts: {error}", tag="ERROR", force_verbose=True, params={ "retries": self.max_retries, "error": str(e) } ) raise await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff async def ainit_db(self): """Initialize database schema""" async with aiosqlite.connect(self.db_path, timeout=30.0) as db: await db.execute(''' CREATE TABLE IF NOT EXISTS crawled_data ( url TEXT PRIMARY KEY, html TEXT, cleaned_html TEXT, markdown TEXT, extracted_content TEXT, success BOOLEAN, media TEXT DEFAULT "{}", links TEXT DEFAULT "{}", metadata TEXT DEFAULT "{}", screenshot TEXT DEFAULT "", response_headers TEXT DEFAULT "{}", downloaded_files TEXT DEFAULT "{}" -- New column added ) ''') await db.commit() async def update_db_schema(self): """Update database schema if needed""" async with aiosqlite.connect(self.db_path, timeout=30.0) as db: cursor = await db.execute("PRAGMA table_info(crawled_data)") columns = await cursor.fetchall() column_names = [column[1] for column in columns] # List of new columns to add new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files'] for column in new_columns: if column not in column_names: await self.aalter_db_add_column(column, db) await db.commit() async def aalter_db_add_column(self, new_column: str, db): """Add new column to the database""" if new_column == 'response_headers': await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"') else: await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') self.logger.info( message="Added column '{column}' to the database", tag="INIT", params={"column": new_column} ) async def aget_cached_url(self, url: str) -> Optional[CrawlResult]: """Retrieve cached URL data as CrawlResult""" async def _get(db): async with db.execute( 'SELECT * FROM crawled_data WHERE url = ?', (url,) ) as cursor: row = await cursor.fetchone() if not row: return None # Get column names columns = [description[0] for description in cursor.description] # Create dict from row data row_dict = dict(zip(columns, row)) # Load content from files using stored hashes content_fields = { 'html': row_dict['html'], 'cleaned_html': row_dict['cleaned_html'], 'markdown': row_dict['markdown'], 'extracted_content': row_dict['extracted_content'], 'screenshot': row_dict['screenshot'], 'screenshots': row_dict['screenshot'], } for field, hash_value in content_fields.items(): if hash_value: content = await self._load_content( hash_value, field.split('_')[0] # Get content type from field name ) row_dict[field] = content or "" else: row_dict[field] = "" # Parse JSON fields json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown'] for field in json_fields: try: row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {} except json.JSONDecodeError: row_dict[field] = {} if isinstance(row_dict['markdown'], Dict): row_dict['markdown_v2'] = row_dict['markdown'] if row_dict['markdown'].get('raw_markdown'): row_dict['markdown'] = row_dict['markdown']['raw_markdown'] # Parse downloaded_files try: row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else [] except json.JSONDecodeError: row_dict['downloaded_files'] = [] # Remove any fields not in CrawlResult model valid_fields = CrawlResult.__annotations__.keys() filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields} return CrawlResult(**filtered_dict) try: return await self.execute_with_retry(_get) except Exception as e: self.logger.error( message="Error retrieving cached URL: {error}", tag="ERROR", force_verbose=True, params={"error": str(e)} ) return None async def acache_url(self, result: CrawlResult): """Cache CrawlResult data""" # Store content files and get hashes content_map = { 'html': (result.html, 'html'), 'cleaned_html': (result.cleaned_html or "", 'cleaned'), 'markdown': None, 'extracted_content': (result.extracted_content or "", 'extracted'), 'screenshot': (result.screenshot or "", 'screenshots') } try: if isinstance(result.markdown, MarkdownGenerationResult): content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown') elif hasattr(result, 'markdown_v2'): content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown') elif isinstance(result.markdown, str): markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown) content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown') else: content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') except Exception as e: self.logger.warning( message=f"Error processing markdown content: {str(e)}", tag="WARNING" ) # Fallback to empty markdown result content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') content_hashes = {} for field, (content, content_type) in content_map.items(): content_hashes[field] = await self._store_content(content, content_type) async def _cache(db): await db.execute(''' INSERT INTO crawled_data ( url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot, response_headers, downloaded_files ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(url) DO UPDATE SET html = excluded.html, cleaned_html = excluded.cleaned_html, markdown = excluded.markdown, extracted_content = excluded.extracted_content, success = excluded.success, media = excluded.media, links = excluded.links, metadata = excluded.metadata, screenshot = excluded.screenshot, response_headers = excluded.response_headers, downloaded_files = excluded.downloaded_files ''', ( result.url, content_hashes['html'], content_hashes['cleaned_html'], content_hashes['markdown'], content_hashes['extracted_content'], result.success, json.dumps(result.media), json.dumps(result.links), json.dumps(result.metadata or {}), content_hashes['screenshot'], json.dumps(result.response_headers or {}), json.dumps(result.downloaded_files or []) )) try: await self.execute_with_retry(_cache) except Exception as e: self.logger.error( message="Error caching URL: {error}", tag="ERROR", force_verbose=True, params={"error": str(e)} ) async def aget_total_count(self) -> int: """Get total number of cached URLs""" async def _count(db): async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor: result = await cursor.fetchone() return result[0] if result else 0 try: return await self.execute_with_retry(_count) except Exception as e: self.logger.error( message="Error getting total count: {error}", tag="ERROR", force_verbose=True, params={"error": str(e)} ) return 0 async def aclear_db(self): """Clear all data from the database""" async def _clear(db): await db.execute('DELETE FROM crawled_data') try: await self.execute_with_retry(_clear) except Exception as e: self.logger.error( message="Error clearing database: {error}", tag="ERROR", force_verbose=True, params={"error": str(e)} ) async def aflush_db(self): """Drop the entire table""" async def _flush(db): await db.execute('DROP TABLE IF EXISTS crawled_data') try: await self.execute_with_retry(_flush) except Exception as e: self.logger.error( message="Error flushing database: {error}", tag="ERROR", force_verbose=True, params={"error": str(e)} ) async def _store_content(self, content: str, content_type: str) -> str: """Store content in filesystem and return hash""" if not content: return "" content_hash = generate_content_hash(content) file_path = os.path.join(self.content_paths[content_type], content_hash) # Only write if file doesn't exist if not os.path.exists(file_path): async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: await f.write(content) return content_hash async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]: """Load content from filesystem by hash""" if not content_hash: return None file_path = os.path.join(self.content_paths[content_type], content_hash) try: async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: return await f.read() except: self.logger.error( message="Failed to load content: {file_path}", tag="ERROR", force_verbose=True, params={"file_path": file_path} ) return None # Create a singleton instance async_db_manager = AsyncDatabaseManager()