recall/messages.py
egregore 291d664051 Initial commit: Egregore db service
PostgreSQL message storage API with asyncpg connection pooling and full-text search.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 11:37:55 +00:00

371 lines
12 KiB
Python

"""
Egregore Database - Message storage operations (PostgreSQL)
"""
import json
import os
import uuid
from datetime import datetime
from typing import Optional
import asyncpg
# Database connection URL - can be overridden via environment
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://egregore:egregore_db_pass@localhost/egregore")
# Connection pool
_pool: Optional[asyncpg.Pool] = None
def set_db_url(url: str):
"""Set the database URL (call before init_db)"""
global DATABASE_URL
DATABASE_URL = url
async def get_pool() -> asyncpg.Pool:
"""Get or create the connection pool"""
global _pool
if _pool is None:
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
return _pool
async def close_pool():
"""Close the connection pool"""
global _pool
if _pool:
await _pool.close()
_pool = None
# Message type priorities for notifications
MESSAGE_PRIORITIES = {
"text": 2, # Regular messages - notify
"tool_use": 0, # Tool invocation - no notify
"tool_result": 0, # Tool output - no notify
"question": 3, # Questions to user - urgent notify
"mode_change": 1, # State transitions - silent
"thinking": 0, # Reasoning process - no notify
"error": 2, # Error messages - notify
}
def get_priority_for_type(msg_type: str, content: str = "") -> int:
"""Get priority for a message type, with question detection"""
base_priority = MESSAGE_PRIORITIES.get(msg_type, 0)
if msg_type == "text" and content.strip().endswith("?"):
return MESSAGE_PRIORITIES["question"]
return base_priority
async def init_db():
"""Initialize PostgreSQL database with messages table"""
pool = await get_pool()
async with pool.acquire() as conn:
await conn.execute("""
CREATE TABLE IF NOT EXISTS messages (
id SERIAL PRIMARY KEY,
role TEXT NOT NULL,
type TEXT NOT NULL DEFAULT 'text',
content TEXT NOT NULL,
group_id TEXT,
metadata JSONB,
priority INTEGER DEFAULT 0,
timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
""")
# Create indexes for efficient querying
await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_role ON messages(role)")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_type ON messages(type)")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_group_id ON messages(group_id)")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_priority ON messages(priority)")
await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON messages(timestamp)")
# Full-text search index on content
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_content_search
ON messages USING gin(to_tsvector('english', content))
""")
async def save_message(
role: str,
content: str,
msg_type: str = "text",
group_id: Optional[str] = None,
metadata: Optional[dict] = None,
priority: Optional[int] = None
) -> int:
"""Save a single message row"""
pool = await get_pool()
async with pool.acquire() as conn:
if priority is None:
priority = get_priority_for_type(msg_type, content)
row = await conn.fetchrow(
"""INSERT INTO messages (role, type, content, group_id, metadata, priority)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id""",
role, msg_type, content, group_id, json.dumps(metadata) if metadata else None, priority
)
return row['id']
async def save_response_blocks(blocks: list, group_id: str) -> list:
"""
Save each response block as a separate row.
Returns list of saved message dicts with IDs for frontend.
"""
saved_messages = []
pool = await get_pool()
async with pool.acquire() as conn:
timestamp = datetime.utcnow()
for block in blocks:
block_type = block.get("type", "text")
content = ""
metadata = None
priority = MESSAGE_PRIORITIES.get(block_type, 0)
if block_type == "text":
content = block.get("content", "")
if content.strip().endswith("?"):
priority = MESSAGE_PRIORITIES["question"]
elif block_type == "tool_use":
content = json.dumps(block.get("input", {}))
metadata = {"tool_name": block.get("name"), "tool_id": block.get("id")}
elif block_type == "tool_result":
content = block.get("content", "")
metadata = {"tool_name": block.get("tool_name"), "tool_use_id": block.get("tool_use_id")}
else:
content = block.get("content", "")
row = await conn.fetchrow(
"""INSERT INTO messages (role, type, content, group_id, metadata, priority, timestamp)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id""",
"assistant", block_type, content, group_id,
json.dumps(metadata) if metadata else None, priority, timestamp
)
msg_id = row['id']
saved_messages.append({
"id": msg_id,
"role": "assistant",
"type": block_type,
"content": content,
"group_id": group_id,
"metadata": metadata,
"priority": priority,
"timestamp": timestamp.isoformat()
})
return saved_messages
async def get_messages(
limit: int = 50,
before_id: int = None,
msg_type: str = None
) -> tuple[list[dict], bool]:
"""Get messages with pagination. Returns (messages, has_more)"""
pool = await get_pool()
async with pool.acquire() as conn:
params = []
where_clauses = []
param_idx = 1
if before_id:
where_clauses.append(f"id < ${param_idx}")
params.append(before_id)
param_idx += 1
if msg_type:
where_clauses.append(f"type = ${param_idx}")
params.append(msg_type)
param_idx += 1
where_sql = " AND ".join(where_clauses) if where_clauses else "TRUE"
params.append(limit + 1)
rows = await conn.fetch(
f"""SELECT id, role, type, content, group_id, metadata, priority, timestamp
FROM messages
WHERE {where_sql}
ORDER BY id DESC
LIMIT ${param_idx}""",
*params
)
has_more = len(rows) > limit
rows = rows[:limit]
messages = []
for row in rows:
metadata = None
if row['metadata']:
try:
metadata = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
except:
pass
messages.append({
"id": row['id'],
"role": row['role'],
"type": row['type'],
"content": row['content'],
"group_id": row['group_id'],
"metadata": metadata,
"priority": row['priority'],
"timestamp": row['timestamp'].isoformat() if row['timestamp'] else None
})
return list(reversed(messages)), has_more
async def get_conversation_history(limit: int = 100) -> list[dict]:
"""
Reconstruct Claude API message format from individual rows.
Groups assistant messages by group_id to build proper content arrays.
"""
messages, _ = await get_messages(limit)
api_messages = []
current_group = None
current_assistant_content = []
for msg in messages:
if msg["role"] == "user":
if current_assistant_content:
api_messages.append({
"role": "assistant",
"content": current_assistant_content
})
current_assistant_content = []
current_group = None
api_messages.append({
"role": "user",
"content": msg["content"]
})
elif msg["role"] == "assistant":
if msg["group_id"] != current_group:
if current_assistant_content:
api_messages.append({
"role": "assistant",
"content": current_assistant_content
})
current_assistant_content = []
current_group = msg["group_id"]
if msg["type"] == "text":
current_assistant_content.append({
"type": "text",
"text": msg["content"]
})
elif msg["type"] == "tool_use":
tool_input = {}
try:
tool_input = json.loads(msg["content"])
except:
pass
metadata = msg.get("metadata") or {}
current_assistant_content.append({
"type": "tool_use",
"id": metadata.get("tool_id", str(uuid.uuid4())),
"name": metadata.get("tool_name", "unknown"),
"input": tool_input
})
elif msg["type"] == "tool_result":
if current_assistant_content:
api_messages.append({
"role": "assistant",
"content": current_assistant_content
})
current_assistant_content = []
metadata = msg.get("metadata") or {}
api_messages.append({
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": metadata.get("tool_use_id", ""),
"content": msg["content"]
}]
})
if current_assistant_content:
api_messages.append({
"role": "assistant",
"content": current_assistant_content
})
return api_messages
async def search_messages(
query: str,
limit: int = 20,
msg_type: str = None
) -> list[dict]:
"""Search messages using PostgreSQL full-text search"""
if len(query) < 2:
return []
pool = await get_pool()
async with pool.acquire() as conn:
params = [query]
param_idx = 2
type_filter = ""
if msg_type:
type_filter = f"AND type = ${param_idx}"
params.append(msg_type)
param_idx += 1
params.append(min(limit, 50))
# Use PostgreSQL full-text search with fallback to ILIKE
rows = await conn.fetch(
f"""SELECT id, role, type, content, group_id, metadata, priority, timestamp,
ts_headline('english', content, plainto_tsquery('english', $1),
'StartSel=**, StopSel=**, MaxWords=50, MinWords=20') as snippet
FROM messages
WHERE (to_tsvector('english', content) @@ plainto_tsquery('english', $1)
OR content ILIKE '%' || $1 || '%')
{type_filter}
ORDER BY id DESC
LIMIT ${param_idx}""",
*params
)
results = []
for row in rows:
metadata = None
if row['metadata']:
try:
metadata = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
except:
pass
# Use ts_headline snippet, fallback to manual snippet
snippet = row['snippet'] if row['snippet'] else row['content'][:100]
results.append({
"id": row['id'],
"role": row['role'],
"type": row['type'],
"content": row['content'],
"group_id": row['group_id'],
"metadata": metadata,
"priority": row['priority'],
"timestamp": row['timestamp'].isoformat() if row['timestamp'] else None,
"snippet": snippet
})
return results
# Legacy compatibility - for migration
def set_db_path(path: str):
"""Legacy function for SQLite compatibility - ignored for PostgreSQL"""
pass