Files
memabra/tests/test_artifact_index.py
2026-04-15 11:06:05 +08:00

170 lines
6.5 KiB
Python

from pathlib import Path
from memabra.persistence import PersistenceStore
from memabra.artifact_index import ArtifactIndex
def _make_trajectory(
trajectory_id: str,
*,
status: str = "success",
decision_type: str = "direct_answer",
channel: str = "local",
reward_total: float = 1.0,
latency_ms: int = 100,
tool_errors: int = 0,
user_corrections: int = 0,
input_text: str = "Hello",
created_at: str = "2026-01-15T10:00:00Z",
):
return {
"trajectory_id": trajectory_id,
"task": {
"task_id": f"task-{trajectory_id}",
"input": input_text,
"channel": channel,
"created_at": created_at,
"user_id": None,
},
"context_snapshot": {"conversation_summary": "", "environment_summary": "", "recent_failures": []},
"candidate_sets": {"memory": [], "skill": [], "tool": []},
"decisions": [
{
"step": 1,
"decision_type": decision_type,
"selected_ids": [],
"selected_payloads": [],
"rejected_ids": [],
"rationale": "",
"estimated_cost": 0.0,
}
],
"events": [],
"outcome": {
"status": status,
"steps": 1,
"latency_ms": latency_ms,
"user_corrections": user_corrections,
"tool_errors": tool_errors,
"notes": None,
},
"reward": {
"total": reward_total,
"components": {
"task_success": 1.0 if status == "success" else 0.0,
"retrieval_hit": 0.0,
"tool_error": 0.1 * tool_errors,
"user_correction": 0.1 * user_corrections,
"latency": 0.0,
"context_cost": 0.0,
"useful_reuse": 0.0,
},
},
}
def test_artifact_index_lists_all_trajectories(tmp_path: Path):
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
persistence.save_trajectory(_make_trajectory("traj-1", status="success"))
persistence.save_trajectory(_make_trajectory("traj-2", status="failure"))
index = ArtifactIndex(persistence_store=persistence)
results = index.query()
assert len(results) == 2
assert {r["trajectory_id"] for r in results} == {"traj-1", "traj-2"}
def test_artifact_index_filters_by_status(tmp_path: Path):
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
persistence.save_trajectory(_make_trajectory("traj-1", status="success"))
persistence.save_trajectory(_make_trajectory("traj-2", status="failure"))
persistence.save_trajectory(_make_trajectory("traj-3", status="partial_success"))
index = ArtifactIndex(persistence_store=persistence)
successes = index.query(status="success")
failures = index.query(status="failure")
assert len(successes) == 1
assert successes[0]["trajectory_id"] == "traj-1"
assert len(failures) == 1
assert failures[0]["trajectory_id"] == "traj-2"
def test_artifact_index_filters_by_reward_range(tmp_path: Path):
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
persistence.save_trajectory(_make_trajectory("traj-1", reward_total=0.9))
persistence.save_trajectory(_make_trajectory("traj-2", reward_total=0.5))
persistence.save_trajectory(_make_trajectory("traj-3", reward_total=-0.2))
index = ArtifactIndex(persistence_store=persistence)
high = index.query(min_reward=0.6)
low = index.query(max_reward=0.0)
assert len(high) == 1 and high[0]["trajectory_id"] == "traj-1"
assert len(low) == 1 and low[0]["trajectory_id"] == "traj-3"
def test_artifact_index_filters_by_decision_type_and_channel(tmp_path: Path):
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
persistence.save_trajectory(_make_trajectory("traj-1", decision_type="direct_answer", channel="local"))
persistence.save_trajectory(_make_trajectory("traj-2", decision_type="call_tool", channel="telegram"))
index = ArtifactIndex(persistence_store=persistence)
tools = index.query(decision_type="call_tool")
telegram = index.query(channel="telegram")
assert len(tools) == 1 and tools[0]["trajectory_id"] == "traj-2"
assert len(telegram) == 1 and telegram[0]["trajectory_id"] == "traj-2"
def test_artifact_index_filters_by_tool_errors_and_user_corrections(tmp_path: Path):
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
persistence.save_trajectory(_make_trajectory("traj-1", tool_errors=0, user_corrections=0))
persistence.save_trajectory(_make_trajectory("traj-2", tool_errors=2, user_corrections=1))
index = ArtifactIndex(persistence_store=persistence)
with_errors = index.query(min_tool_errors=1)
with_corrections = index.query(min_user_corrections=1)
assert len(with_errors) == 1 and with_errors[0]["trajectory_id"] == "traj-2"
assert len(with_corrections) == 1 and with_corrections[0]["trajectory_id"] == "traj-2"
def test_artifact_index_filters_by_input_text(tmp_path: Path):
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
persistence.save_trajectory(_make_trajectory("traj-1", input_text="Deploy the service"))
persistence.save_trajectory(_make_trajectory("traj-2", input_text="Check status"))
index = ArtifactIndex(persistence_store=persistence)
deploy = index.query(input_contains="deploy")
status = index.query(input_contains="STATUS")
assert len(deploy) == 1 and deploy[0]["trajectory_id"] == "traj-1"
assert len(status) == 1 and status[0]["trajectory_id"] == "traj-2"
def test_artifact_index_slice_dataset_returns_ids(tmp_path: Path):
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
persistence.save_trajectory(_make_trajectory("traj-1", status="success", reward_total=0.9))
persistence.save_trajectory(_make_trajectory("traj-2", status="failure", reward_total=-0.1))
persistence.save_trajectory(_make_trajectory("traj-3", status="success", reward_total=0.95))
index = ArtifactIndex(persistence_store=persistence)
slice_ids = index.slice_dataset(status="success", min_reward=0.8)
assert slice_ids == ["traj-1", "traj-3"]
def test_artifact_index_refresh_picks_up_new_files(tmp_path: Path):
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
persistence.save_trajectory(_make_trajectory("traj-1"))
index = ArtifactIndex(persistence_store=persistence)
assert len(index.query()) == 1
persistence.save_trajectory(_make_trajectory("traj-2"))
index.refresh()
assert len(index.query()) == 2