""" Egregore Database - Message storage operations (PostgreSQL) """ import json import os import uuid from datetime import datetime from typing import Optional import asyncpg # Database connection URL - can be overridden via environment DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://egregore:egregore_db_pass@localhost/egregore") # Connection pool _pool: Optional[asyncpg.Pool] = None def set_db_url(url: str): """Set the database URL (call before init_db)""" global DATABASE_URL DATABASE_URL = url async def get_pool() -> asyncpg.Pool: """Get or create the connection pool""" global _pool if _pool is None: _pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) return _pool async def close_pool(): """Close the connection pool""" global _pool if _pool: await _pool.close() _pool = None # Message type priorities for notifications MESSAGE_PRIORITIES = { "text": 2, # Regular messages - notify "tool_use": 0, # Tool invocation - no notify "tool_result": 0, # Tool output - no notify "question": 3, # Questions to user - urgent notify "mode_change": 1, # State transitions - silent "thinking": 0, # Reasoning process - no notify "error": 2, # Error messages - notify } def get_priority_for_type(msg_type: str, content: str = "") -> int: """Get priority for a message type, with question detection""" base_priority = MESSAGE_PRIORITIES.get(msg_type, 0) if msg_type == "text" and content.strip().endswith("?"): return MESSAGE_PRIORITIES["question"] return base_priority async def init_db(): """Initialize PostgreSQL database with messages table""" pool = await get_pool() async with pool.acquire() as conn: await conn.execute(""" CREATE TABLE IF NOT EXISTS messages ( id SERIAL PRIMARY KEY, role TEXT NOT NULL, type TEXT NOT NULL DEFAULT 'text', content TEXT NOT NULL, group_id TEXT, metadata JSONB, priority INTEGER DEFAULT 0, timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW() ) """) # Create indexes for efficient querying await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_role ON messages(role)") await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_type ON messages(type)") await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_group_id ON messages(group_id)") await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_priority ON messages(priority)") await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON messages(timestamp)") # Full-text search index on content await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_messages_content_search ON messages USING gin(to_tsvector('english', content)) """) async def save_message( role: str, content: str, msg_type: str = "text", group_id: Optional[str] = None, metadata: Optional[dict] = None, priority: Optional[int] = None ) -> int: """Save a single message row""" pool = await get_pool() async with pool.acquire() as conn: if priority is None: priority = get_priority_for_type(msg_type, content) row = await conn.fetchrow( """INSERT INTO messages (role, type, content, group_id, metadata, priority) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id""", role, msg_type, content, group_id, json.dumps(metadata) if metadata else None, priority ) return row['id'] async def save_response_blocks(blocks: list, group_id: str) -> list: """ Save each response block as a separate row. Returns list of saved message dicts with IDs for frontend. """ saved_messages = [] pool = await get_pool() async with pool.acquire() as conn: timestamp = datetime.utcnow() for block in blocks: block_type = block.get("type", "text") content = "" metadata = None priority = MESSAGE_PRIORITIES.get(block_type, 0) if block_type == "text": content = block.get("content", "") if content.strip().endswith("?"): priority = MESSAGE_PRIORITIES["question"] elif block_type == "tool_use": content = json.dumps(block.get("input", {})) metadata = {"tool_name": block.get("name"), "tool_id": block.get("id")} elif block_type == "tool_result": content = block.get("content", "") metadata = {"tool_name": block.get("tool_name"), "tool_use_id": block.get("tool_use_id")} else: content = block.get("content", "") row = await conn.fetchrow( """INSERT INTO messages (role, type, content, group_id, metadata, priority, timestamp) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id""", "assistant", block_type, content, group_id, json.dumps(metadata) if metadata else None, priority, timestamp ) msg_id = row['id'] saved_messages.append({ "id": msg_id, "role": "assistant", "type": block_type, "content": content, "group_id": group_id, "metadata": metadata, "priority": priority, "timestamp": timestamp.isoformat() }) return saved_messages async def get_messages( limit: int = 50, before_id: int = None, msg_type: str = None ) -> tuple[list[dict], bool]: """Get messages with pagination. Returns (messages, has_more)""" pool = await get_pool() async with pool.acquire() as conn: params = [] where_clauses = [] param_idx = 1 if before_id: where_clauses.append(f"id < ${param_idx}") params.append(before_id) param_idx += 1 if msg_type: where_clauses.append(f"type = ${param_idx}") params.append(msg_type) param_idx += 1 where_sql = " AND ".join(where_clauses) if where_clauses else "TRUE" params.append(limit + 1) rows = await conn.fetch( f"""SELECT id, role, type, content, group_id, metadata, priority, timestamp FROM messages WHERE {where_sql} ORDER BY id DESC LIMIT ${param_idx}""", *params ) has_more = len(rows) > limit rows = rows[:limit] messages = [] for row in rows: metadata = None if row['metadata']: try: metadata = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata'] except: pass messages.append({ "id": row['id'], "role": row['role'], "type": row['type'], "content": row['content'], "group_id": row['group_id'], "metadata": metadata, "priority": row['priority'], "timestamp": row['timestamp'].isoformat() if row['timestamp'] else None }) return list(reversed(messages)), has_more async def get_conversation_history(limit: int = 100) -> list[dict]: """ Reconstruct Claude API message format from individual rows. Groups assistant messages by group_id to build proper content arrays. """ messages, _ = await get_messages(limit) api_messages = [] current_group = None current_assistant_content = [] for msg in messages: if msg["role"] == "user": if current_assistant_content: api_messages.append({ "role": "assistant", "content": current_assistant_content }) current_assistant_content = [] current_group = None api_messages.append({ "role": "user", "content": msg["content"] }) elif msg["role"] == "assistant": if msg["group_id"] != current_group: if current_assistant_content: api_messages.append({ "role": "assistant", "content": current_assistant_content }) current_assistant_content = [] current_group = msg["group_id"] if msg["type"] == "text": current_assistant_content.append({ "type": "text", "text": msg["content"] }) elif msg["type"] == "tool_use": tool_input = {} try: tool_input = json.loads(msg["content"]) except: pass metadata = msg.get("metadata") or {} current_assistant_content.append({ "type": "tool_use", "id": metadata.get("tool_id", str(uuid.uuid4())), "name": metadata.get("tool_name", "unknown"), "input": tool_input }) elif msg["type"] == "tool_result": if current_assistant_content: api_messages.append({ "role": "assistant", "content": current_assistant_content }) current_assistant_content = [] metadata = msg.get("metadata") or {} api_messages.append({ "role": "user", "content": [{ "type": "tool_result", "tool_use_id": metadata.get("tool_use_id", ""), "content": msg["content"] }] }) if current_assistant_content: api_messages.append({ "role": "assistant", "content": current_assistant_content }) return api_messages async def search_messages( query: str, limit: int = 20, msg_type: str = None ) -> list[dict]: """Search messages using PostgreSQL full-text search""" if len(query) < 2: return [] pool = await get_pool() async with pool.acquire() as conn: params = [query] param_idx = 2 type_filter = "" if msg_type: type_filter = f"AND type = ${param_idx}" params.append(msg_type) param_idx += 1 params.append(min(limit, 50)) # Use PostgreSQL full-text search with fallback to ILIKE rows = await conn.fetch( f"""SELECT id, role, type, content, group_id, metadata, priority, timestamp, ts_headline('english', content, plainto_tsquery('english', $1), 'StartSel=**, StopSel=**, MaxWords=50, MinWords=20') as snippet FROM messages WHERE (to_tsvector('english', content) @@ plainto_tsquery('english', $1) OR content ILIKE '%' || $1 || '%') {type_filter} ORDER BY id DESC LIMIT ${param_idx}""", *params ) results = [] for row in rows: metadata = None if row['metadata']: try: metadata = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata'] except: pass # Use ts_headline snippet, fallback to manual snippet snippet = row['snippet'] if row['snippet'] else row['content'][:100] results.append({ "id": row['id'], "role": row['role'], "type": row['type'], "content": row['content'], "group_id": row['group_id'], "metadata": metadata, "priority": row['priority'], "timestamp": row['timestamp'].isoformat() if row['timestamp'] else None, "snippet": snippet }) return results # Legacy compatibility - for migration def set_db_path(path: str): """Legacy function for SQLite compatibility - ignored for PostgreSQL""" pass