from pathlib import Path from memabra.persistence import PersistenceStore from memabra.artifact_index import ArtifactIndex def _make_trajectory( trajectory_id: str, *, status: str = "success", decision_type: str = "direct_answer", channel: str = "local", reward_total: float = 1.0, latency_ms: int = 100, tool_errors: int = 0, user_corrections: int = 0, input_text: str = "Hello", created_at: str = "2026-01-15T10:00:00Z", ): return { "trajectory_id": trajectory_id, "task": { "task_id": f"task-{trajectory_id}", "input": input_text, "channel": channel, "created_at": created_at, "user_id": None, }, "context_snapshot": {"conversation_summary": "", "environment_summary": "", "recent_failures": []}, "candidate_sets": {"memory": [], "skill": [], "tool": []}, "decisions": [ { "step": 1, "decision_type": decision_type, "selected_ids": [], "selected_payloads": [], "rejected_ids": [], "rationale": "", "estimated_cost": 0.0, } ], "events": [], "outcome": { "status": status, "steps": 1, "latency_ms": latency_ms, "user_corrections": user_corrections, "tool_errors": tool_errors, "notes": None, }, "reward": { "total": reward_total, "components": { "task_success": 1.0 if status == "success" else 0.0, "retrieval_hit": 0.0, "tool_error": 0.1 * tool_errors, "user_correction": 0.1 * user_corrections, "latency": 0.0, "context_cost": 0.0, "useful_reuse": 0.0, }, }, } def test_artifact_index_lists_all_trajectories(tmp_path: Path): persistence = PersistenceStore(base_dir=tmp_path / "artifacts") persistence.save_trajectory(_make_trajectory("traj-1", status="success")) persistence.save_trajectory(_make_trajectory("traj-2", status="failure")) index = ArtifactIndex(persistence_store=persistence) results = index.query() assert len(results) == 2 assert {r["trajectory_id"] for r in results} == {"traj-1", "traj-2"} def test_artifact_index_filters_by_status(tmp_path: Path): persistence = PersistenceStore(base_dir=tmp_path / "artifacts") persistence.save_trajectory(_make_trajectory("traj-1", status="success")) persistence.save_trajectory(_make_trajectory("traj-2", status="failure")) persistence.save_trajectory(_make_trajectory("traj-3", status="partial_success")) index = ArtifactIndex(persistence_store=persistence) successes = index.query(status="success") failures = index.query(status="failure") assert len(successes) == 1 assert successes[0]["trajectory_id"] == "traj-1" assert len(failures) == 1 assert failures[0]["trajectory_id"] == "traj-2" def test_artifact_index_filters_by_reward_range(tmp_path: Path): persistence = PersistenceStore(base_dir=tmp_path / "artifacts") persistence.save_trajectory(_make_trajectory("traj-1", reward_total=0.9)) persistence.save_trajectory(_make_trajectory("traj-2", reward_total=0.5)) persistence.save_trajectory(_make_trajectory("traj-3", reward_total=-0.2)) index = ArtifactIndex(persistence_store=persistence) high = index.query(min_reward=0.6) low = index.query(max_reward=0.0) assert len(high) == 1 and high[0]["trajectory_id"] == "traj-1" assert len(low) == 1 and low[0]["trajectory_id"] == "traj-3" def test_artifact_index_filters_by_decision_type_and_channel(tmp_path: Path): persistence = PersistenceStore(base_dir=tmp_path / "artifacts") persistence.save_trajectory(_make_trajectory("traj-1", decision_type="direct_answer", channel="local")) persistence.save_trajectory(_make_trajectory("traj-2", decision_type="call_tool", channel="telegram")) index = ArtifactIndex(persistence_store=persistence) tools = index.query(decision_type="call_tool") telegram = index.query(channel="telegram") assert len(tools) == 1 and tools[0]["trajectory_id"] == "traj-2" assert len(telegram) == 1 and telegram[0]["trajectory_id"] == "traj-2" def test_artifact_index_filters_by_tool_errors_and_user_corrections(tmp_path: Path): persistence = PersistenceStore(base_dir=tmp_path / "artifacts") persistence.save_trajectory(_make_trajectory("traj-1", tool_errors=0, user_corrections=0)) persistence.save_trajectory(_make_trajectory("traj-2", tool_errors=2, user_corrections=1)) index = ArtifactIndex(persistence_store=persistence) with_errors = index.query(min_tool_errors=1) with_corrections = index.query(min_user_corrections=1) assert len(with_errors) == 1 and with_errors[0]["trajectory_id"] == "traj-2" assert len(with_corrections) == 1 and with_corrections[0]["trajectory_id"] == "traj-2" def test_artifact_index_filters_by_input_text(tmp_path: Path): persistence = PersistenceStore(base_dir=tmp_path / "artifacts") persistence.save_trajectory(_make_trajectory("traj-1", input_text="Deploy the service")) persistence.save_trajectory(_make_trajectory("traj-2", input_text="Check status")) index = ArtifactIndex(persistence_store=persistence) deploy = index.query(input_contains="deploy") status = index.query(input_contains="STATUS") assert len(deploy) == 1 and deploy[0]["trajectory_id"] == "traj-1" assert len(status) == 1 and status[0]["trajectory_id"] == "traj-2" def test_artifact_index_slice_dataset_returns_ids(tmp_path: Path): persistence = PersistenceStore(base_dir=tmp_path / "artifacts") persistence.save_trajectory(_make_trajectory("traj-1", status="success", reward_total=0.9)) persistence.save_trajectory(_make_trajectory("traj-2", status="failure", reward_total=-0.1)) persistence.save_trajectory(_make_trajectory("traj-3", status="success", reward_total=0.95)) index = ArtifactIndex(persistence_store=persistence) slice_ids = index.slice_dataset(status="success", min_reward=0.8) assert slice_ids == ["traj-1", "traj-3"] def test_artifact_index_refresh_picks_up_new_files(tmp_path: Path): persistence = PersistenceStore(base_dir=tmp_path / "artifacts") persistence.save_trajectory(_make_trajectory("traj-1")) index = ArtifactIndex(persistence_store=persistence) assert len(index.query()) == 1 persistence.save_trajectory(_make_trajectory("traj-2")) index.refresh() assert len(index.query()) == 2