Files
gh-dwsy-ai-runtime-ai-runti…/skills/memory/memory_discovery.py
2025-11-29 18:24:34 +08:00

464 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Memory Discovery and Query Utilities for AI Runtime
- 加载 `.ai-runtime/memory/episodic/index.yml`
- 提供 SQL 风格 (WHERE / ORDER BY / LIMIT) 的事件查询接口
- 提供 table/json 两种格式化输出
依赖PyYAML项目中已作为核心依赖使用
"""
from __future__ import annotations
import datetime as dt
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
import yaml
@dataclass
class MemoryEvent:
"""单条情景记忆事件的索引信息"""
id: str
type: str
level: str
timestamp: dt.datetime
date_bucket: str
path: Path
title: str = ""
tags: List[str] = field(default_factory=list)
related: List[str] = field(default_factory=list)
meta: Dict[str, Any] = field(default_factory=dict)
@property
def date(self) -> str:
"""YYYY-MM-DD 字符串,便于 WHERE 子句使用 date 字段。"""
return self.timestamp.date().isoformat()
def to_dict(self) -> Dict[str, Any]:
"""转换为可 JSON 序列化的字典。"""
return {
"id": self.id,
"type": self.type,
"level": self.level,
"timestamp": self.timestamp.isoformat(),
"date": self.date,
"date_bucket": self.date_bucket,
"path": str(self.path),
"title": self.title,
"tags": list(self.tags),
"related": list(self.related),
"meta": dict(self.meta),
}
class MemoryDiscovery:
"""Episodic 记忆索引加载与查询"""
def __init__(self, memory_root: Path) -> None:
self.memory_root = memory_root
self.episodic_root = memory_root / "episodic"
self.index_path = self.episodic_root / "index.yml"
self.events: List[MemoryEvent] = []
self.refresh()
# ------------------------------------------------------------------
# 加载索引
# ------------------------------------------------------------------
def refresh(self) -> None:
"""重新加载索引文件。"""
self.events = self._load_events()
def _load_events(self) -> List[MemoryEvent]:
"""从 episodic 目录扫描 Markdown 事件文件并解析元信息。"""
if not self.episodic_root.exists():
return []
events: List[MemoryEvent] = []
for md_path in self.episodic_root.rglob("*.md"):
event = self._parse_event_file(md_path)
if event is not None:
events.append(event)
return events
def _parse_event_file(self, path: Path) -> Optional[MemoryEvent]:
"""解析单个事件 Markdown 文件。
协议:
- 可选顶部 YAML front matter: `--- ... ---`
- 正文中可使用:
- `# 标题` 作为事件标题
- `## 时间` 下第一行非空文本作为时间
- `## 标签` 下第一行非空文本作为标签列表
"""
try:
text = path.read_text(encoding="utf-8")
except Exception:
return None
lines = text.splitlines()
front_matter, body_lines = self._parse_front_matter(lines)
stem = path.stem
# 基础字段
id_value = str(front_matter.get("id") or stem)
type_value = str(front_matter.get("type") or "event")
# level: 优先 front matter其次目录结构推断
level_value = str(front_matter.get("level") or self._infer_level_from_path(path))
# 标题:优先 front matter.title其次正文第一个 '# ' 标题
title = front_matter.get("title") or self._extract_title_from_body(body_lines) or stem
# 标签:支持 front matter.tags 或 '## 标签' 段
tags = front_matter.get("tags")
if isinstance(tags, str):
tags = [t.strip() for t in re.split(r"[,\s]+", tags) if t.strip()]
elif isinstance(tags, list):
tags = [str(t) for t in tags]
else:
tags = self._extract_tags_from_body(body_lines)
# 时间front matter.timestamp/time → 正文 '## 时间' → 文件名/mtime 兜底
ts_str = front_matter.get("timestamp") or front_matter.get("time")
timestamp: Optional[dt.datetime] = None
if isinstance(ts_str, str):
timestamp = self._parse_datetime(ts_str)
if timestamp is None:
body_time = self._extract_time_from_body(body_lines)
if body_time:
timestamp = self._parse_datetime(body_time)
if timestamp is None:
timestamp = self._infer_datetime_from_filename_or_mtime(path)
if timestamp is None:
# 无法推断时间的事件对查询意义有限,忽略该文件
return None
# date_bucket: 优先 front matter.date_bucket其次目录结构 / 时间推断
date_bucket = front_matter.get("date_bucket") or self._infer_date_bucket(path, timestamp)
related = front_matter.get("related") or []
if isinstance(related, str):
related = [related]
elif not isinstance(related, list):
related = []
# meta: 保留所有未被提升为显式字段的 front matter 信息
meta: Dict[str, Any] = dict(front_matter)
for k in [
"id",
"type",
"level",
"title",
"tags",
"timestamp",
"time",
"date_bucket",
"related",
]:
meta.pop(k, None)
return MemoryEvent(
id=id_value,
type=type_value,
level=level_value,
timestamp=timestamp,
date_bucket=str(date_bucket),
path=path,
title=str(title),
tags=list(tags or []),
related=list(related),
meta=meta,
)
def _parse_front_matter(self, lines: List[str]):
"""解析 YAML front matter如果不存在则返回空字典和原始行。"""
if not lines:
return {}, []
if lines[0].strip() != "---":
return {}, lines
for i in range(1, len(lines)):
if lines[i].strip() == "---":
fm_text = "\n".join(lines[1:i])
try:
data = yaml.safe_load(fm_text) or {}
except Exception:
data = {}
return data, lines[i + 1 :]
# 未找到结束分隔符,视为无 front matter
return {}, lines
@staticmethod
def _extract_title_from_body(body_lines: List[str]) -> Optional[str]:
for line in body_lines:
s = line.strip()
if s.startswith("# "):
return s[2:].strip()
return None
@staticmethod
def _extract_time_from_body(body_lines: List[str]) -> Optional[str]:
for i, line in enumerate(body_lines):
if line.strip().startswith("## 时间"):
for j in range(i + 1, len(body_lines)):
value = body_lines[j].strip()
if value:
return value
break
return None
@staticmethod
def _extract_tags_from_body(body_lines: List[str]) -> List[str]:
for i, line in enumerate(body_lines):
if line.strip().startswith("## 标签"):
for j in range(i + 1, len(body_lines)):
raw = body_lines[j].strip()
if not raw:
continue
parts = [p.strip() for p in re.split(r"[,\s]+", raw) if p.strip()]
return parts
break
return []
def _infer_level_from_path(self, path: Path) -> str:
"""根据相对路径推断级别: year/month/day/event。"""
try:
rel = path.relative_to(self.episodic_root)
except ValueError:
return "event"
parts = rel.parts
if len(parts) >= 3 and parts[0].isdigit() and parts[1].isdigit() and parts[2].isdigit():
return "day"
if len(parts) >= 2 and parts[0].isdigit() and parts[1].isdigit():
return "month"
if len(parts) >= 1 and parts[0].isdigit():
return "year"
return "event"
def _infer_date_bucket(self, path: Path, ts: dt.datetime) -> str:
"""推断 date_bucket例如 "2025/11/14""""
try:
rel = path.relative_to(self.episodic_root)
parts = rel.parts
if len(parts) >= 3 and parts[0].isdigit() and parts[1].isdigit() and parts[2].isdigit():
return f"{parts[0]}/{parts[1]}/{parts[2]}"
except ValueError:
pass
return ts.date().isoformat()
def _infer_datetime_from_filename_or_mtime(self, path: Path) -> Optional[dt.datetime]:
"""从文件名 (YYYYMMDD-HHMM) 或 mtime 推断时间。"""
m = re.match(r"(\d{8})-(\d{4})", path.stem)
if m:
date_str, hm = m.groups()
try:
return dt.datetime.strptime(date_str + hm, "%Y%m%d%H%M")
except Exception:
pass
m2 = re.match(r"(\d{4})(\d{2})(\d{2})", path.stem)
if m2:
y, mth, d = m2.groups()
try:
return dt.datetime.strptime(f"{y}{mth}{d}", "%Y%m%d")
except Exception:
pass
try:
return dt.datetime.fromtimestamp(path.stat().st_mtime)
except Exception:
return None
# ------------------------------------------------------------------
# SQL 风格查询接口
# ------------------------------------------------------------------
def query(
self,
where: Optional[str] = None,
order_by: Optional[str] = None,
limit: Optional[int] = None,
offset: int = 0,
) -> List[MemoryEvent]:
"""基于 SQL 风格参数查询事件列表。"""
events: List[MemoryEvent] = list(self.events)
if where:
events = list(self._apply_where(events, where))
if order_by:
events = self._apply_order_by(events, order_by)
if offset:
events = events[offset:]
if limit is not None:
events = events[:limit]
return events
def _apply_where(
self, events: Iterable[MemoryEvent], where: str
) -> Iterable[MemoryEvent]:
"""简易 WHERE 解析,仅支持 AND运算符子集。
支持的形式:
- field = 'value' / != / >= / <=
- tags CONTAINS 'tag'
- 通过 AND 连接多个条件(不支持 OR / 括号)
"""
conditions = [part.strip() for part in re.split(r"\s+AND\s+", where, flags=re.I) if part.strip()]
def match(event: MemoryEvent) -> bool:
for cond in conditions:
if not self._eval_condition(event, cond):
return False
return True
return (e for e in events if match(e))
def _eval_condition(self, event: MemoryEvent, cond: str) -> bool:
# tags CONTAINS 'tag'
if re.search(r"\bCONTAINS\b", cond, flags=re.I):
left, right = re.split(r"\bCONTAINS\b", cond, maxsplit=1, flags=re.I)
field = left.strip()
value = self._strip_quotes(right.strip())
if field.lower() != "tags":
return False
return value in (event.tags or [])
# field op value
m = re.match(r"^(\w+)\s*(=|!=|>=|<=)\s*(.+)$", cond)
if not m:
return False
field, op, raw_value = m.groups()
field = field.strip().lower()
value = self._strip_quotes(raw_value.strip())
# 取事件属性或 meta 字段
lhs: Any
if field == "id":
lhs = event.id
elif field == "type":
lhs = event.type
elif field == "level":
lhs = event.level
elif field == "title":
lhs = event.title
elif field == "date":
lhs = event.date
elif field == "timestamp":
lhs = event.timestamp
else:
if field in event.meta:
lhs = event.meta[field]
else:
# 未知字段直接返回 False避免误匹配
return False
# 时间 / 日期字段支持 >= <=
if isinstance(lhs, dt.datetime):
rhs = self._parse_datetime(value)
if rhs is None:
return False
else:
rhs = value
try:
if op == "=":
return lhs == rhs
if op == "!=":
return lhs != rhs
if op == ">=":
return lhs >= rhs
if op == "<=":
return lhs <= rhs
except TypeError:
return False
return False
@staticmethod
def _strip_quotes(text: str) -> str:
if (text.startswith("'") and text.endswith("'")) or (
text.startswith('"') and text.endswith('"')
):
return text[1:-1]
return text
@staticmethod
def _parse_datetime(value: str) -> Optional[dt.datetime]:
# 支持 "YYYY-MM-DD" 或 ISO8601 字符串
try:
if len(value) == 10:
return dt.datetime.fromisoformat(value + "T00:00:00")
return dt.datetime.fromisoformat(value)
except Exception:
return None
# ------------------------------------------------------------------
# 格式化输出
# ------------------------------------------------------------------
def format_events(
self,
events: List[MemoryEvent],
select: Optional[List[str]] = None,
format_type: str = "table",
) -> str:
if select is None or not select:
select = ["id", "timestamp", "title"]
rows = []
for ev in events:
d = ev.to_dict()
rows.append({field: d.get(field) for field in select})
if format_type == "json":
return json.dumps(rows, ensure_ascii=False, indent=2)
# table 格式
return self._format_table(rows, select)
@staticmethod
def _format_table(rows: List[Dict[str, Any]], headers: List[str]) -> str:
if not rows:
return "(no events)"
# 计算列宽
widths: Dict[str, int] = {}
for h in headers:
widths[h] = max(len(h), *(len(str(row.get(h, ""))) for row in rows))
def fmt_row(row: Dict[str, Any]) -> str:
return " ".join(str(row.get(h, "")).ljust(widths[h]) for h in headers)
header_line = " ".join(h.ljust(widths[h]) for h in headers)
sep_line = " ".join("-" * widths[h] for h in headers)
data_lines = [fmt_row(r) for r in rows]
return "\n".join([header_line, sep_line, *data_lines])