170 lines
6.5 KiB
Python
170 lines
6.5 KiB
Python
from pathlib import Path
|
|
|
|
from memabra.persistence import PersistenceStore
|
|
from memabra.artifact_index import ArtifactIndex
|
|
|
|
|
|
def _make_trajectory(
|
|
trajectory_id: str,
|
|
*,
|
|
status: str = "success",
|
|
decision_type: str = "direct_answer",
|
|
channel: str = "local",
|
|
reward_total: float = 1.0,
|
|
latency_ms: int = 100,
|
|
tool_errors: int = 0,
|
|
user_corrections: int = 0,
|
|
input_text: str = "Hello",
|
|
created_at: str = "2026-01-15T10:00:00Z",
|
|
):
|
|
return {
|
|
"trajectory_id": trajectory_id,
|
|
"task": {
|
|
"task_id": f"task-{trajectory_id}",
|
|
"input": input_text,
|
|
"channel": channel,
|
|
"created_at": created_at,
|
|
"user_id": None,
|
|
},
|
|
"context_snapshot": {"conversation_summary": "", "environment_summary": "", "recent_failures": []},
|
|
"candidate_sets": {"memory": [], "skill": [], "tool": []},
|
|
"decisions": [
|
|
{
|
|
"step": 1,
|
|
"decision_type": decision_type,
|
|
"selected_ids": [],
|
|
"selected_payloads": [],
|
|
"rejected_ids": [],
|
|
"rationale": "",
|
|
"estimated_cost": 0.0,
|
|
}
|
|
],
|
|
"events": [],
|
|
"outcome": {
|
|
"status": status,
|
|
"steps": 1,
|
|
"latency_ms": latency_ms,
|
|
"user_corrections": user_corrections,
|
|
"tool_errors": tool_errors,
|
|
"notes": None,
|
|
},
|
|
"reward": {
|
|
"total": reward_total,
|
|
"components": {
|
|
"task_success": 1.0 if status == "success" else 0.0,
|
|
"retrieval_hit": 0.0,
|
|
"tool_error": 0.1 * tool_errors,
|
|
"user_correction": 0.1 * user_corrections,
|
|
"latency": 0.0,
|
|
"context_cost": 0.0,
|
|
"useful_reuse": 0.0,
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
def test_artifact_index_lists_all_trajectories(tmp_path: Path):
|
|
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
|
|
persistence.save_trajectory(_make_trajectory("traj-1", status="success"))
|
|
persistence.save_trajectory(_make_trajectory("traj-2", status="failure"))
|
|
|
|
index = ArtifactIndex(persistence_store=persistence)
|
|
results = index.query()
|
|
|
|
assert len(results) == 2
|
|
assert {r["trajectory_id"] for r in results} == {"traj-1", "traj-2"}
|
|
|
|
|
|
def test_artifact_index_filters_by_status(tmp_path: Path):
|
|
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
|
|
persistence.save_trajectory(_make_trajectory("traj-1", status="success"))
|
|
persistence.save_trajectory(_make_trajectory("traj-2", status="failure"))
|
|
persistence.save_trajectory(_make_trajectory("traj-3", status="partial_success"))
|
|
|
|
index = ArtifactIndex(persistence_store=persistence)
|
|
successes = index.query(status="success")
|
|
failures = index.query(status="failure")
|
|
|
|
assert len(successes) == 1
|
|
assert successes[0]["trajectory_id"] == "traj-1"
|
|
assert len(failures) == 1
|
|
assert failures[0]["trajectory_id"] == "traj-2"
|
|
|
|
|
|
def test_artifact_index_filters_by_reward_range(tmp_path: Path):
|
|
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
|
|
persistence.save_trajectory(_make_trajectory("traj-1", reward_total=0.9))
|
|
persistence.save_trajectory(_make_trajectory("traj-2", reward_total=0.5))
|
|
persistence.save_trajectory(_make_trajectory("traj-3", reward_total=-0.2))
|
|
|
|
index = ArtifactIndex(persistence_store=persistence)
|
|
high = index.query(min_reward=0.6)
|
|
low = index.query(max_reward=0.0)
|
|
|
|
assert len(high) == 1 and high[0]["trajectory_id"] == "traj-1"
|
|
assert len(low) == 1 and low[0]["trajectory_id"] == "traj-3"
|
|
|
|
|
|
def test_artifact_index_filters_by_decision_type_and_channel(tmp_path: Path):
|
|
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
|
|
persistence.save_trajectory(_make_trajectory("traj-1", decision_type="direct_answer", channel="local"))
|
|
persistence.save_trajectory(_make_trajectory("traj-2", decision_type="call_tool", channel="telegram"))
|
|
|
|
index = ArtifactIndex(persistence_store=persistence)
|
|
tools = index.query(decision_type="call_tool")
|
|
telegram = index.query(channel="telegram")
|
|
|
|
assert len(tools) == 1 and tools[0]["trajectory_id"] == "traj-2"
|
|
assert len(telegram) == 1 and telegram[0]["trajectory_id"] == "traj-2"
|
|
|
|
|
|
def test_artifact_index_filters_by_tool_errors_and_user_corrections(tmp_path: Path):
|
|
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
|
|
persistence.save_trajectory(_make_trajectory("traj-1", tool_errors=0, user_corrections=0))
|
|
persistence.save_trajectory(_make_trajectory("traj-2", tool_errors=2, user_corrections=1))
|
|
|
|
index = ArtifactIndex(persistence_store=persistence)
|
|
with_errors = index.query(min_tool_errors=1)
|
|
with_corrections = index.query(min_user_corrections=1)
|
|
|
|
assert len(with_errors) == 1 and with_errors[0]["trajectory_id"] == "traj-2"
|
|
assert len(with_corrections) == 1 and with_corrections[0]["trajectory_id"] == "traj-2"
|
|
|
|
|
|
def test_artifact_index_filters_by_input_text(tmp_path: Path):
|
|
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
|
|
persistence.save_trajectory(_make_trajectory("traj-1", input_text="Deploy the service"))
|
|
persistence.save_trajectory(_make_trajectory("traj-2", input_text="Check status"))
|
|
|
|
index = ArtifactIndex(persistence_store=persistence)
|
|
deploy = index.query(input_contains="deploy")
|
|
status = index.query(input_contains="STATUS")
|
|
|
|
assert len(deploy) == 1 and deploy[0]["trajectory_id"] == "traj-1"
|
|
assert len(status) == 1 and status[0]["trajectory_id"] == "traj-2"
|
|
|
|
|
|
def test_artifact_index_slice_dataset_returns_ids(tmp_path: Path):
|
|
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
|
|
persistence.save_trajectory(_make_trajectory("traj-1", status="success", reward_total=0.9))
|
|
persistence.save_trajectory(_make_trajectory("traj-2", status="failure", reward_total=-0.1))
|
|
persistence.save_trajectory(_make_trajectory("traj-3", status="success", reward_total=0.95))
|
|
|
|
index = ArtifactIndex(persistence_store=persistence)
|
|
slice_ids = index.slice_dataset(status="success", min_reward=0.8)
|
|
|
|
assert slice_ids == ["traj-1", "traj-3"]
|
|
|
|
|
|
def test_artifact_index_refresh_picks_up_new_files(tmp_path: Path):
|
|
persistence = PersistenceStore(base_dir=tmp_path / "artifacts")
|
|
persistence.save_trajectory(_make_trajectory("traj-1"))
|
|
|
|
index = ArtifactIndex(persistence_store=persistence)
|
|
assert len(index.query()) == 1
|
|
|
|
persistence.save_trajectory(_make_trajectory("traj-2"))
|
|
index.refresh()
|
|
|
|
assert len(index.query()) == 2
|