Initial commit

This commit is contained in:
Zhongwei Li
2025-11-29 18:24:34 +08:00
commit 3a6a27d00c
16 changed files with 2959 additions and 0 deletions

View File

@@ -0,0 +1,463 @@
#!/usr/bin/env python3
"""Memory Discovery and Query Utilities for AI Runtime
- 加载 `.ai-runtime/memory/episodic/index.yml`
- 提供 SQL 风格 (WHERE / ORDER BY / LIMIT) 的事件查询接口
- 提供 table/json 两种格式化输出
依赖PyYAML项目中已作为核心依赖使用
"""
from __future__ import annotations
import datetime as dt
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
import yaml
@dataclass
class MemoryEvent:
"""单条情景记忆事件的索引信息"""
id: str
type: str
level: str
timestamp: dt.datetime
date_bucket: str
path: Path
title: str = ""
tags: List[str] = field(default_factory=list)
related: List[str] = field(default_factory=list)
meta: Dict[str, Any] = field(default_factory=dict)
@property
def date(self) -> str:
"""YYYY-MM-DD 字符串,便于 WHERE 子句使用 date 字段。"""
return self.timestamp.date().isoformat()
def to_dict(self) -> Dict[str, Any]:
"""转换为可 JSON 序列化的字典。"""
return {
"id": self.id,
"type": self.type,
"level": self.level,
"timestamp": self.timestamp.isoformat(),
"date": self.date,
"date_bucket": self.date_bucket,
"path": str(self.path),
"title": self.title,
"tags": list(self.tags),
"related": list(self.related),
"meta": dict(self.meta),
}
class MemoryDiscovery:
"""Episodic 记忆索引加载与查询"""
def __init__(self, memory_root: Path) -> None:
self.memory_root = memory_root
self.episodic_root = memory_root / "episodic"
self.index_path = self.episodic_root / "index.yml"
self.events: List[MemoryEvent] = []
self.refresh()
# ------------------------------------------------------------------
# 加载索引
# ------------------------------------------------------------------
def refresh(self) -> None:
"""重新加载索引文件。"""
self.events = self._load_events()
def _load_events(self) -> List[MemoryEvent]:
"""从 episodic 目录扫描 Markdown 事件文件并解析元信息。"""
if not self.episodic_root.exists():
return []
events: List[MemoryEvent] = []
for md_path in self.episodic_root.rglob("*.md"):
event = self._parse_event_file(md_path)
if event is not None:
events.append(event)
return events
def _parse_event_file(self, path: Path) -> Optional[MemoryEvent]:
"""解析单个事件 Markdown 文件。
协议:
- 可选顶部 YAML front matter: `--- ... ---`
- 正文中可使用:
- `# 标题` 作为事件标题
- `## 时间` 下第一行非空文本作为时间
- `## 标签` 下第一行非空文本作为标签列表
"""
try:
text = path.read_text(encoding="utf-8")
except Exception:
return None
lines = text.splitlines()
front_matter, body_lines = self._parse_front_matter(lines)
stem = path.stem
# 基础字段
id_value = str(front_matter.get("id") or stem)
type_value = str(front_matter.get("type") or "event")
# level: 优先 front matter其次目录结构推断
level_value = str(front_matter.get("level") or self._infer_level_from_path(path))
# 标题:优先 front matter.title其次正文第一个 '# ' 标题
title = front_matter.get("title") or self._extract_title_from_body(body_lines) or stem
# 标签:支持 front matter.tags 或 '## 标签' 段
tags = front_matter.get("tags")
if isinstance(tags, str):
tags = [t.strip() for t in re.split(r"[,\s]+", tags) if t.strip()]
elif isinstance(tags, list):
tags = [str(t) for t in tags]
else:
tags = self._extract_tags_from_body(body_lines)
# 时间front matter.timestamp/time → 正文 '## 时间' → 文件名/mtime 兜底
ts_str = front_matter.get("timestamp") or front_matter.get("time")
timestamp: Optional[dt.datetime] = None
if isinstance(ts_str, str):
timestamp = self._parse_datetime(ts_str)
if timestamp is None:
body_time = self._extract_time_from_body(body_lines)
if body_time:
timestamp = self._parse_datetime(body_time)
if timestamp is None:
timestamp = self._infer_datetime_from_filename_or_mtime(path)
if timestamp is None:
# 无法推断时间的事件对查询意义有限,忽略该文件
return None
# date_bucket: 优先 front matter.date_bucket其次目录结构 / 时间推断
date_bucket = front_matter.get("date_bucket") or self._infer_date_bucket(path, timestamp)
related = front_matter.get("related") or []
if isinstance(related, str):
related = [related]
elif not isinstance(related, list):
related = []
# meta: 保留所有未被提升为显式字段的 front matter 信息
meta: Dict[str, Any] = dict(front_matter)
for k in [
"id",
"type",
"level",
"title",
"tags",
"timestamp",
"time",
"date_bucket",
"related",
]:
meta.pop(k, None)
return MemoryEvent(
id=id_value,
type=type_value,
level=level_value,
timestamp=timestamp,
date_bucket=str(date_bucket),
path=path,
title=str(title),
tags=list(tags or []),
related=list(related),
meta=meta,
)
def _parse_front_matter(self, lines: List[str]):
"""解析 YAML front matter如果不存在则返回空字典和原始行。"""
if not lines:
return {}, []
if lines[0].strip() != "---":
return {}, lines
for i in range(1, len(lines)):
if lines[i].strip() == "---":
fm_text = "\n".join(lines[1:i])
try:
data = yaml.safe_load(fm_text) or {}
except Exception:
data = {}
return data, lines[i + 1 :]
# 未找到结束分隔符,视为无 front matter
return {}, lines
@staticmethod
def _extract_title_from_body(body_lines: List[str]) -> Optional[str]:
for line in body_lines:
s = line.strip()
if s.startswith("# "):
return s[2:].strip()
return None
@staticmethod
def _extract_time_from_body(body_lines: List[str]) -> Optional[str]:
for i, line in enumerate(body_lines):
if line.strip().startswith("## 时间"):
for j in range(i + 1, len(body_lines)):
value = body_lines[j].strip()
if value:
return value
break
return None
@staticmethod
def _extract_tags_from_body(body_lines: List[str]) -> List[str]:
for i, line in enumerate(body_lines):
if line.strip().startswith("## 标签"):
for j in range(i + 1, len(body_lines)):
raw = body_lines[j].strip()
if not raw:
continue
parts = [p.strip() for p in re.split(r"[,\s]+", raw) if p.strip()]
return parts
break
return []
def _infer_level_from_path(self, path: Path) -> str:
"""根据相对路径推断级别: year/month/day/event。"""
try:
rel = path.relative_to(self.episodic_root)
except ValueError:
return "event"
parts = rel.parts
if len(parts) >= 3 and parts[0].isdigit() and parts[1].isdigit() and parts[2].isdigit():
return "day"
if len(parts) >= 2 and parts[0].isdigit() and parts[1].isdigit():
return "month"
if len(parts) >= 1 and parts[0].isdigit():
return "year"
return "event"
def _infer_date_bucket(self, path: Path, ts: dt.datetime) -> str:
"""推断 date_bucket例如 "2025/11/14""""
try:
rel = path.relative_to(self.episodic_root)
parts = rel.parts
if len(parts) >= 3 and parts[0].isdigit() and parts[1].isdigit() and parts[2].isdigit():
return f"{parts[0]}/{parts[1]}/{parts[2]}"
except ValueError:
pass
return ts.date().isoformat()
def _infer_datetime_from_filename_or_mtime(self, path: Path) -> Optional[dt.datetime]:
"""从文件名 (YYYYMMDD-HHMM) 或 mtime 推断时间。"""
m = re.match(r"(\d{8})-(\d{4})", path.stem)
if m:
date_str, hm = m.groups()
try:
return dt.datetime.strptime(date_str + hm, "%Y%m%d%H%M")
except Exception:
pass
m2 = re.match(r"(\d{4})(\d{2})(\d{2})", path.stem)
if m2:
y, mth, d = m2.groups()
try:
return dt.datetime.strptime(f"{y}{mth}{d}", "%Y%m%d")
except Exception:
pass
try:
return dt.datetime.fromtimestamp(path.stat().st_mtime)
except Exception:
return None
# ------------------------------------------------------------------
# SQL 风格查询接口
# ------------------------------------------------------------------
def query(
self,
where: Optional[str] = None,
order_by: Optional[str] = None,
limit: Optional[int] = None,
offset: int = 0,
) -> List[MemoryEvent]:
"""基于 SQL 风格参数查询事件列表。"""
events: List[MemoryEvent] = list(self.events)
if where:
events = list(self._apply_where(events, where))
if order_by:
events = self._apply_order_by(events, order_by)
if offset:
events = events[offset:]
if limit is not None:
events = events[:limit]
return events
def _apply_where(
self, events: Iterable[MemoryEvent], where: str
) -> Iterable[MemoryEvent]:
"""简易 WHERE 解析,仅支持 AND运算符子集。
支持的形式:
- field = 'value' / != / >= / <=
- tags CONTAINS 'tag'
- 通过 AND 连接多个条件(不支持 OR / 括号)
"""
conditions = [part.strip() for part in re.split(r"\s+AND\s+", where, flags=re.I) if part.strip()]
def match(event: MemoryEvent) -> bool:
for cond in conditions:
if not self._eval_condition(event, cond):
return False
return True
return (e for e in events if match(e))
def _eval_condition(self, event: MemoryEvent, cond: str) -> bool:
# tags CONTAINS 'tag'
if re.search(r"\bCONTAINS\b", cond, flags=re.I):
left, right = re.split(r"\bCONTAINS\b", cond, maxsplit=1, flags=re.I)
field = left.strip()
value = self._strip_quotes(right.strip())
if field.lower() != "tags":
return False
return value in (event.tags or [])
# field op value
m = re.match(r"^(\w+)\s*(=|!=|>=|<=)\s*(.+)$", cond)
if not m:
return False
field, op, raw_value = m.groups()
field = field.strip().lower()
value = self._strip_quotes(raw_value.strip())
# 取事件属性或 meta 字段
lhs: Any
if field == "id":
lhs = event.id
elif field == "type":
lhs = event.type
elif field == "level":
lhs = event.level
elif field == "title":
lhs = event.title
elif field == "date":
lhs = event.date
elif field == "timestamp":
lhs = event.timestamp
else:
if field in event.meta:
lhs = event.meta[field]
else:
# 未知字段直接返回 False避免误匹配
return False
# 时间 / 日期字段支持 >= <=
if isinstance(lhs, dt.datetime):
rhs = self._parse_datetime(value)
if rhs is None:
return False
else:
rhs = value
try:
if op == "=":
return lhs == rhs
if op == "!=":
return lhs != rhs
if op == ">=":
return lhs >= rhs
if op == "<=":
return lhs <= rhs
except TypeError:
return False
return False
@staticmethod
def _strip_quotes(text: str) -> str:
if (text.startswith("'") and text.endswith("'")) or (
text.startswith('"') and text.endswith('"')
):
return text[1:-1]
return text
@staticmethod
def _parse_datetime(value: str) -> Optional[dt.datetime]:
# 支持 "YYYY-MM-DD" 或 ISO8601 字符串
try:
if len(value) == 10:
return dt.datetime.fromisoformat(value + "T00:00:00")
return dt.datetime.fromisoformat(value)
except Exception:
return None
# ------------------------------------------------------------------
# 格式化输出
# ------------------------------------------------------------------
def format_events(
self,
events: List[MemoryEvent],
select: Optional[List[str]] = None,
format_type: str = "table",
) -> str:
if select is None or not select:
select = ["id", "timestamp", "title"]
rows = []
for ev in events:
d = ev.to_dict()
rows.append({field: d.get(field) for field in select})
if format_type == "json":
return json.dumps(rows, ensure_ascii=False, indent=2)
# table 格式
return self._format_table(rows, select)
@staticmethod
def _format_table(rows: List[Dict[str, Any]], headers: List[str]) -> str:
if not rows:
return "(no events)"
# 计算列宽
widths: Dict[str, int] = {}
for h in headers:
widths[h] = max(len(h), *(len(str(row.get(h, ""))) for row in rows))
def fmt_row(row: Dict[str, Any]) -> str:
return " ".join(str(row.get(h, "")).ljust(widths[h]) for h in headers)
header_line = " ".join(h.ljust(widths[h]) for h in headers)
sep_line = " ".join("-" * widths[h] for h in headers)
data_lines = [fmt_row(r) for r in rows]
return "\n".join([header_line, sep_line, *data_lines])