Port relay service from Python to Rust + Rouille
- PostgreSQL with chrono datetime support - JWT authentication (jsonwebtoken crate) - bcrypt API key hashing - In-memory rate limiter (10 req/min for chat) - HTTP proxy to reason/recall services - ~3.8MB binary vs ~50MB Python Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
8ad87ddaaa
commit
fd3802d749
3 changed files with 2927 additions and 0 deletions
2040
Cargo.lock
generated
Normal file
2040
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
21
Cargo.toml
Normal file
21
Cargo.toml
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
[package]
|
||||||
|
name = "relay"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
rouille = "3.6"
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
postgres = { version = "0.19", features = ["with-chrono-0_4"] }
|
||||||
|
jsonwebtoken = "9.2"
|
||||||
|
bcrypt = "0.15"
|
||||||
|
ureq = { version = "2.9", features = ["json"] }
|
||||||
|
uuid = { version = "1.0", features = ["v4"] }
|
||||||
|
rand = "0.8"
|
||||||
|
hex = "0.4"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
strip = true
|
||||||
|
lto = true
|
||||||
866
src/main.rs
Normal file
866
src/main.rs
Normal file
|
|
@ -0,0 +1,866 @@
|
||||||
|
//! Egregore Relay Service - API Gateway
|
||||||
|
//!
|
||||||
|
//! Relays authenticated API requests to backend services.
|
||||||
|
//! Runs on port 8085.
|
||||||
|
|
||||||
|
use chrono::{Duration, Utc};
|
||||||
|
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||||
|
use postgres::{Client, NoTls};
|
||||||
|
use rand::Rng;
|
||||||
|
use rouille::{router, Request, Response};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::env;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use std::time::{Duration as StdDuration, Instant};
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Configuration
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
fn get_env(key: &str, default: &str) -> String {
|
||||||
|
env::var(key).unwrap_or_else(|_| default.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
static ref DATABASE_URL: String = get_env("DATABASE_URL", "postgresql://egregore:egregore_db_pass@localhost/egregore");
|
||||||
|
static ref JWT_SECRET: String = get_env("JWT_SECRET", "egregore-jwt-secret-change-in-production");
|
||||||
|
static ref JWT_EXPIRY: i64 = get_env("JWT_EXPIRY", "3600").parse().unwrap_or(3600);
|
||||||
|
static ref REASON_URL: String = get_env("REASON_URL", "http://127.0.0.1:8081");
|
||||||
|
static ref RECALL_URL: String = get_env("RECALL_URL", "http://127.0.0.1:8082");
|
||||||
|
}
|
||||||
|
|
||||||
|
const BIND_ADDR: &str = "127.0.0.1:8085";
|
||||||
|
const DEFAULT_MODEL: &str = "claude-sonnet-4-20250514";
|
||||||
|
|
||||||
|
// Simple lazy_static replacement using std
|
||||||
|
mod lazy_static {
|
||||||
|
macro_rules! lazy_static {
|
||||||
|
($(static ref $name:ident: $ty:ty = $init:expr;)*) => {
|
||||||
|
$(
|
||||||
|
static $name: std::sync::LazyLock<$ty> = std::sync::LazyLock::new(|| $init);
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
pub(crate) use lazy_static;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Data Structures
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct Claims {
|
||||||
|
sub: String, // client_id
|
||||||
|
iss: String, // "egregore"
|
||||||
|
iat: i64, // issued at
|
||||||
|
exp: i64, // expiration
|
||||||
|
scope: Vec<String>, // scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct ApiClient {
|
||||||
|
id: i32,
|
||||||
|
client_id: String,
|
||||||
|
name: Option<String>,
|
||||||
|
scopes: Vec<String>,
|
||||||
|
rate_limit: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct TokenRequest {
|
||||||
|
api_key: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct TokenResponse {
|
||||||
|
token: String,
|
||||||
|
expires_in: i64,
|
||||||
|
token_type: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ChatRequest {
|
||||||
|
message: String,
|
||||||
|
#[serde(default = "default_model")]
|
||||||
|
model: String,
|
||||||
|
#[serde(default = "default_max_iterations")]
|
||||||
|
max_iterations: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_model() -> String {
|
||||||
|
DEFAULT_MODEL.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_max_iterations() -> i32 {
|
||||||
|
10
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct CreateClientRequest {
|
||||||
|
name: String,
|
||||||
|
#[serde(default = "default_scopes")]
|
||||||
|
scopes: Vec<String>,
|
||||||
|
#[serde(default = "default_rate_limit")]
|
||||||
|
rate_limit: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_scopes() -> Vec<String> {
|
||||||
|
vec!["chat".to_string(), "history".to_string()]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_rate_limit() -> i32 {
|
||||||
|
100
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Default)]
|
||||||
|
struct UpdateClientRequest {
|
||||||
|
name: Option<String>,
|
||||||
|
scopes: Option<Vec<String>>,
|
||||||
|
rate_limit: Option<i32>,
|
||||||
|
enabled: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Application State
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
struct AppState {
|
||||||
|
db: Mutex<Client>,
|
||||||
|
rate_limiter: RateLimiter,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Rate Limiter
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
struct RateLimiter {
|
||||||
|
requests: Mutex<HashMap<String, Vec<Instant>>>,
|
||||||
|
window: StdDuration,
|
||||||
|
max_requests: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimiter {
|
||||||
|
fn new(max_requests: usize, window_secs: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
requests: Mutex::new(HashMap::new()),
|
||||||
|
window: StdDuration::from_secs(window_secs),
|
||||||
|
max_requests,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check(&self, key: &str) -> bool {
|
||||||
|
let mut requests = self.requests.lock().unwrap();
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
let entry = requests.entry(key.to_string()).or_default();
|
||||||
|
|
||||||
|
// Remove expired entries
|
||||||
|
entry.retain(|t| now.duration_since(*t) < self.window);
|
||||||
|
|
||||||
|
if entry.len() >= self.max_requests {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.push(now);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Database Operations
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
fn init_db(client: &mut Client) -> Result<(), postgres::Error> {
|
||||||
|
client.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS api_clients (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
client_id VARCHAR(64) UNIQUE NOT NULL,
|
||||||
|
api_key_hash VARCHAR(128) NOT NULL,
|
||||||
|
name VARCHAR(255),
|
||||||
|
scopes TEXT[] DEFAULT ARRAY['chat', 'history'],
|
||||||
|
rate_limit INT DEFAULT 100,
|
||||||
|
enabled BOOLEAN DEFAULT TRUE,
|
||||||
|
expires_at TIMESTAMPTZ,
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
last_used_at TIMESTAMPTZ
|
||||||
|
)",
|
||||||
|
&[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
client.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_api_clients_client_id ON api_clients(client_id)",
|
||||||
|
&[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
client.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_api_clients_enabled ON api_clients(enabled)",
|
||||||
|
&[],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// API Key Management
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
fn generate_api_key() -> String {
|
||||||
|
let bytes: [u8; 20] = rand::thread_rng().gen();
|
||||||
|
format!("eg_{}", hex::encode(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hash_api_key(api_key: &str) -> String {
|
||||||
|
bcrypt::hash(api_key, bcrypt::DEFAULT_COST).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify_api_key(api_key: &str, hash: &str) -> bool {
|
||||||
|
bcrypt::verify(api_key, hash).unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_client_by_api_key(client: &mut Client, api_key: &str) -> Option<ApiClient> {
|
||||||
|
if !api_key.starts_with("eg_") {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must fetch all enabled clients and check each hash (bcrypt salts are unique)
|
||||||
|
let rows = client
|
||||||
|
.query(
|
||||||
|
"SELECT id, client_id, api_key_hash, name, scopes, rate_limit, expires_at
|
||||||
|
FROM api_clients WHERE enabled = TRUE",
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
for row in rows {
|
||||||
|
let hash: String = row.get("api_key_hash");
|
||||||
|
if verify_api_key(api_key, &hash) {
|
||||||
|
// Check expiry
|
||||||
|
let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
|
||||||
|
if let Some(exp) = expires_at {
|
||||||
|
if exp < Utc::now() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update last_used_at
|
||||||
|
let id: i32 = row.get("id");
|
||||||
|
let _ = client.execute(
|
||||||
|
"UPDATE api_clients SET last_used_at = NOW() WHERE id = $1",
|
||||||
|
&[&id],
|
||||||
|
);
|
||||||
|
|
||||||
|
let scopes: Vec<String> = row.get("scopes");
|
||||||
|
return Some(ApiClient {
|
||||||
|
id,
|
||||||
|
client_id: row.get("client_id"),
|
||||||
|
name: row.get("name"),
|
||||||
|
scopes,
|
||||||
|
rate_limit: row.get("rate_limit"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_client(
|
||||||
|
client: &mut Client,
|
||||||
|
name: &str,
|
||||||
|
scopes: &[String],
|
||||||
|
rate_limit: i32,
|
||||||
|
) -> Result<(String, String), postgres::Error> {
|
||||||
|
let client_id = format!("client_{}", hex::encode(rand::thread_rng().gen::<[u8; 8]>()));
|
||||||
|
let api_key = generate_api_key();
|
||||||
|
let key_hash = hash_api_key(&api_key);
|
||||||
|
|
||||||
|
client.execute(
|
||||||
|
"INSERT INTO api_clients (client_id, api_key_hash, name, scopes, rate_limit)
|
||||||
|
VALUES ($1, $2, $3, $4, $5)",
|
||||||
|
&[&client_id, &key_hash, &name, &scopes, &rate_limit],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok((client_id, api_key))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_client_by_id(client: &mut Client, client_id: &str) -> Option<Value> {
|
||||||
|
let row = client
|
||||||
|
.query_one(
|
||||||
|
"SELECT id, client_id, name, scopes, rate_limit, enabled, expires_at, created_at, last_used_at
|
||||||
|
FROM api_clients WHERE client_id = $1",
|
||||||
|
&[&client_id],
|
||||||
|
)
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
let scopes: Vec<String> = row.get("scopes");
|
||||||
|
let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
|
||||||
|
let created_at: Option<chrono::DateTime<Utc>> = row.get("created_at");
|
||||||
|
let last_used_at: Option<chrono::DateTime<Utc>> = row.get("last_used_at");
|
||||||
|
|
||||||
|
Some(json!({
|
||||||
|
"id": row.get::<_, i32>("id"),
|
||||||
|
"client_id": row.get::<_, String>("client_id"),
|
||||||
|
"name": row.get::<_, Option<String>>("name"),
|
||||||
|
"scopes": scopes,
|
||||||
|
"rate_limit": row.get::<_, i32>("rate_limit"),
|
||||||
|
"enabled": row.get::<_, bool>("enabled"),
|
||||||
|
"expires_at": expires_at.map(|t| t.to_rfc3339()),
|
||||||
|
"created_at": created_at.map(|t| t.to_rfc3339()),
|
||||||
|
"last_used_at": last_used_at.map(|t| t.to_rfc3339())
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_clients(client: &mut Client, include_disabled: bool) -> Vec<Value> {
|
||||||
|
let query = if include_disabled {
|
||||||
|
"SELECT id, client_id, name, scopes, rate_limit, enabled, expires_at, created_at, last_used_at
|
||||||
|
FROM api_clients ORDER BY created_at DESC"
|
||||||
|
} else {
|
||||||
|
"SELECT id, client_id, name, scopes, rate_limit, enabled, expires_at, created_at, last_used_at
|
||||||
|
FROM api_clients WHERE enabled = TRUE ORDER BY created_at DESC"
|
||||||
|
};
|
||||||
|
|
||||||
|
let rows = match client.query(query, &[]) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(_) => return vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
rows.iter()
|
||||||
|
.map(|row| {
|
||||||
|
let scopes: Vec<String> = row.get("scopes");
|
||||||
|
let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
|
||||||
|
let created_at: Option<chrono::DateTime<Utc>> = row.get("created_at");
|
||||||
|
let last_used_at: Option<chrono::DateTime<Utc>> = row.get("last_used_at");
|
||||||
|
|
||||||
|
json!({
|
||||||
|
"id": row.get::<_, i32>("id"),
|
||||||
|
"client_id": row.get::<_, String>("client_id"),
|
||||||
|
"name": row.get::<_, Option<String>>("name"),
|
||||||
|
"scopes": scopes,
|
||||||
|
"rate_limit": row.get::<_, i32>("rate_limit"),
|
||||||
|
"enabled": row.get::<_, bool>("enabled"),
|
||||||
|
"expires_at": expires_at.map(|t| t.to_rfc3339()),
|
||||||
|
"created_at": created_at.map(|t| t.to_rfc3339()),
|
||||||
|
"last_used_at": last_used_at.map(|t| t.to_rfc3339())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_client_db(
|
||||||
|
client: &mut Client,
|
||||||
|
client_id: &str,
|
||||||
|
name: Option<&str>,
|
||||||
|
scopes: Option<&[String]>,
|
||||||
|
rate_limit: Option<i32>,
|
||||||
|
enabled: Option<bool>,
|
||||||
|
) -> bool {
|
||||||
|
let mut updates = vec![];
|
||||||
|
let mut params: Vec<&(dyn postgres::types::ToSql + Sync)> = vec![];
|
||||||
|
let mut idx = 1;
|
||||||
|
|
||||||
|
// We need to store owned values
|
||||||
|
let name_owned: Option<String> = name.map(|s| s.to_string());
|
||||||
|
let scopes_owned: Option<Vec<String>> = scopes.map(|s| s.to_vec());
|
||||||
|
|
||||||
|
if let Some(ref n) = name_owned {
|
||||||
|
updates.push(format!("name = ${}", idx));
|
||||||
|
params.push(n);
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
if let Some(ref s) = scopes_owned {
|
||||||
|
updates.push(format!("scopes = ${}", idx));
|
||||||
|
params.push(s);
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
if let Some(ref r) = rate_limit {
|
||||||
|
updates.push(format!("rate_limit = ${}", idx));
|
||||||
|
params.push(r);
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
if let Some(ref e) = enabled {
|
||||||
|
updates.push(format!("enabled = ${}", idx));
|
||||||
|
params.push(e);
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if updates.is_empty() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let sql = format!(
|
||||||
|
"UPDATE api_clients SET {} WHERE client_id = ${}",
|
||||||
|
updates.join(", "),
|
||||||
|
idx
|
||||||
|
);
|
||||||
|
params.push(&client_id);
|
||||||
|
|
||||||
|
match client.execute(&sql, ¶ms) {
|
||||||
|
Ok(n) => n > 0,
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn disable_client_db(client: &mut Client, client_id: &str) -> bool {
|
||||||
|
match client.execute(
|
||||||
|
"UPDATE api_clients SET enabled = FALSE WHERE client_id = $1",
|
||||||
|
&[&client_id],
|
||||||
|
) {
|
||||||
|
Ok(n) => n > 0,
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn regenerate_api_key_db(client: &mut Client, client_id: &str) -> Option<String> {
|
||||||
|
let new_api_key = generate_api_key();
|
||||||
|
let key_hash = hash_api_key(&new_api_key);
|
||||||
|
|
||||||
|
match client.execute(
|
||||||
|
"UPDATE api_clients SET api_key_hash = $1 WHERE client_id = $2",
|
||||||
|
&[&key_hash, &client_id],
|
||||||
|
) {
|
||||||
|
Ok(n) if n > 0 => Some(new_api_key),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// JWT Operations
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
fn create_jwt(client_id: &str, scopes: &[String]) -> String {
|
||||||
|
let now = Utc::now().timestamp();
|
||||||
|
let claims = Claims {
|
||||||
|
sub: client_id.to_string(),
|
||||||
|
iss: "egregore".to_string(),
|
||||||
|
iat: now,
|
||||||
|
exp: now + *JWT_EXPIRY,
|
||||||
|
scope: scopes.to_vec(),
|
||||||
|
};
|
||||||
|
|
||||||
|
encode(
|
||||||
|
&Header::default(),
|
||||||
|
&claims,
|
||||||
|
&EncodingKey::from_secret(JWT_SECRET.as_bytes()),
|
||||||
|
)
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify_jwt(token: &str) -> Option<Claims> {
|
||||||
|
decode::<Claims>(
|
||||||
|
token,
|
||||||
|
&DecodingKey::from_secret(JWT_SECRET.as_bytes()),
|
||||||
|
&Validation::default(),
|
||||||
|
)
|
||||||
|
.ok()
|
||||||
|
.map(|data| data.claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_bearer_token(request: &Request) -> Option<String> {
|
||||||
|
request
|
||||||
|
.header("Authorization")
|
||||||
|
.and_then(|h| h.strip_prefix("Bearer "))
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// HTTP Proxy
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
fn proxy_get(url: &str) -> Response {
|
||||||
|
match ureq::get(url).timeout(StdDuration::from_secs(120)).call() {
|
||||||
|
Ok(resp) => {
|
||||||
|
let body: Value = resp.into_json().unwrap_or(json!({}));
|
||||||
|
Response::json(&body)
|
||||||
|
}
|
||||||
|
Err(ureq::Error::Status(code, resp)) => {
|
||||||
|
let body = resp.into_string().unwrap_or_default();
|
||||||
|
Response::text(body).with_status_code(code)
|
||||||
|
}
|
||||||
|
Err(e) => error_response(503, &format!("Backend unavailable: {}", e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn proxy_post(url: &str, body: &Value) -> Response {
|
||||||
|
match ureq::post(url)
|
||||||
|
.timeout(StdDuration::from_secs(120))
|
||||||
|
.send_json(body)
|
||||||
|
{
|
||||||
|
Ok(resp) => {
|
||||||
|
let body: Value = resp.into_json().unwrap_or(json!({}));
|
||||||
|
Response::json(&body)
|
||||||
|
}
|
||||||
|
Err(ureq::Error::Status(code, resp)) => {
|
||||||
|
let body = resp.into_string().unwrap_or_default();
|
||||||
|
Response::text(body).with_status_code(code)
|
||||||
|
}
|
||||||
|
Err(e) => error_response(503, &format!("Backend unavailable: {}", e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// HTTP Handlers
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
fn json_response<T: Serialize>(data: T) -> Response {
|
||||||
|
Response::json(&data)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn error_response(status: u16, message: &str) -> Response {
|
||||||
|
Response::json(&json!({ "error": message })).with_status_code(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn health(state: &AppState) -> Response {
|
||||||
|
let mut reason_ok = false;
|
||||||
|
let mut recall_ok = false;
|
||||||
|
|
||||||
|
if let Ok(resp) = ureq::get(&format!("{}/health", *REASON_URL))
|
||||||
|
.timeout(StdDuration::from_secs(2))
|
||||||
|
.call()
|
||||||
|
{
|
||||||
|
reason_ok = resp.status() == 200;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(resp) = ureq::get(&format!("{}/health", *RECALL_URL))
|
||||||
|
.timeout(StdDuration::from_secs(2))
|
||||||
|
.call()
|
||||||
|
{
|
||||||
|
recall_ok = resp.status() == 200;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check DB connection
|
||||||
|
let db_ok = state.db.lock().unwrap().execute("SELECT 1", &[]).is_ok();
|
||||||
|
|
||||||
|
json_response(json!({
|
||||||
|
"status": if reason_ok && recall_ok && db_ok { "ok" } else { "degraded" },
|
||||||
|
"service": "relay",
|
||||||
|
"backends": {
|
||||||
|
"reason": if reason_ok { "ok" } else { "unavailable" },
|
||||||
|
"recall": if recall_ok { "ok" } else { "unavailable" },
|
||||||
|
"database": if db_ok { "ok" } else { "unavailable" }
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_token(request: &Request, state: &AppState) -> Response {
|
||||||
|
let body: TokenRequest = match rouille::input::json_input(request) {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => return error_response(400, &format!("Invalid JSON: {}", e)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut db = state.db.lock().unwrap();
|
||||||
|
match get_client_by_api_key(&mut db, &body.api_key) {
|
||||||
|
Some(client) => {
|
||||||
|
let token = create_jwt(&client.client_id, &client.scopes);
|
||||||
|
json_response(TokenResponse {
|
||||||
|
token,
|
||||||
|
expires_in: *JWT_EXPIRY,
|
||||||
|
token_type: "Bearer".to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
None => error_response(
|
||||||
|
401,
|
||||||
|
"API key not found or revoked",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth middleware helper
|
||||||
|
fn require_auth(request: &Request, required_scope: &str) -> Result<Claims, Response> {
|
||||||
|
let token = match extract_bearer_token(request) {
|
||||||
|
Some(t) => t,
|
||||||
|
None => return Err(error_response(401, "Missing authorization header")),
|
||||||
|
};
|
||||||
|
|
||||||
|
let claims = match verify_jwt(&token) {
|
||||||
|
Some(c) => c,
|
||||||
|
None => return Err(error_response(401, "Invalid or expired token")),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check scope
|
||||||
|
if !claims.scope.contains(&"*".to_string()) && !claims.scope.contains(&required_scope.to_string())
|
||||||
|
{
|
||||||
|
return Err(error_response(
|
||||||
|
403,
|
||||||
|
&format!("Requires '{}' scope", required_scope),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn chat(request: &Request, state: &AppState) -> Response {
|
||||||
|
// Auth check
|
||||||
|
let claims = match require_auth(request, "chat") {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(r) => return r,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Rate limit check
|
||||||
|
if !state.rate_limiter.check(&claims.sub) {
|
||||||
|
return error_response(429, "Rate limit exceeded: 10 requests per minute");
|
||||||
|
}
|
||||||
|
|
||||||
|
let body: ChatRequest = match rouille::input::json_input(request) {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => return error_response(400, &format!("Invalid JSON: {}", e)),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get conversation history
|
||||||
|
let history_resp = match ureq::get(&format!("{}/messages/history", *RECALL_URL))
|
||||||
|
.timeout(StdDuration::from_secs(10))
|
||||||
|
.call()
|
||||||
|
{
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => return error_response(503, &format!("Recall service unavailable: {}", e)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let history_data: Value = history_resp.into_json().unwrap_or(json!({}));
|
||||||
|
let mut history: Vec<Value> = history_data
|
||||||
|
.get("history")
|
||||||
|
.and_then(|h| h.as_array())
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
// Add user message
|
||||||
|
history.push(json!({
|
||||||
|
"role": "user",
|
||||||
|
"content": body.message
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Process with reason service
|
||||||
|
let payload = json!({
|
||||||
|
"model": body.model,
|
||||||
|
"history": history,
|
||||||
|
"max_iterations": body.max_iterations
|
||||||
|
});
|
||||||
|
|
||||||
|
proxy_post(&format!("{}/process", *REASON_URL), &payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_history(request: &Request) -> Response {
|
||||||
|
if let Err(r) = require_auth(request, "history") {
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
let limit = request
|
||||||
|
.get_param("limit")
|
||||||
|
.and_then(|s| s.parse::<i32>().ok())
|
||||||
|
.unwrap_or(50)
|
||||||
|
.min(100);
|
||||||
|
|
||||||
|
let before = request.get_param("before");
|
||||||
|
|
||||||
|
let mut url = format!("{}/messages?limit={}", *RECALL_URL, limit);
|
||||||
|
if let Some(b) = before {
|
||||||
|
url.push_str(&format!("&before={}", b));
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy_get(&url)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn search_history(request: &Request) -> Response {
|
||||||
|
if let Err(r) = require_auth(request, "history") {
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
let q = match request.get_param("q") {
|
||||||
|
Some(q) => q,
|
||||||
|
None => return error_response(400, "Missing 'q' parameter"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let limit = request
|
||||||
|
.get_param("limit")
|
||||||
|
.and_then(|s| s.parse::<i32>().ok())
|
||||||
|
.unwrap_or(20);
|
||||||
|
|
||||||
|
let url = format!(
|
||||||
|
"{}/messages/search?q={}&limit={}",
|
||||||
|
*RECALL_URL,
|
||||||
|
urlencoding::encode(&q),
|
||||||
|
limit
|
||||||
|
);
|
||||||
|
|
||||||
|
proxy_get(&url)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_tools(request: &Request) -> Response {
|
||||||
|
if let Err(r) = require_auth(request, "tools") {
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy_get(&format!("{}/tools", *REASON_URL))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admin endpoints
|
||||||
|
fn admin_create_client(request: &Request, state: &AppState) -> Response {
|
||||||
|
let body: CreateClientRequest = match rouille::input::json_input(request) {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => return error_response(400, &format!("Invalid JSON: {}", e)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut db = state.db.lock().unwrap();
|
||||||
|
match create_client(&mut db, &body.name, &body.scopes, body.rate_limit) {
|
||||||
|
Ok((client_id, api_key)) => Response::json(&json!({
|
||||||
|
"client_id": client_id,
|
||||||
|
"api_key": api_key,
|
||||||
|
"name": body.name,
|
||||||
|
"scopes": body.scopes
|
||||||
|
}))
|
||||||
|
.with_status_code(201),
|
||||||
|
Err(e) => error_response(500, &format!("Failed to create client: {}", e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn admin_list_clients(request: &Request, state: &AppState) -> Response {
|
||||||
|
let include_disabled = request
|
||||||
|
.get_param("include_disabled")
|
||||||
|
.map(|s| s == "true")
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let mut db = state.db.lock().unwrap();
|
||||||
|
let clients = list_clients(&mut db, include_disabled);
|
||||||
|
json_response(json!({ "clients": clients }))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn admin_get_client(client_id: &str, state: &AppState) -> Response {
|
||||||
|
let mut db = state.db.lock().unwrap();
|
||||||
|
match get_client_by_id(&mut db, client_id) {
|
||||||
|
Some(client) => json_response(client),
|
||||||
|
None => error_response(404, "Client not found"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn admin_update_client(client_id: &str, request: &Request, state: &AppState) -> Response {
|
||||||
|
let body: UpdateClientRequest = match rouille::input::json_input(request) {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => return error_response(400, &format!("Invalid JSON: {}", e)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut db = state.db.lock().unwrap();
|
||||||
|
let success = update_client_db(
|
||||||
|
&mut db,
|
||||||
|
client_id,
|
||||||
|
body.name.as_deref(),
|
||||||
|
body.scopes.as_deref(),
|
||||||
|
body.rate_limit,
|
||||||
|
body.enabled,
|
||||||
|
);
|
||||||
|
|
||||||
|
if success {
|
||||||
|
json_response(json!({ "status": "updated" }))
|
||||||
|
} else {
|
||||||
|
error_response(404, "Client not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn admin_disable_client(client_id: &str, state: &AppState) -> Response {
|
||||||
|
let mut db = state.db.lock().unwrap();
|
||||||
|
if disable_client_db(&mut db, client_id) {
|
||||||
|
json_response(json!({ "status": "disabled" }))
|
||||||
|
} else {
|
||||||
|
error_response(404, "Client not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn admin_regenerate_key(client_id: &str, state: &AppState) -> Response {
|
||||||
|
let mut db = state.db.lock().unwrap();
|
||||||
|
match regenerate_api_key_db(&mut db, client_id) {
|
||||||
|
Some(new_key) => json_response(json!({ "api_key": new_key })),
|
||||||
|
None => error_response(404, "Client not found"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Router
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
fn handle_request(request: &Request, state: &AppState) -> Response {
|
||||||
|
// Handle CORS preflight
|
||||||
|
if request.method() == "OPTIONS" {
|
||||||
|
return Response::empty_204();
|
||||||
|
}
|
||||||
|
|
||||||
|
router!(request,
|
||||||
|
// Public
|
||||||
|
(GET) ["/health"] => { health(state) },
|
||||||
|
(POST) ["/auth/token"] => { get_token(request, state) },
|
||||||
|
|
||||||
|
// Protected v1 endpoints
|
||||||
|
(POST) ["/v1/chat"] => { chat(request, state) },
|
||||||
|
(GET) ["/v1/history"] => { get_history(request) },
|
||||||
|
(GET) ["/v1/history/search"] => { search_history(request) },
|
||||||
|
(GET) ["/v1/tools"] => { get_tools(request) },
|
||||||
|
|
||||||
|
// Admin endpoints
|
||||||
|
(POST) ["/admin/clients"] => { admin_create_client(request, state) },
|
||||||
|
(GET) ["/admin/clients"] => { admin_list_clients(request, state) },
|
||||||
|
(GET) ["/admin/clients/{id}", id: String] => { admin_get_client(&id, state) },
|
||||||
|
(PATCH) ["/admin/clients/{id}", id: String] => { admin_update_client(&id, request, state) },
|
||||||
|
(DELETE) ["/admin/clients/{id}", id: String] => { admin_disable_client(&id, state) },
|
||||||
|
(POST) ["/admin/clients/{id}/regenerate-key", id: String] => { admin_regenerate_key(&id, state) },
|
||||||
|
|
||||||
|
_ => Response::empty_404()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
println!("Egregore Relay Service starting on {}", BIND_ADDR);
|
||||||
|
|
||||||
|
// Connect to database
|
||||||
|
println!("[relay] Connecting to database...");
|
||||||
|
let mut db_client = match Client::connect(&DATABASE_URL, NoTls) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[relay] Failed to connect to database: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize database schema
|
||||||
|
if let Err(e) = init_db(&mut db_client) {
|
||||||
|
eprintln!("[relay] Failed to initialize database: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
println!("[relay] Database initialized");
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
db: Mutex::new(db_client),
|
||||||
|
rate_limiter: RateLimiter::new(10, 60), // 10 requests per 60 seconds
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start HTTP server
|
||||||
|
rouille::start_server(BIND_ADDR, move |request| {
|
||||||
|
let response = handle_request(request, &state);
|
||||||
|
|
||||||
|
// Add CORS headers
|
||||||
|
response
|
||||||
|
.with_additional_header("Access-Control-Allow-Origin", "*")
|
||||||
|
.with_additional_header(
|
||||||
|
"Access-Control-Allow-Methods",
|
||||||
|
"GET, POST, PATCH, DELETE, OPTIONS",
|
||||||
|
)
|
||||||
|
.with_additional_header(
|
||||||
|
"Access-Control-Allow-Headers",
|
||||||
|
"Content-Type, Authorization",
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// URL encoding helper
|
||||||
|
mod urlencoding {
|
||||||
|
pub fn encode(s: &str) -> String {
|
||||||
|
let mut result = String::new();
|
||||||
|
for c in s.chars() {
|
||||||
|
match c {
|
||||||
|
'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' | '.' | '~' => result.push(c),
|
||||||
|
_ => {
|
||||||
|
for b in c.to_string().as_bytes() {
|
||||||
|
result.push_str(&format!("%{:02X}", b));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue