315 lines
8.5 KiB
Python
315 lines
8.5 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Egregore Relay Service - API Gateway
|
||
|
|
|
||
|
|
Relays authenticated API requests to backend services.
|
||
|
|
Runs on port 8085.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import secrets
|
||
|
|
from datetime import datetime, timedelta, timezone
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from fastapi import FastAPI, HTTPException, Depends, Request
|
||
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from dotenv import load_dotenv
|
||
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||
|
|
from slowapi.util import get_remote_address
|
||
|
|
from slowapi.errors import RateLimitExceeded
|
||
|
|
import httpx
|
||
|
|
import jwt
|
||
|
|
import bcrypt
|
||
|
|
|
||
|
|
# Load environment
|
||
|
|
load_dotenv("/home/admin/.env")
|
||
|
|
|
||
|
|
JWT_SECRET = os.getenv("JWT_SECRET", secrets.token_hex(32))
|
||
|
|
JWT_EXPIRY = int(os.getenv("JWT_EXPIRY", "3600"))
|
||
|
|
GATEWAY_PORT = int(os.getenv("GATEWAY_PORT", "8085"))
|
||
|
|
|
||
|
|
# Backend service URLs
|
||
|
|
REASON_URL = os.getenv("REASON_URL", "http://127.0.0.1:8081")
|
||
|
|
RECALL_URL = os.getenv("RECALL_URL", "http://127.0.0.1:8082")
|
||
|
|
CONVERSE_URL = os.getenv("CONVERSE_URL", "http://127.0.0.1:8080")
|
||
|
|
|
||
|
|
# Rate limiting
|
||
|
|
limiter = Limiter(key_func=get_remote_address)
|
||
|
|
|
||
|
|
app = FastAPI(
|
||
|
|
title="Egregore Relay Service",
|
||
|
|
version="1.0.0",
|
||
|
|
docs_url="/docs"
|
||
|
|
)
|
||
|
|
app.state.limiter = limiter
|
||
|
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||
|
|
|
||
|
|
# Security
|
||
|
|
bearer_scheme = HTTPBearer(auto_error=False)
|
||
|
|
|
||
|
|
# HTTP client for backend services
|
||
|
|
http_client = httpx.AsyncClient(timeout=120.0)
|
||
|
|
|
||
|
|
|
||
|
|
# In-memory API key store (replace with DB in production)
|
||
|
|
# Format: api_key_hash -> {client_id, scopes, rate_limit}
|
||
|
|
API_CLIENTS = {}
|
||
|
|
|
||
|
|
|
||
|
|
# Request/Response models
|
||
|
|
class TokenRequest(BaseModel):
|
||
|
|
api_key: str
|
||
|
|
|
||
|
|
|
||
|
|
class TokenResponse(BaseModel):
|
||
|
|
token: str
|
||
|
|
expires_in: int
|
||
|
|
token_type: str = "Bearer"
|
||
|
|
|
||
|
|
|
||
|
|
class ChatRequest(BaseModel):
|
||
|
|
message: str
|
||
|
|
model: str = "claude-sonnet-4-20250514"
|
||
|
|
max_iterations: int = 10
|
||
|
|
|
||
|
|
|
||
|
|
class ErrorResponse(BaseModel):
|
||
|
|
error: str
|
||
|
|
message: str
|
||
|
|
details: Optional[dict] = None
|
||
|
|
|
||
|
|
|
||
|
|
# Helper functions
|
||
|
|
def verify_api_key(api_key: str) -> Optional[dict]:
|
||
|
|
"""Verify API key and return client info"""
|
||
|
|
for key_hash, client in API_CLIENTS.items():
|
||
|
|
if bcrypt.checkpw(api_key.encode(), key_hash.encode()):
|
||
|
|
return client
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def create_jwt(client_id: str, scopes: list) -> str:
|
||
|
|
"""Create a JWT token for the client"""
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
payload = {
|
||
|
|
"sub": client_id,
|
||
|
|
"iss": "egregore",
|
||
|
|
"iat": now,
|
||
|
|
"exp": now + timedelta(seconds=JWT_EXPIRY),
|
||
|
|
"scope": scopes
|
||
|
|
}
|
||
|
|
return jwt.encode(payload, JWT_SECRET, algorithm="HS256")
|
||
|
|
|
||
|
|
|
||
|
|
def verify_jwt(token: str) -> Optional[dict]:
|
||
|
|
"""Verify JWT and return payload"""
|
||
|
|
try:
|
||
|
|
payload = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
|
||
|
|
return payload
|
||
|
|
except jwt.ExpiredSignatureError:
|
||
|
|
return None
|
||
|
|
except jwt.InvalidTokenError:
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
async def get_current_client(
|
||
|
|
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
|
||
|
|
) -> dict:
|
||
|
|
"""Dependency to get current authenticated client from JWT"""
|
||
|
|
if not credentials:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=401,
|
||
|
|
detail={"error": "invalid_token", "message": "Missing authorization header"}
|
||
|
|
)
|
||
|
|
|
||
|
|
payload = verify_jwt(credentials.credentials)
|
||
|
|
if not payload:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=401,
|
||
|
|
detail={"error": "invalid_token", "message": "Invalid or expired token"}
|
||
|
|
)
|
||
|
|
|
||
|
|
return payload
|
||
|
|
|
||
|
|
|
||
|
|
def require_scope(required: str):
|
||
|
|
"""Dependency factory to check for required scope"""
|
||
|
|
async def check_scope(client: dict = Depends(get_current_client)):
|
||
|
|
scopes = client.get("scope", [])
|
||
|
|
if "*" in scopes or required in scopes:
|
||
|
|
return client
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=403,
|
||
|
|
detail={"error": "insufficient_scope", "message": f"Requires '{required}' scope"}
|
||
|
|
)
|
||
|
|
return check_scope
|
||
|
|
|
||
|
|
|
||
|
|
# Public endpoints
|
||
|
|
@app.get("/health")
|
||
|
|
async def health():
|
||
|
|
"""Health check with backend status"""
|
||
|
|
reason_ok = False
|
||
|
|
recall_ok = False
|
||
|
|
|
||
|
|
try:
|
||
|
|
resp = await http_client.get(f"{REASON_URL}/health", timeout=2.0)
|
||
|
|
reason_ok = resp.status_code == 200
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
|
||
|
|
try:
|
||
|
|
resp = await http_client.get(f"{RECALL_URL}/health", timeout=2.0)
|
||
|
|
recall_ok = resp.status_code == 200
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
|
||
|
|
return {
|
||
|
|
"status": "ok" if (reason_ok and recall_ok) else "degraded",
|
||
|
|
"service": "relay",
|
||
|
|
"backends": {
|
||
|
|
"reason": "ok" if reason_ok else "unavailable",
|
||
|
|
"recall": "ok" if recall_ok else "unavailable"
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/auth/token", response_model=TokenResponse)
|
||
|
|
async def get_token(req: TokenRequest):
|
||
|
|
"""Exchange API key for JWT token"""
|
||
|
|
client = verify_api_key(req.api_key)
|
||
|
|
if not client:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=401,
|
||
|
|
detail={"error": "invalid_api_key", "message": "API key not found or revoked"}
|
||
|
|
)
|
||
|
|
|
||
|
|
token = create_jwt(client["client_id"], client["scopes"])
|
||
|
|
return TokenResponse(token=token, expires_in=JWT_EXPIRY)
|
||
|
|
|
||
|
|
|
||
|
|
# Protected endpoints
|
||
|
|
@app.post("/v1/chat")
|
||
|
|
@limiter.limit("10/minute")
|
||
|
|
async def chat(
|
||
|
|
request: Request,
|
||
|
|
req: ChatRequest,
|
||
|
|
client: dict = Depends(require_scope("chat"))
|
||
|
|
):
|
||
|
|
"""Send a message and get AI response"""
|
||
|
|
try:
|
||
|
|
# Get conversation history
|
||
|
|
history_resp = await http_client.get(f"{RECALL_URL}/messages/history")
|
||
|
|
history = history_resp.json().get("history", [])
|
||
|
|
|
||
|
|
# Add user message to history
|
||
|
|
history.append({"role": "user", "content": req.message})
|
||
|
|
|
||
|
|
# Process with reason service
|
||
|
|
reason_resp = await http_client.post(
|
||
|
|
f"{REASON_URL}/process",
|
||
|
|
json={
|
||
|
|
"model": req.model,
|
||
|
|
"history": history,
|
||
|
|
"max_iterations": req.max_iterations
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
if reason_resp.status_code != 200:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=502,
|
||
|
|
detail={"error": "backend_error", "message": "Reason service error"}
|
||
|
|
)
|
||
|
|
|
||
|
|
return reason_resp.json()
|
||
|
|
|
||
|
|
except httpx.RequestError as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=503,
|
||
|
|
detail={"error": "backend_unavailable", "message": str(e)}
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/v1/history")
|
||
|
|
async def get_history(
|
||
|
|
limit: int = 50,
|
||
|
|
before: Optional[int] = None,
|
||
|
|
client: dict = Depends(require_scope("history"))
|
||
|
|
):
|
||
|
|
"""Get message history with pagination"""
|
||
|
|
params = {"limit": min(limit, 100)}
|
||
|
|
if before:
|
||
|
|
params["before"] = before
|
||
|
|
|
||
|
|
try:
|
||
|
|
resp = await http_client.get(f"{RECALL_URL}/messages", params=params)
|
||
|
|
return resp.json()
|
||
|
|
except httpx.RequestError as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=503,
|
||
|
|
detail={"error": "backend_unavailable", "message": str(e)}
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/v1/history/search")
|
||
|
|
async def search_history(
|
||
|
|
q: str,
|
||
|
|
limit: int = 20,
|
||
|
|
client: dict = Depends(require_scope("history"))
|
||
|
|
):
|
||
|
|
"""Search message history"""
|
||
|
|
try:
|
||
|
|
resp = await http_client.get(
|
||
|
|
f"{RECALL_URL}/messages/search",
|
||
|
|
params={"q": q, "limit": limit}
|
||
|
|
)
|
||
|
|
return resp.json()
|
||
|
|
except httpx.RequestError as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=503,
|
||
|
|
detail={"error": "backend_unavailable", "message": str(e)}
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/v1/tools")
|
||
|
|
async def get_tools(client: dict = Depends(require_scope("tools"))):
|
||
|
|
"""Get available AI tools"""
|
||
|
|
try:
|
||
|
|
resp = await http_client.get(f"{REASON_URL}/tools")
|
||
|
|
return resp.json()
|
||
|
|
except httpx.RequestError as e:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=503,
|
||
|
|
detail={"error": "backend_unavailable", "message": str(e)}
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# Admin endpoint to register API keys (temporary, move to proper admin)
|
||
|
|
@app.post("/admin/clients")
|
||
|
|
async def register_client(
|
||
|
|
client_id: str,
|
||
|
|
scopes: list = ["chat", "history"],
|
||
|
|
):
|
||
|
|
"""Register a new API client (temporary admin endpoint)"""
|
||
|
|
# Generate API key
|
||
|
|
api_key = f"eg_{secrets.token_hex(20)}"
|
||
|
|
key_hash = bcrypt.hashpw(api_key.encode(), bcrypt.gensalt()).decode()
|
||
|
|
|
||
|
|
API_CLIENTS[key_hash] = {
|
||
|
|
"client_id": client_id,
|
||
|
|
"scopes": scopes,
|
||
|
|
"rate_limit": 100
|
||
|
|
}
|
||
|
|
|
||
|
|
return {
|
||
|
|
"client_id": client_id,
|
||
|
|
"api_key": api_key, # Only shown once!
|
||
|
|
"scopes": scopes
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import uvicorn
|
||
|
|
uvicorn.run(app, host="127.0.0.1", port=GATEWAY_PORT)
|