PostgreSQL message storage API with asyncpg connection pooling and full-text search. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
371 lines
12 KiB
Python
371 lines
12 KiB
Python
"""
|
|
Egregore Database - Message storage operations (PostgreSQL)
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
import asyncpg
|
|
|
|
# Database connection URL - can be overridden via environment
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://egregore:egregore_db_pass@localhost/egregore")
|
|
|
|
# Connection pool
|
|
_pool: Optional[asyncpg.Pool] = None
|
|
|
|
|
|
def set_db_url(url: str):
|
|
"""Set the database URL (call before init_db)"""
|
|
global DATABASE_URL
|
|
DATABASE_URL = url
|
|
|
|
|
|
async def get_pool() -> asyncpg.Pool:
|
|
"""Get or create the connection pool"""
|
|
global _pool
|
|
if _pool is None:
|
|
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
|
return _pool
|
|
|
|
|
|
async def close_pool():
|
|
"""Close the connection pool"""
|
|
global _pool
|
|
if _pool:
|
|
await _pool.close()
|
|
_pool = None
|
|
|
|
|
|
# Message type priorities for notifications
|
|
MESSAGE_PRIORITIES = {
|
|
"text": 2, # Regular messages - notify
|
|
"tool_use": 0, # Tool invocation - no notify
|
|
"tool_result": 0, # Tool output - no notify
|
|
"question": 3, # Questions to user - urgent notify
|
|
"mode_change": 1, # State transitions - silent
|
|
"thinking": 0, # Reasoning process - no notify
|
|
"error": 2, # Error messages - notify
|
|
}
|
|
|
|
|
|
def get_priority_for_type(msg_type: str, content: str = "") -> int:
|
|
"""Get priority for a message type, with question detection"""
|
|
base_priority = MESSAGE_PRIORITIES.get(msg_type, 0)
|
|
if msg_type == "text" and content.strip().endswith("?"):
|
|
return MESSAGE_PRIORITIES["question"]
|
|
return base_priority
|
|
|
|
|
|
async def init_db():
|
|
"""Initialize PostgreSQL database with messages table"""
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
await conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id SERIAL PRIMARY KEY,
|
|
role TEXT NOT NULL,
|
|
type TEXT NOT NULL DEFAULT 'text',
|
|
content TEXT NOT NULL,
|
|
group_id TEXT,
|
|
metadata JSONB,
|
|
priority INTEGER DEFAULT 0,
|
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
)
|
|
""")
|
|
|
|
# Create indexes for efficient querying
|
|
await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_role ON messages(role)")
|
|
await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_type ON messages(type)")
|
|
await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_group_id ON messages(group_id)")
|
|
await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_priority ON messages(priority)")
|
|
await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON messages(timestamp)")
|
|
|
|
# Full-text search index on content
|
|
await conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_messages_content_search
|
|
ON messages USING gin(to_tsvector('english', content))
|
|
""")
|
|
|
|
|
|
async def save_message(
|
|
role: str,
|
|
content: str,
|
|
msg_type: str = "text",
|
|
group_id: Optional[str] = None,
|
|
metadata: Optional[dict] = None,
|
|
priority: Optional[int] = None
|
|
) -> int:
|
|
"""Save a single message row"""
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
if priority is None:
|
|
priority = get_priority_for_type(msg_type, content)
|
|
|
|
row = await conn.fetchrow(
|
|
"""INSERT INTO messages (role, type, content, group_id, metadata, priority)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
RETURNING id""",
|
|
role, msg_type, content, group_id, json.dumps(metadata) if metadata else None, priority
|
|
)
|
|
return row['id']
|
|
|
|
|
|
async def save_response_blocks(blocks: list, group_id: str) -> list:
|
|
"""
|
|
Save each response block as a separate row.
|
|
Returns list of saved message dicts with IDs for frontend.
|
|
"""
|
|
saved_messages = []
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
timestamp = datetime.utcnow()
|
|
|
|
for block in blocks:
|
|
block_type = block.get("type", "text")
|
|
content = ""
|
|
metadata = None
|
|
priority = MESSAGE_PRIORITIES.get(block_type, 0)
|
|
|
|
if block_type == "text":
|
|
content = block.get("content", "")
|
|
if content.strip().endswith("?"):
|
|
priority = MESSAGE_PRIORITIES["question"]
|
|
elif block_type == "tool_use":
|
|
content = json.dumps(block.get("input", {}))
|
|
metadata = {"tool_name": block.get("name"), "tool_id": block.get("id")}
|
|
elif block_type == "tool_result":
|
|
content = block.get("content", "")
|
|
metadata = {"tool_name": block.get("tool_name"), "tool_use_id": block.get("tool_use_id")}
|
|
else:
|
|
content = block.get("content", "")
|
|
|
|
row = await conn.fetchrow(
|
|
"""INSERT INTO messages (role, type, content, group_id, metadata, priority, timestamp)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
RETURNING id""",
|
|
"assistant", block_type, content, group_id,
|
|
json.dumps(metadata) if metadata else None, priority, timestamp
|
|
)
|
|
msg_id = row['id']
|
|
|
|
saved_messages.append({
|
|
"id": msg_id,
|
|
"role": "assistant",
|
|
"type": block_type,
|
|
"content": content,
|
|
"group_id": group_id,
|
|
"metadata": metadata,
|
|
"priority": priority,
|
|
"timestamp": timestamp.isoformat()
|
|
})
|
|
|
|
return saved_messages
|
|
|
|
|
|
async def get_messages(
|
|
limit: int = 50,
|
|
before_id: int = None,
|
|
msg_type: str = None
|
|
) -> tuple[list[dict], bool]:
|
|
"""Get messages with pagination. Returns (messages, has_more)"""
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
params = []
|
|
where_clauses = []
|
|
param_idx = 1
|
|
|
|
if before_id:
|
|
where_clauses.append(f"id < ${param_idx}")
|
|
params.append(before_id)
|
|
param_idx += 1
|
|
if msg_type:
|
|
where_clauses.append(f"type = ${param_idx}")
|
|
params.append(msg_type)
|
|
param_idx += 1
|
|
|
|
where_sql = " AND ".join(where_clauses) if where_clauses else "TRUE"
|
|
params.append(limit + 1)
|
|
|
|
rows = await conn.fetch(
|
|
f"""SELECT id, role, type, content, group_id, metadata, priority, timestamp
|
|
FROM messages
|
|
WHERE {where_sql}
|
|
ORDER BY id DESC
|
|
LIMIT ${param_idx}""",
|
|
*params
|
|
)
|
|
|
|
has_more = len(rows) > limit
|
|
rows = rows[:limit]
|
|
|
|
messages = []
|
|
for row in rows:
|
|
metadata = None
|
|
if row['metadata']:
|
|
try:
|
|
metadata = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
|
|
except:
|
|
pass
|
|
messages.append({
|
|
"id": row['id'],
|
|
"role": row['role'],
|
|
"type": row['type'],
|
|
"content": row['content'],
|
|
"group_id": row['group_id'],
|
|
"metadata": metadata,
|
|
"priority": row['priority'],
|
|
"timestamp": row['timestamp'].isoformat() if row['timestamp'] else None
|
|
})
|
|
|
|
return list(reversed(messages)), has_more
|
|
|
|
|
|
async def get_conversation_history(limit: int = 100) -> list[dict]:
|
|
"""
|
|
Reconstruct Claude API message format from individual rows.
|
|
Groups assistant messages by group_id to build proper content arrays.
|
|
"""
|
|
messages, _ = await get_messages(limit)
|
|
|
|
api_messages = []
|
|
current_group = None
|
|
current_assistant_content = []
|
|
|
|
for msg in messages:
|
|
if msg["role"] == "user":
|
|
if current_assistant_content:
|
|
api_messages.append({
|
|
"role": "assistant",
|
|
"content": current_assistant_content
|
|
})
|
|
current_assistant_content = []
|
|
current_group = None
|
|
|
|
api_messages.append({
|
|
"role": "user",
|
|
"content": msg["content"]
|
|
})
|
|
elif msg["role"] == "assistant":
|
|
if msg["group_id"] != current_group:
|
|
if current_assistant_content:
|
|
api_messages.append({
|
|
"role": "assistant",
|
|
"content": current_assistant_content
|
|
})
|
|
current_assistant_content = []
|
|
current_group = msg["group_id"]
|
|
|
|
if msg["type"] == "text":
|
|
current_assistant_content.append({
|
|
"type": "text",
|
|
"text": msg["content"]
|
|
})
|
|
elif msg["type"] == "tool_use":
|
|
tool_input = {}
|
|
try:
|
|
tool_input = json.loads(msg["content"])
|
|
except:
|
|
pass
|
|
metadata = msg.get("metadata") or {}
|
|
current_assistant_content.append({
|
|
"type": "tool_use",
|
|
"id": metadata.get("tool_id", str(uuid.uuid4())),
|
|
"name": metadata.get("tool_name", "unknown"),
|
|
"input": tool_input
|
|
})
|
|
elif msg["type"] == "tool_result":
|
|
if current_assistant_content:
|
|
api_messages.append({
|
|
"role": "assistant",
|
|
"content": current_assistant_content
|
|
})
|
|
current_assistant_content = []
|
|
metadata = msg.get("metadata") or {}
|
|
api_messages.append({
|
|
"role": "user",
|
|
"content": [{
|
|
"type": "tool_result",
|
|
"tool_use_id": metadata.get("tool_use_id", ""),
|
|
"content": msg["content"]
|
|
}]
|
|
})
|
|
|
|
if current_assistant_content:
|
|
api_messages.append({
|
|
"role": "assistant",
|
|
"content": current_assistant_content
|
|
})
|
|
|
|
return api_messages
|
|
|
|
|
|
async def search_messages(
|
|
query: str,
|
|
limit: int = 20,
|
|
msg_type: str = None
|
|
) -> list[dict]:
|
|
"""Search messages using PostgreSQL full-text search"""
|
|
if len(query) < 2:
|
|
return []
|
|
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
params = [query]
|
|
param_idx = 2
|
|
type_filter = ""
|
|
|
|
if msg_type:
|
|
type_filter = f"AND type = ${param_idx}"
|
|
params.append(msg_type)
|
|
param_idx += 1
|
|
|
|
params.append(min(limit, 50))
|
|
|
|
# Use PostgreSQL full-text search with fallback to ILIKE
|
|
rows = await conn.fetch(
|
|
f"""SELECT id, role, type, content, group_id, metadata, priority, timestamp,
|
|
ts_headline('english', content, plainto_tsquery('english', $1),
|
|
'StartSel=**, StopSel=**, MaxWords=50, MinWords=20') as snippet
|
|
FROM messages
|
|
WHERE (to_tsvector('english', content) @@ plainto_tsquery('english', $1)
|
|
OR content ILIKE '%' || $1 || '%')
|
|
{type_filter}
|
|
ORDER BY id DESC
|
|
LIMIT ${param_idx}""",
|
|
*params
|
|
)
|
|
|
|
results = []
|
|
for row in rows:
|
|
metadata = None
|
|
if row['metadata']:
|
|
try:
|
|
metadata = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
|
|
except:
|
|
pass
|
|
|
|
# Use ts_headline snippet, fallback to manual snippet
|
|
snippet = row['snippet'] if row['snippet'] else row['content'][:100]
|
|
|
|
results.append({
|
|
"id": row['id'],
|
|
"role": row['role'],
|
|
"type": row['type'],
|
|
"content": row['content'],
|
|
"group_id": row['group_id'],
|
|
"metadata": metadata,
|
|
"priority": row['priority'],
|
|
"timestamp": row['timestamp'].isoformat() if row['timestamp'] else None,
|
|
"snippet": snippet
|
|
})
|
|
|
|
return results
|
|
|
|
|
|
# Legacy compatibility - for migration
|
|
def set_db_path(path: str):
|
|
"""Legacy function for SQLite compatibility - ignored for PostgreSQL"""
|
|
pass
|