""" Egregore Relay Service - API Key storage and management (PostgreSQL) """ import os import secrets from datetime import datetime, timezone from typing import Optional import asyncpg import bcrypt # Database connection URL DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://egregore:egregore_db_pass@localhost/egregore") # Connection pool _pool: Optional[asyncpg.Pool] = None async def get_pool() -> asyncpg.Pool: """Get or create the connection pool""" global _pool if _pool is None: _pool = await asyncpg.create_pool(DATABASE_URL, min_size=1, max_size=5) return _pool async def close_pool(): """Close the connection pool""" global _pool if _pool: await _pool.close() _pool = None async def init_db(): """Initialize the API clients table""" pool = await get_pool() async with pool.acquire() as conn: await conn.execute(""" CREATE TABLE IF NOT EXISTS api_clients ( id SERIAL PRIMARY KEY, client_id VARCHAR(64) UNIQUE NOT NULL, api_key_hash VARCHAR(128) NOT NULL, name VARCHAR(255), scopes TEXT[] DEFAULT ARRAY['chat', 'history'], rate_limit INT DEFAULT 100, enabled BOOLEAN DEFAULT TRUE, expires_at TIMESTAMPTZ, created_at TIMESTAMPTZ DEFAULT NOW(), last_used_at TIMESTAMPTZ ) """) # Index for fast client_id lookups await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_api_clients_client_id ON api_clients(client_id) """) # Index for enabled status filtering await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_api_clients_enabled ON api_clients(enabled) """) def generate_api_key() -> str: """Generate a new API key with 'eg_' prefix""" return f"eg_{secrets.token_hex(20)}" def hash_api_key(api_key: str) -> str: """Hash an API key using bcrypt""" return bcrypt.hashpw(api_key.encode(), bcrypt.gensalt()).decode() def verify_api_key_hash(api_key: str, key_hash: str) -> bool: """Verify an API key against its hash""" return bcrypt.checkpw(api_key.encode(), key_hash.encode()) async def create_client( name: str, scopes: list[str] = None, rate_limit: int = 100, expires_at: datetime = None ) -> tuple[str, str]: """ Create a new API client. Returns (client_id, api_key) - api_key is only returned once! """ pool = await get_pool() client_id = f"client_{secrets.token_hex(8)}" api_key = generate_api_key() key_hash = hash_api_key(api_key) if scopes is None: scopes = ["chat", "history"] async with pool.acquire() as conn: await conn.execute( """INSERT INTO api_clients (client_id, api_key_hash, name, scopes, rate_limit, expires_at) VALUES ($1, $2, $3, $4, $5, $6)""", client_id, key_hash, name, scopes, rate_limit, expires_at ) return client_id, api_key async def get_client_by_api_key(api_key: str) -> Optional[dict]: """ Look up a client by API key. Returns client info if key is valid and enabled, None otherwise. Updates last_used_at on successful lookup. """ if not api_key or not api_key.startswith("eg_"): return None pool = await get_pool() async with pool.acquire() as conn: # Get all enabled clients and check each hash # (We can't query by hash directly since bcrypt salts are unique) rows = await conn.fetch( """SELECT id, client_id, api_key_hash, name, scopes, rate_limit, expires_at FROM api_clients WHERE enabled = TRUE""" ) for row in rows: if verify_api_key_hash(api_key, row['api_key_hash']): # Check expiry if row['expires_at'] and row['expires_at'] < datetime.now(timezone.utc): return None # Update last_used_at await conn.execute( "UPDATE api_clients SET last_used_at = NOW() WHERE id = $1", row['id'] ) return { "id": row['id'], "client_id": row['client_id'], "name": row['name'], "scopes": list(row['scopes']) if row['scopes'] else [], "rate_limit": row['rate_limit'] } return None async def get_client_by_id(client_id: str) -> Optional[dict]: """Get client info by client_id""" pool = await get_pool() async with pool.acquire() as conn: row = await conn.fetchrow( """SELECT id, client_id, name, scopes, rate_limit, enabled, expires_at, created_at, last_used_at FROM api_clients WHERE client_id = $1""", client_id ) if not row: return None return { "id": row['id'], "client_id": row['client_id'], "name": row['name'], "scopes": list(row['scopes']) if row['scopes'] else [], "rate_limit": row['rate_limit'], "enabled": row['enabled'], "expires_at": row['expires_at'].isoformat() if row['expires_at'] else None, "created_at": row['created_at'].isoformat() if row['created_at'] else None, "last_used_at": row['last_used_at'].isoformat() if row['last_used_at'] else None } async def list_clients(include_disabled: bool = False) -> list[dict]: """List all API clients""" pool = await get_pool() async with pool.acquire() as conn: if include_disabled: rows = await conn.fetch( """SELECT id, client_id, name, scopes, rate_limit, enabled, expires_at, created_at, last_used_at FROM api_clients ORDER BY created_at DESC""" ) else: rows = await conn.fetch( """SELECT id, client_id, name, scopes, rate_limit, enabled, expires_at, created_at, last_used_at FROM api_clients WHERE enabled = TRUE ORDER BY created_at DESC""" ) return [{ "id": row['id'], "client_id": row['client_id'], "name": row['name'], "scopes": list(row['scopes']) if row['scopes'] else [], "rate_limit": row['rate_limit'], "enabled": row['enabled'], "expires_at": row['expires_at'].isoformat() if row['expires_at'] else None, "created_at": row['created_at'].isoformat() if row['created_at'] else None, "last_used_at": row['last_used_at'].isoformat() if row['last_used_at'] else None } for row in rows] async def update_client( client_id: str, name: str = None, scopes: list[str] = None, rate_limit: int = None, enabled: bool = None, expires_at: datetime = None ) -> bool: """Update client settings. Returns True if client was found and updated.""" pool = await get_pool() updates = [] params = [] param_idx = 1 if name is not None: updates.append(f"name = ${param_idx}") params.append(name) param_idx += 1 if scopes is not None: updates.append(f"scopes = ${param_idx}") params.append(scopes) param_idx += 1 if rate_limit is not None: updates.append(f"rate_limit = ${param_idx}") params.append(rate_limit) param_idx += 1 if enabled is not None: updates.append(f"enabled = ${param_idx}") params.append(enabled) param_idx += 1 if expires_at is not None: updates.append(f"expires_at = ${param_idx}") params.append(expires_at) param_idx += 1 if not updates: return False params.append(client_id) update_sql = ", ".join(updates) async with pool.acquire() as conn: result = await conn.execute( f"UPDATE api_clients SET {update_sql} WHERE client_id = ${param_idx}", *params ) return result == "UPDATE 1" async def disable_client(client_id: str) -> bool: """Disable a client (soft delete). Returns True if found.""" return await update_client(client_id, enabled=False) async def regenerate_api_key(client_id: str) -> Optional[str]: """Regenerate API key for a client. Returns new key or None if not found.""" pool = await get_pool() new_api_key = generate_api_key() key_hash = hash_api_key(new_api_key) async with pool.acquire() as conn: result = await conn.execute( "UPDATE api_clients SET api_key_hash = $1 WHERE client_id = $2", key_hash, client_id ) if result == "UPDATE 1": return new_api_key return None async def delete_client(client_id: str) -> bool: """Permanently delete a client. Returns True if found.""" pool = await get_pool() async with pool.acquire() as conn: result = await conn.execute( "DELETE FROM api_clients WHERE client_id = $1", client_id ) return result == "DELETE 1"