464 lines
15 KiB
Python
464 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
"""Memory Discovery and Query Utilities for AI Runtime
|
||
|
||
- 加载 `.ai-runtime/memory/episodic/index.yml`
|
||
- 提供 SQL 风格 (WHERE / ORDER BY / LIMIT) 的事件查询接口
|
||
- 提供 table/json 两种格式化输出
|
||
|
||
依赖:PyYAML(项目中已作为核心依赖使用)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import datetime as dt
|
||
import json
|
||
import re
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Iterable, List, Optional
|
||
|
||
import yaml
|
||
|
||
|
||
@dataclass
|
||
class MemoryEvent:
|
||
"""单条情景记忆事件的索引信息"""
|
||
|
||
id: str
|
||
type: str
|
||
level: str
|
||
timestamp: dt.datetime
|
||
date_bucket: str
|
||
path: Path
|
||
title: str = ""
|
||
tags: List[str] = field(default_factory=list)
|
||
related: List[str] = field(default_factory=list)
|
||
meta: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
@property
|
||
def date(self) -> str:
|
||
"""YYYY-MM-DD 字符串,便于 WHERE 子句使用 date 字段。"""
|
||
|
||
return self.timestamp.date().isoformat()
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为可 JSON 序列化的字典。"""
|
||
|
||
return {
|
||
"id": self.id,
|
||
"type": self.type,
|
||
"level": self.level,
|
||
"timestamp": self.timestamp.isoformat(),
|
||
"date": self.date,
|
||
"date_bucket": self.date_bucket,
|
||
"path": str(self.path),
|
||
"title": self.title,
|
||
"tags": list(self.tags),
|
||
"related": list(self.related),
|
||
"meta": dict(self.meta),
|
||
}
|
||
|
||
|
||
class MemoryDiscovery:
|
||
"""Episodic 记忆索引加载与查询"""
|
||
|
||
def __init__(self, memory_root: Path) -> None:
|
||
self.memory_root = memory_root
|
||
self.episodic_root = memory_root / "episodic"
|
||
self.index_path = self.episodic_root / "index.yml"
|
||
self.events: List[MemoryEvent] = []
|
||
self.refresh()
|
||
|
||
# ------------------------------------------------------------------
|
||
# 加载索引
|
||
# ------------------------------------------------------------------
|
||
def refresh(self) -> None:
|
||
"""重新加载索引文件。"""
|
||
|
||
self.events = self._load_events()
|
||
|
||
def _load_events(self) -> List[MemoryEvent]:
|
||
"""从 episodic 目录扫描 Markdown 事件文件并解析元信息。"""
|
||
|
||
if not self.episodic_root.exists():
|
||
return []
|
||
|
||
events: List[MemoryEvent] = []
|
||
for md_path in self.episodic_root.rglob("*.md"):
|
||
event = self._parse_event_file(md_path)
|
||
if event is not None:
|
||
events.append(event)
|
||
|
||
return events
|
||
|
||
def _parse_event_file(self, path: Path) -> Optional[MemoryEvent]:
|
||
"""解析单个事件 Markdown 文件。
|
||
|
||
协议:
|
||
- 可选顶部 YAML front matter: `--- ... ---`
|
||
- 正文中可使用:
|
||
- `# 标题` 作为事件标题
|
||
- `## 时间` 下第一行非空文本作为时间
|
||
- `## 标签` 下第一行非空文本作为标签列表
|
||
"""
|
||
|
||
try:
|
||
text = path.read_text(encoding="utf-8")
|
||
except Exception:
|
||
return None
|
||
|
||
lines = text.splitlines()
|
||
front_matter, body_lines = self._parse_front_matter(lines)
|
||
|
||
stem = path.stem
|
||
|
||
# 基础字段
|
||
id_value = str(front_matter.get("id") or stem)
|
||
type_value = str(front_matter.get("type") or "event")
|
||
|
||
# level: 优先 front matter,其次目录结构推断
|
||
level_value = str(front_matter.get("level") or self._infer_level_from_path(path))
|
||
|
||
# 标题:优先 front matter.title,其次正文第一个 '# ' 标题
|
||
title = front_matter.get("title") or self._extract_title_from_body(body_lines) or stem
|
||
|
||
# 标签:支持 front matter.tags 或 '## 标签' 段
|
||
tags = front_matter.get("tags")
|
||
if isinstance(tags, str):
|
||
tags = [t.strip() for t in re.split(r"[,\s]+", tags) if t.strip()]
|
||
elif isinstance(tags, list):
|
||
tags = [str(t) for t in tags]
|
||
else:
|
||
tags = self._extract_tags_from_body(body_lines)
|
||
|
||
# 时间:front matter.timestamp/time → 正文 '## 时间' → 文件名/mtime 兜底
|
||
ts_str = front_matter.get("timestamp") or front_matter.get("time")
|
||
timestamp: Optional[dt.datetime] = None
|
||
if isinstance(ts_str, str):
|
||
timestamp = self._parse_datetime(ts_str)
|
||
if timestamp is None:
|
||
body_time = self._extract_time_from_body(body_lines)
|
||
if body_time:
|
||
timestamp = self._parse_datetime(body_time)
|
||
if timestamp is None:
|
||
timestamp = self._infer_datetime_from_filename_or_mtime(path)
|
||
|
||
if timestamp is None:
|
||
# 无法推断时间的事件对查询意义有限,忽略该文件
|
||
return None
|
||
|
||
# date_bucket: 优先 front matter.date_bucket,其次目录结构 / 时间推断
|
||
date_bucket = front_matter.get("date_bucket") or self._infer_date_bucket(path, timestamp)
|
||
|
||
related = front_matter.get("related") or []
|
||
if isinstance(related, str):
|
||
related = [related]
|
||
elif not isinstance(related, list):
|
||
related = []
|
||
|
||
# meta: 保留所有未被提升为显式字段的 front matter 信息
|
||
meta: Dict[str, Any] = dict(front_matter)
|
||
for k in [
|
||
"id",
|
||
"type",
|
||
"level",
|
||
"title",
|
||
"tags",
|
||
"timestamp",
|
||
"time",
|
||
"date_bucket",
|
||
"related",
|
||
]:
|
||
meta.pop(k, None)
|
||
|
||
return MemoryEvent(
|
||
id=id_value,
|
||
type=type_value,
|
||
level=level_value,
|
||
timestamp=timestamp,
|
||
date_bucket=str(date_bucket),
|
||
path=path,
|
||
title=str(title),
|
||
tags=list(tags or []),
|
||
related=list(related),
|
||
meta=meta,
|
||
)
|
||
|
||
def _parse_front_matter(self, lines: List[str]):
|
||
"""解析 YAML front matter,如果不存在则返回空字典和原始行。"""
|
||
|
||
if not lines:
|
||
return {}, []
|
||
|
||
if lines[0].strip() != "---":
|
||
return {}, lines
|
||
|
||
for i in range(1, len(lines)):
|
||
if lines[i].strip() == "---":
|
||
fm_text = "\n".join(lines[1:i])
|
||
try:
|
||
data = yaml.safe_load(fm_text) or {}
|
||
except Exception:
|
||
data = {}
|
||
return data, lines[i + 1 :]
|
||
|
||
# 未找到结束分隔符,视为无 front matter
|
||
return {}, lines
|
||
|
||
@staticmethod
|
||
def _extract_title_from_body(body_lines: List[str]) -> Optional[str]:
|
||
for line in body_lines:
|
||
s = line.strip()
|
||
if s.startswith("# "):
|
||
return s[2:].strip()
|
||
return None
|
||
|
||
@staticmethod
|
||
def _extract_time_from_body(body_lines: List[str]) -> Optional[str]:
|
||
for i, line in enumerate(body_lines):
|
||
if line.strip().startswith("## 时间"):
|
||
for j in range(i + 1, len(body_lines)):
|
||
value = body_lines[j].strip()
|
||
if value:
|
||
return value
|
||
break
|
||
return None
|
||
|
||
@staticmethod
|
||
def _extract_tags_from_body(body_lines: List[str]) -> List[str]:
|
||
for i, line in enumerate(body_lines):
|
||
if line.strip().startswith("## 标签"):
|
||
for j in range(i + 1, len(body_lines)):
|
||
raw = body_lines[j].strip()
|
||
if not raw:
|
||
continue
|
||
parts = [p.strip() for p in re.split(r"[,\s]+", raw) if p.strip()]
|
||
return parts
|
||
break
|
||
return []
|
||
|
||
def _infer_level_from_path(self, path: Path) -> str:
|
||
"""根据相对路径推断级别: year/month/day/event。"""
|
||
|
||
try:
|
||
rel = path.relative_to(self.episodic_root)
|
||
except ValueError:
|
||
return "event"
|
||
|
||
parts = rel.parts
|
||
if len(parts) >= 3 and parts[0].isdigit() and parts[1].isdigit() and parts[2].isdigit():
|
||
return "day"
|
||
if len(parts) >= 2 and parts[0].isdigit() and parts[1].isdigit():
|
||
return "month"
|
||
if len(parts) >= 1 and parts[0].isdigit():
|
||
return "year"
|
||
return "event"
|
||
|
||
def _infer_date_bucket(self, path: Path, ts: dt.datetime) -> str:
|
||
"""推断 date_bucket,例如 "2025/11/14"。"""
|
||
|
||
try:
|
||
rel = path.relative_to(self.episodic_root)
|
||
parts = rel.parts
|
||
if len(parts) >= 3 and parts[0].isdigit() and parts[1].isdigit() and parts[2].isdigit():
|
||
return f"{parts[0]}/{parts[1]}/{parts[2]}"
|
||
except ValueError:
|
||
pass
|
||
|
||
return ts.date().isoformat()
|
||
|
||
def _infer_datetime_from_filename_or_mtime(self, path: Path) -> Optional[dt.datetime]:
|
||
"""从文件名 (YYYYMMDD-HHMM) 或 mtime 推断时间。"""
|
||
|
||
m = re.match(r"(\d{8})-(\d{4})", path.stem)
|
||
if m:
|
||
date_str, hm = m.groups()
|
||
try:
|
||
return dt.datetime.strptime(date_str + hm, "%Y%m%d%H%M")
|
||
except Exception:
|
||
pass
|
||
|
||
m2 = re.match(r"(\d{4})(\d{2})(\d{2})", path.stem)
|
||
if m2:
|
||
y, mth, d = m2.groups()
|
||
try:
|
||
return dt.datetime.strptime(f"{y}{mth}{d}", "%Y%m%d")
|
||
except Exception:
|
||
pass
|
||
|
||
try:
|
||
return dt.datetime.fromtimestamp(path.stat().st_mtime)
|
||
except Exception:
|
||
return None
|
||
|
||
# ------------------------------------------------------------------
|
||
# SQL 风格查询接口
|
||
# ------------------------------------------------------------------
|
||
def query(
|
||
self,
|
||
where: Optional[str] = None,
|
||
order_by: Optional[str] = None,
|
||
limit: Optional[int] = None,
|
||
offset: int = 0,
|
||
) -> List[MemoryEvent]:
|
||
"""基于 SQL 风格参数查询事件列表。"""
|
||
|
||
events: List[MemoryEvent] = list(self.events)
|
||
|
||
if where:
|
||
events = list(self._apply_where(events, where))
|
||
|
||
if order_by:
|
||
events = self._apply_order_by(events, order_by)
|
||
|
||
if offset:
|
||
events = events[offset:]
|
||
|
||
if limit is not None:
|
||
events = events[:limit]
|
||
|
||
return events
|
||
|
||
def _apply_where(
|
||
self, events: Iterable[MemoryEvent], where: str
|
||
) -> Iterable[MemoryEvent]:
|
||
"""简易 WHERE 解析,仅支持 AND,运算符子集。
|
||
|
||
支持的形式:
|
||
- field = 'value' / != / >= / <=
|
||
- tags CONTAINS 'tag'
|
||
- 通过 AND 连接多个条件(不支持 OR / 括号)
|
||
"""
|
||
|
||
conditions = [part.strip() for part in re.split(r"\s+AND\s+", where, flags=re.I) if part.strip()]
|
||
|
||
def match(event: MemoryEvent) -> bool:
|
||
for cond in conditions:
|
||
if not self._eval_condition(event, cond):
|
||
return False
|
||
return True
|
||
|
||
return (e for e in events if match(e))
|
||
|
||
def _eval_condition(self, event: MemoryEvent, cond: str) -> bool:
|
||
# tags CONTAINS 'tag'
|
||
if re.search(r"\bCONTAINS\b", cond, flags=re.I):
|
||
left, right = re.split(r"\bCONTAINS\b", cond, maxsplit=1, flags=re.I)
|
||
field = left.strip()
|
||
value = self._strip_quotes(right.strip())
|
||
if field.lower() != "tags":
|
||
return False
|
||
return value in (event.tags or [])
|
||
|
||
# field op value
|
||
m = re.match(r"^(\w+)\s*(=|!=|>=|<=)\s*(.+)$", cond)
|
||
if not m:
|
||
return False
|
||
|
||
field, op, raw_value = m.groups()
|
||
field = field.strip().lower()
|
||
value = self._strip_quotes(raw_value.strip())
|
||
|
||
# 取事件属性或 meta 字段
|
||
lhs: Any
|
||
if field == "id":
|
||
lhs = event.id
|
||
elif field == "type":
|
||
lhs = event.type
|
||
elif field == "level":
|
||
lhs = event.level
|
||
elif field == "title":
|
||
lhs = event.title
|
||
elif field == "date":
|
||
lhs = event.date
|
||
elif field == "timestamp":
|
||
lhs = event.timestamp
|
||
else:
|
||
if field in event.meta:
|
||
lhs = event.meta[field]
|
||
else:
|
||
# 未知字段直接返回 False,避免误匹配
|
||
return False
|
||
|
||
# 时间 / 日期字段支持 >= <=
|
||
if isinstance(lhs, dt.datetime):
|
||
rhs = self._parse_datetime(value)
|
||
if rhs is None:
|
||
return False
|
||
else:
|
||
rhs = value
|
||
|
||
try:
|
||
if op == "=":
|
||
return lhs == rhs
|
||
if op == "!=":
|
||
return lhs != rhs
|
||
if op == ">=":
|
||
return lhs >= rhs
|
||
if op == "<=":
|
||
return lhs <= rhs
|
||
except TypeError:
|
||
return False
|
||
|
||
return False
|
||
|
||
@staticmethod
|
||
def _strip_quotes(text: str) -> str:
|
||
if (text.startswith("'") and text.endswith("'")) or (
|
||
text.startswith('"') and text.endswith('"')
|
||
):
|
||
return text[1:-1]
|
||
return text
|
||
|
||
@staticmethod
|
||
def _parse_datetime(value: str) -> Optional[dt.datetime]:
|
||
# 支持 "YYYY-MM-DD" 或 ISO8601 字符串
|
||
try:
|
||
if len(value) == 10:
|
||
return dt.datetime.fromisoformat(value + "T00:00:00")
|
||
return dt.datetime.fromisoformat(value)
|
||
except Exception:
|
||
return None
|
||
|
||
# ------------------------------------------------------------------
|
||
# 格式化输出
|
||
# ------------------------------------------------------------------
|
||
def format_events(
|
||
self,
|
||
events: List[MemoryEvent],
|
||
select: Optional[List[str]] = None,
|
||
format_type: str = "table",
|
||
) -> str:
|
||
if select is None or not select:
|
||
select = ["id", "timestamp", "title"]
|
||
|
||
rows = []
|
||
for ev in events:
|
||
d = ev.to_dict()
|
||
rows.append({field: d.get(field) for field in select})
|
||
|
||
if format_type == "json":
|
||
return json.dumps(rows, ensure_ascii=False, indent=2)
|
||
|
||
# table 格式
|
||
return self._format_table(rows, select)
|
||
|
||
@staticmethod
|
||
def _format_table(rows: List[Dict[str, Any]], headers: List[str]) -> str:
|
||
if not rows:
|
||
return "(no events)"
|
||
|
||
# 计算列宽
|
||
widths: Dict[str, int] = {}
|
||
for h in headers:
|
||
widths[h] = max(len(h), *(len(str(row.get(h, ""))) for row in rows))
|
||
|
||
def fmt_row(row: Dict[str, Any]) -> str:
|
||
return " ".join(str(row.get(h, "")).ljust(widths[h]) for h in headers)
|
||
|
||
header_line = " ".join(h.ljust(widths[h]) for h in headers)
|
||
sep_line = " ".join("-" * widths[h] for h in headers)
|
||
data_lines = [fmt_row(r) for r in rows]
|
||
|
||
return "\n".join([header_line, sep_line, *data_lines])
|