From 78ee93dbc69fef98fcd145fab6d978b8491db478 Mon Sep 17 00:00:00 2001 From: egregore Date: Mon, 2 Feb 2026 19:54:40 +0000 Subject: [PATCH] Add PostgreSQL-backed API key storage - api_keys.py: Database operations for API clients - bcrypt hashing for API keys - CRUD operations with full PostgreSQL support - Indexes for efficient client_id lookup - Soft delete (disable) and hard delete - Key regeneration support - main.py: Wire up database storage - Startup/shutdown handlers for DB pool - Full admin CRUD endpoints - Token exchange uses DB lookup Co-Authored-By: Claude Opus 4.5 --- api_keys.py | 292 +++++++++++++++++++++++++++++++++++++++++++++++ main.py | 125 +++++++++++++++----- requirements.txt | 1 + 3 files changed, 388 insertions(+), 30 deletions(-) create mode 100644 api_keys.py diff --git a/api_keys.py b/api_keys.py new file mode 100644 index 0000000..38bd155 --- /dev/null +++ b/api_keys.py @@ -0,0 +1,292 @@ +""" +Egregore Relay Service - API Key storage and management (PostgreSQL) +""" + +import os +import secrets +from datetime import datetime, timezone +from typing import Optional + +import asyncpg +import bcrypt + +# Database connection URL +DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://egregore:egregore_db_pass@localhost/egregore") + +# Connection pool +_pool: Optional[asyncpg.Pool] = None + + +async def get_pool() -> asyncpg.Pool: + """Get or create the connection pool""" + global _pool + if _pool is None: + _pool = await asyncpg.create_pool(DATABASE_URL, min_size=1, max_size=5) + return _pool + + +async def close_pool(): + """Close the connection pool""" + global _pool + if _pool: + await _pool.close() + _pool = None + + +async def init_db(): + """Initialize the API clients table""" + pool = await get_pool() + async with pool.acquire() as conn: + await conn.execute(""" + CREATE TABLE IF NOT EXISTS api_clients ( + id SERIAL PRIMARY KEY, + client_id VARCHAR(64) UNIQUE NOT NULL, + api_key_hash VARCHAR(128) NOT NULL, + name VARCHAR(255), + scopes TEXT[] DEFAULT ARRAY['chat', 'history'], + rate_limit INT DEFAULT 100, + enabled BOOLEAN DEFAULT TRUE, + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ DEFAULT NOW(), + last_used_at TIMESTAMPTZ + ) + """) + + # Index for fast client_id lookups + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_api_clients_client_id + ON api_clients(client_id) + """) + + # Index for enabled status filtering + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_api_clients_enabled + ON api_clients(enabled) + """) + + +def generate_api_key() -> str: + """Generate a new API key with 'eg_' prefix""" + return f"eg_{secrets.token_hex(20)}" + + +def hash_api_key(api_key: str) -> str: + """Hash an API key using bcrypt""" + return bcrypt.hashpw(api_key.encode(), bcrypt.gensalt()).decode() + + +def verify_api_key_hash(api_key: str, key_hash: str) -> bool: + """Verify an API key against its hash""" + return bcrypt.checkpw(api_key.encode(), key_hash.encode()) + + +async def create_client( + name: str, + scopes: list[str] = None, + rate_limit: int = 100, + expires_at: datetime = None +) -> tuple[str, str]: + """ + Create a new API client. + Returns (client_id, api_key) - api_key is only returned once! + """ + pool = await get_pool() + + client_id = f"client_{secrets.token_hex(8)}" + api_key = generate_api_key() + key_hash = hash_api_key(api_key) + + if scopes is None: + scopes = ["chat", "history"] + + async with pool.acquire() as conn: + await conn.execute( + """INSERT INTO api_clients (client_id, api_key_hash, name, scopes, rate_limit, expires_at) + VALUES ($1, $2, $3, $4, $5, $6)""", + client_id, key_hash, name, scopes, rate_limit, expires_at + ) + + return client_id, api_key + + +async def get_client_by_api_key(api_key: str) -> Optional[dict]: + """ + Look up a client by API key. + Returns client info if key is valid and enabled, None otherwise. + Updates last_used_at on successful lookup. + """ + if not api_key or not api_key.startswith("eg_"): + return None + + pool = await get_pool() + async with pool.acquire() as conn: + # Get all enabled clients and check each hash + # (We can't query by hash directly since bcrypt salts are unique) + rows = await conn.fetch( + """SELECT id, client_id, api_key_hash, name, scopes, rate_limit, expires_at + FROM api_clients + WHERE enabled = TRUE""" + ) + + for row in rows: + if verify_api_key_hash(api_key, row['api_key_hash']): + # Check expiry + if row['expires_at'] and row['expires_at'] < datetime.now(timezone.utc): + return None + + # Update last_used_at + await conn.execute( + "UPDATE api_clients SET last_used_at = NOW() WHERE id = $1", + row['id'] + ) + + return { + "id": row['id'], + "client_id": row['client_id'], + "name": row['name'], + "scopes": list(row['scopes']) if row['scopes'] else [], + "rate_limit": row['rate_limit'] + } + + return None + + +async def get_client_by_id(client_id: str) -> Optional[dict]: + """Get client info by client_id""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + """SELECT id, client_id, name, scopes, rate_limit, enabled, expires_at, created_at, last_used_at + FROM api_clients + WHERE client_id = $1""", + client_id + ) + + if not row: + return None + + return { + "id": row['id'], + "client_id": row['client_id'], + "name": row['name'], + "scopes": list(row['scopes']) if row['scopes'] else [], + "rate_limit": row['rate_limit'], + "enabled": row['enabled'], + "expires_at": row['expires_at'].isoformat() if row['expires_at'] else None, + "created_at": row['created_at'].isoformat() if row['created_at'] else None, + "last_used_at": row['last_used_at'].isoformat() if row['last_used_at'] else None + } + + +async def list_clients(include_disabled: bool = False) -> list[dict]: + """List all API clients""" + pool = await get_pool() + async with pool.acquire() as conn: + if include_disabled: + rows = await conn.fetch( + """SELECT id, client_id, name, scopes, rate_limit, enabled, expires_at, created_at, last_used_at + FROM api_clients + ORDER BY created_at DESC""" + ) + else: + rows = await conn.fetch( + """SELECT id, client_id, name, scopes, rate_limit, enabled, expires_at, created_at, last_used_at + FROM api_clients + WHERE enabled = TRUE + ORDER BY created_at DESC""" + ) + + return [{ + "id": row['id'], + "client_id": row['client_id'], + "name": row['name'], + "scopes": list(row['scopes']) if row['scopes'] else [], + "rate_limit": row['rate_limit'], + "enabled": row['enabled'], + "expires_at": row['expires_at'].isoformat() if row['expires_at'] else None, + "created_at": row['created_at'].isoformat() if row['created_at'] else None, + "last_used_at": row['last_used_at'].isoformat() if row['last_used_at'] else None + } for row in rows] + + +async def update_client( + client_id: str, + name: str = None, + scopes: list[str] = None, + rate_limit: int = None, + enabled: bool = None, + expires_at: datetime = None +) -> bool: + """Update client settings. Returns True if client was found and updated.""" + pool = await get_pool() + + updates = [] + params = [] + param_idx = 1 + + if name is not None: + updates.append(f"name = ${param_idx}") + params.append(name) + param_idx += 1 + if scopes is not None: + updates.append(f"scopes = ${param_idx}") + params.append(scopes) + param_idx += 1 + if rate_limit is not None: + updates.append(f"rate_limit = ${param_idx}") + params.append(rate_limit) + param_idx += 1 + if enabled is not None: + updates.append(f"enabled = ${param_idx}") + params.append(enabled) + param_idx += 1 + if expires_at is not None: + updates.append(f"expires_at = ${param_idx}") + params.append(expires_at) + param_idx += 1 + + if not updates: + return False + + params.append(client_id) + update_sql = ", ".join(updates) + + async with pool.acquire() as conn: + result = await conn.execute( + f"UPDATE api_clients SET {update_sql} WHERE client_id = ${param_idx}", + *params + ) + return result == "UPDATE 1" + + +async def disable_client(client_id: str) -> bool: + """Disable a client (soft delete). Returns True if found.""" + return await update_client(client_id, enabled=False) + + +async def regenerate_api_key(client_id: str) -> Optional[str]: + """Regenerate API key for a client. Returns new key or None if not found.""" + pool = await get_pool() + + new_api_key = generate_api_key() + key_hash = hash_api_key(new_api_key) + + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE api_clients SET api_key_hash = $1 WHERE client_id = $2", + key_hash, client_id + ) + if result == "UPDATE 1": + return new_api_key + return None + + +async def delete_client(client_id: str) -> bool: + """Permanently delete a client. Returns True if found.""" + pool = await get_pool() + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM api_clients WHERE client_id = $1", + client_id + ) + return result == "DELETE 1" diff --git a/main.py b/main.py index 063c8ad..292e46f 100644 --- a/main.py +++ b/main.py @@ -20,7 +20,18 @@ from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded import httpx import jwt -import bcrypt + +from api_keys import ( + init_db as init_api_keys_db, + close_pool as close_api_keys_pool, + get_client_by_api_key, + create_client, + get_client_by_id, + list_clients, + update_client, + disable_client, + regenerate_api_key, +) # Load environment load_dotenv("/home/admin/.env") @@ -52,9 +63,16 @@ bearer_scheme = HTTPBearer(auto_error=False) http_client = httpx.AsyncClient(timeout=120.0) -# In-memory API key store (replace with DB in production) -# Format: api_key_hash -> {client_id, scopes, rate_limit} -API_CLIENTS = {} +@app.on_event("startup") +async def startup(): + """Initialize database on startup""" + await init_api_keys_db() + + +@app.on_event("shutdown") +async def shutdown(): + """Clean up on shutdown""" + await close_api_keys_pool() # Request/Response models @@ -81,14 +99,6 @@ class ErrorResponse(BaseModel): # Helper functions -def verify_api_key(api_key: str) -> Optional[dict]: - """Verify API key and return client info""" - for key_hash, client in API_CLIENTS.items(): - if bcrypt.checkpw(api_key.encode(), key_hash.encode()): - return client - return None - - def create_jwt(client_id: str, scopes: list) -> str: """Create a JWT token for the client""" now = datetime.now(timezone.utc) @@ -178,7 +188,7 @@ async def health(): @app.post("/auth/token", response_model=TokenResponse) async def get_token(req: TokenRequest): """Exchange API key for JWT token""" - client = verify_api_key(req.api_key) + client = await get_client_by_api_key(req.api_key) if not client: raise HTTPException( status_code=401, @@ -285,30 +295,85 @@ async def get_tools(client: dict = Depends(require_scope("tools"))): ) -# Admin endpoint to register API keys (temporary, move to proper admin) +# Admin endpoints for API key management +class CreateClientRequest(BaseModel): + name: str + scopes: list[str] = ["chat", "history"] + rate_limit: int = 100 + + +class UpdateClientRequest(BaseModel): + name: Optional[str] = None + scopes: Optional[list[str]] = None + rate_limit: Optional[int] = None + enabled: Optional[bool] = None + + @app.post("/admin/clients") -async def register_client( - client_id: str, - scopes: list = ["chat", "history"], -): - """Register a new API client (temporary admin endpoint)""" - # Generate API key - api_key = f"eg_{secrets.token_hex(20)}" - key_hash = bcrypt.hashpw(api_key.encode(), bcrypt.gensalt()).decode() - - API_CLIENTS[key_hash] = { - "client_id": client_id, - "scopes": scopes, - "rate_limit": 100 - } - +async def admin_create_client(req: CreateClientRequest): + """Create a new API client""" + client_id, api_key = await create_client( + name=req.name, + scopes=req.scopes, + rate_limit=req.rate_limit + ) return { "client_id": client_id, "api_key": api_key, # Only shown once! - "scopes": scopes + "name": req.name, + "scopes": req.scopes } +@app.get("/admin/clients") +async def admin_list_clients(include_disabled: bool = False): + """List all API clients""" + clients = await list_clients(include_disabled=include_disabled) + return {"clients": clients} + + +@app.get("/admin/clients/{client_id}") +async def admin_get_client(client_id: str): + """Get a specific client by ID""" + client = await get_client_by_id(client_id) + if not client: + raise HTTPException(status_code=404, detail="Client not found") + return client + + +@app.patch("/admin/clients/{client_id}") +async def admin_update_client(client_id: str, req: UpdateClientRequest): + """Update a client's settings""" + success = await update_client( + client_id=client_id, + name=req.name, + scopes=req.scopes, + rate_limit=req.rate_limit, + enabled=req.enabled + ) + if not success: + raise HTTPException(status_code=404, detail="Client not found") + return {"status": "updated"} + + +@app.delete("/admin/clients/{client_id}") +async def admin_disable_client(client_id: str): + """Disable a client (soft delete)""" + success = await disable_client(client_id) + if not success: + raise HTTPException(status_code=404, detail="Client not found") + return {"status": "disabled"} + + +@app.post("/admin/clients/{client_id}/regenerate-key") +async def admin_regenerate_key(client_id: str): + """Regenerate API key for a client""" + new_key = await regenerate_api_key(client_id) + if not new_key: + raise HTTPException(status_code=404, detail="Client not found") + return {"api_key": new_key} + + if __name__ == "__main__": import uvicorn uvicorn.run(app, host="127.0.0.1", port=GATEWAY_PORT) diff --git a/requirements.txt b/requirements.txt index 5b6177f..38d3745 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ pyjwt>=2.8.0 bcrypt>=4.0.0 slowapi>=0.1.9 pydantic>=2.0.0 +asyncpg>=0.29.0