relay/main.py

380 lines
10 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""
Egregore Relay Service - API Gateway
Relays authenticated API requests to backend services.
Runs on port 8085.
"""
import os
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from dotenv import load_dotenv
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
import httpx
import jwt
from api_keys import (
init_db as init_api_keys_db,
close_pool as close_api_keys_pool,
get_client_by_api_key,
create_client,
get_client_by_id,
list_clients,
update_client,
disable_client,
regenerate_api_key,
)
# Load environment
load_dotenv("/home/admin/.env")
JWT_SECRET = os.getenv("JWT_SECRET", secrets.token_hex(32))
JWT_EXPIRY = int(os.getenv("JWT_EXPIRY", "3600"))
GATEWAY_PORT = int(os.getenv("GATEWAY_PORT", "8085"))
# Backend service URLs
REASON_URL = os.getenv("REASON_URL", "http://127.0.0.1:8081")
RECALL_URL = os.getenv("RECALL_URL", "http://127.0.0.1:8082")
CONVERSE_URL = os.getenv("CONVERSE_URL", "http://127.0.0.1:8080")
# Rate limiting
limiter = Limiter(key_func=get_remote_address)
app = FastAPI(
title="Egregore Relay Service",
version="1.0.0",
docs_url="/docs"
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Security
bearer_scheme = HTTPBearer(auto_error=False)
# HTTP client for backend services
http_client = httpx.AsyncClient(timeout=120.0)
@app.on_event("startup")
async def startup():
"""Initialize database on startup"""
await init_api_keys_db()
@app.on_event("shutdown")
async def shutdown():
"""Clean up on shutdown"""
await close_api_keys_pool()
# Request/Response models
class TokenRequest(BaseModel):
api_key: str
class TokenResponse(BaseModel):
token: str
expires_in: int
token_type: str = "Bearer"
class ChatRequest(BaseModel):
message: str
model: str = "claude-sonnet-4-20250514"
max_iterations: int = 10
class ErrorResponse(BaseModel):
error: str
message: str
details: Optional[dict] = None
# Helper functions
def create_jwt(client_id: str, scopes: list) -> str:
"""Create a JWT token for the client"""
now = datetime.now(timezone.utc)
payload = {
"sub": client_id,
"iss": "egregore",
"iat": now,
"exp": now + timedelta(seconds=JWT_EXPIRY),
"scope": scopes
}
return jwt.encode(payload, JWT_SECRET, algorithm="HS256")
def verify_jwt(token: str) -> Optional[dict]:
"""Verify JWT and return payload"""
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
async def get_current_client(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
) -> dict:
"""Dependency to get current authenticated client from JWT"""
if not credentials:
raise HTTPException(
status_code=401,
detail={"error": "invalid_token", "message": "Missing authorization header"}
)
payload = verify_jwt(credentials.credentials)
if not payload:
raise HTTPException(
status_code=401,
detail={"error": "invalid_token", "message": "Invalid or expired token"}
)
return payload
def require_scope(required: str):
"""Dependency factory to check for required scope"""
async def check_scope(client: dict = Depends(get_current_client)):
scopes = client.get("scope", [])
if "*" in scopes or required in scopes:
return client
raise HTTPException(
status_code=403,
detail={"error": "insufficient_scope", "message": f"Requires '{required}' scope"}
)
return check_scope
# Public endpoints
@app.get("/health")
async def health():
"""Health check with backend status"""
reason_ok = False
recall_ok = False
try:
resp = await http_client.get(f"{REASON_URL}/health", timeout=2.0)
reason_ok = resp.status_code == 200
except:
pass
try:
resp = await http_client.get(f"{RECALL_URL}/health", timeout=2.0)
recall_ok = resp.status_code == 200
except:
pass
return {
"status": "ok" if (reason_ok and recall_ok) else "degraded",
"service": "relay",
"backends": {
"reason": "ok" if reason_ok else "unavailable",
"recall": "ok" if recall_ok else "unavailable"
}
}
@app.post("/auth/token", response_model=TokenResponse)
async def get_token(req: TokenRequest):
"""Exchange API key for JWT token"""
client = await get_client_by_api_key(req.api_key)
if not client:
raise HTTPException(
status_code=401,
detail={"error": "invalid_api_key", "message": "API key not found or revoked"}
)
token = create_jwt(client["client_id"], client["scopes"])
return TokenResponse(token=token, expires_in=JWT_EXPIRY)
# Protected endpoints
@app.post("/v1/chat")
@limiter.limit("10/minute")
async def chat(
request: Request,
req: ChatRequest,
client: dict = Depends(require_scope("chat"))
):
"""Send a message and get AI response"""
try:
# Get conversation history
history_resp = await http_client.get(f"{RECALL_URL}/messages/history")
history = history_resp.json().get("history", [])
# Add user message to history
history.append({"role": "user", "content": req.message})
# Process with reason service
reason_resp = await http_client.post(
f"{REASON_URL}/process",
json={
"model": req.model,
"history": history,
"max_iterations": req.max_iterations
}
)
if reason_resp.status_code != 200:
raise HTTPException(
status_code=502,
detail={"error": "backend_error", "message": "Reason service error"}
)
return reason_resp.json()
except httpx.RequestError as e:
raise HTTPException(
status_code=503,
detail={"error": "backend_unavailable", "message": str(e)}
)
@app.get("/v1/history")
async def get_history(
limit: int = 50,
before: Optional[int] = None,
client: dict = Depends(require_scope("history"))
):
"""Get message history with pagination"""
params = {"limit": min(limit, 100)}
if before:
params["before"] = before
try:
resp = await http_client.get(f"{RECALL_URL}/messages", params=params)
return resp.json()
except httpx.RequestError as e:
raise HTTPException(
status_code=503,
detail={"error": "backend_unavailable", "message": str(e)}
)
@app.get("/v1/history/search")
async def search_history(
q: str,
limit: int = 20,
client: dict = Depends(require_scope("history"))
):
"""Search message history"""
try:
resp = await http_client.get(
f"{RECALL_URL}/messages/search",
params={"q": q, "limit": limit}
)
return resp.json()
except httpx.RequestError as e:
raise HTTPException(
status_code=503,
detail={"error": "backend_unavailable", "message": str(e)}
)
@app.get("/v1/tools")
async def get_tools(client: dict = Depends(require_scope("tools"))):
"""Get available AI tools"""
try:
resp = await http_client.get(f"{REASON_URL}/tools")
return resp.json()
except httpx.RequestError as e:
raise HTTPException(
status_code=503,
detail={"error": "backend_unavailable", "message": str(e)}
)
# Admin endpoints for API key management
class CreateClientRequest(BaseModel):
name: str
scopes: list[str] = ["chat", "history"]
rate_limit: int = 100
class UpdateClientRequest(BaseModel):
name: Optional[str] = None
scopes: Optional[list[str]] = None
rate_limit: Optional[int] = None
enabled: Optional[bool] = None
@app.post("/admin/clients")
async def admin_create_client(req: CreateClientRequest):
"""Create a new API client"""
client_id, api_key = await create_client(
name=req.name,
scopes=req.scopes,
rate_limit=req.rate_limit
)
return {
"client_id": client_id,
"api_key": api_key, # Only shown once!
"name": req.name,
"scopes": req.scopes
}
@app.get("/admin/clients")
async def admin_list_clients(include_disabled: bool = False):
"""List all API clients"""
clients = await list_clients(include_disabled=include_disabled)
return {"clients": clients}
@app.get("/admin/clients/{client_id}")
async def admin_get_client(client_id: str):
"""Get a specific client by ID"""
client = await get_client_by_id(client_id)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
return client
@app.patch("/admin/clients/{client_id}")
async def admin_update_client(client_id: str, req: UpdateClientRequest):
"""Update a client's settings"""
success = await update_client(
client_id=client_id,
name=req.name,
scopes=req.scopes,
rate_limit=req.rate_limit,
enabled=req.enabled
)
if not success:
raise HTTPException(status_code=404, detail="Client not found")
return {"status": "updated"}
@app.delete("/admin/clients/{client_id}")
async def admin_disable_client(client_id: str):
"""Disable a client (soft delete)"""
success = await disable_client(client_id)
if not success:
raise HTTPException(status_code=404, detail="Client not found")
return {"status": "disabled"}
@app.post("/admin/clients/{client_id}/regenerate-key")
async def admin_regenerate_key(client_id: str):
"""Regenerate API key for a client"""
new_key = await regenerate_api_key(client_id)
if not new_key:
raise HTTPException(status_code=404, detail="Client not found")
return {"api_key": new_key}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=GATEWAY_PORT)