Files
memabra/tests/test_cli_workflow.py
2026-04-15 11:06:05 +08:00

575 lines
20 KiB
Python

from pathlib import Path
from memabra.cli import format_output, run_online_learning_workflow, run_wrapup_workflow
def test_run_wrapup_workflow_trains_evaluates_and_versions_router(tmp_path: Path):
result = run_wrapup_workflow(base_dir=tmp_path / "demo-artifacts")
assert result["seed_summary"]["trajectories"] >= 3
assert "baseline" in result["comparison"]
assert "challenger" in result["comparison"]
assert result["saved_version"]["version_id"]
assert (tmp_path / "demo-artifacts" / "router-versions" / "current.json").exists()
def test_run_online_learning_workflow_runs_cycle_and_returns_report(tmp_path: Path):
result = run_online_learning_workflow(base_dir=tmp_path / "demo-artifacts")
assert "skipped" in result
assert "report_id" in result
# Since it seeds tasks, it should not skip
assert result["skipped"] is False
assert result["promoted"] is True
assert (tmp_path / "demo-artifacts" / "training-reports").exists()
def test_format_output_workflow_text_includes_decision_reason_and_dry_run():
payload = {
"report_id": "report-123",
"skipped": False,
"promoted": False,
"dry_run": True,
"decision": {
"accepted": False,
"reasons": ["Reward delta too small", "Latency increased"],
"metrics": {
"reward_delta": -0.12,
"error_rate_delta": 0.02,
"latency_delta_ms": 12.5,
},
},
"baseline_metrics": {
"avg_reward": 1.0,
"error_rate": 0.1,
"avg_latency_ms": 120.0,
},
"challenger_metrics": {
"avg_reward": 0.88,
"error_rate": 0.12,
"avg_latency_ms": 132.5,
},
}
rendered = format_output(payload, output_format="text", mode="workflow")
assert "Memabra online learning result" in rendered
assert "Summary" in rendered
assert "Report ID: report-123" in rendered
assert "Skipped: no" in rendered
assert "Promoted: no" in rendered
assert "Dry run: yes" in rendered
assert "Baseline" in rendered
assert "Reward: 1.0000" in rendered
assert "Error rate: 0.1000" in rendered
assert "Latency (ms): 120.0000" in rendered
assert "Challenger" in rendered
assert "Reward: 0.8800" in rendered
assert "Deltas" in rendered
assert "Reward delta: -0.1200" in rendered
assert "Error rate delta: 0.0200" in rendered
assert "Latency delta (ms): 12.5000" in rendered
assert "Decision" in rendered
assert "Reason: Reward delta too small; Latency increased" in rendered
def test_format_output_workflow_text_includes_error_details():
payload = {
"report_id": "report-err",
"skipped": False,
"promoted": False,
"error": "benchmark crashed",
}
rendered = format_output(payload, output_format="text", mode="workflow")
assert "Error: benchmark crashed" in rendered
def test_format_output_status_text_includes_latest_report_details():
payload = {
"base_dir": "/tmp/demo-artifacts",
"current_version_id": "v2",
"version_count": 2,
"trajectory_count": 8,
"report_count": 3,
"latest_report": {
"report_id": "report-9",
"timestamp": "2026-04-15T06:00:00+00:00",
"promoted": True,
},
}
rendered = format_output(payload, output_format="text", mode="status")
assert "Memabra status" in rendered
assert "Current version: v2" in rendered
assert "Latest report: report-9" in rendered
assert "Latest report time: 2026-04-15T06:00:00+00:00" in rendered
assert "Latest promotion accepted: yes" in rendered
def test_format_output_list_versions_text_marks_current_version():
payload = {
"current_version_id": "v2",
"versions": [
{"version_id": "v1", "metadata": {"source": "seed", "avg_reward": 1.2}},
{"version_id": "v2", "metadata": {"source": "online_learning", "avg_reward": 1.4}},
],
}
rendered = format_output(payload, output_format="text", mode="list_versions")
assert "Saved router versions (2 total)" in rendered
assert "Current version: v2" in rendered
assert "1. v1 (source=seed, avg_reward=1.2)" in rendered
assert "2. v2 (current, source=online_learning, avg_reward=1.4)" in rendered
def test_main_entrypoint_uses_online_learning_workflow(monkeypatch):
from memabra import cli
calls = []
def mock_online_learning_workflow(*, base_dir=None, min_new_trajectories=3, seen_trajectory_store=None, **kwargs):
calls.append({"base_dir": str(base_dir), "min_new_trajectories": min_new_trajectories, "seen_trajectory_store": seen_trajectory_store})
return {"skipped": False, "promoted": True, "report_id": "report-test"}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
rc = cli.main()
assert rc == 0
assert len(calls) == 1
assert calls[0]["min_new_trajectories"] == 3
def test_main_entrypoint_parses_base_dir_argument(monkeypatch):
from memabra import cli
calls = []
def mock_online_learning_workflow(*, base_dir=None, min_new_trajectories=3, seen_trajectory_store=None, **kwargs):
calls.append({"base_dir": str(base_dir) if base_dir else None, "min_new_trajectories": min_new_trajectories, "seen_trajectory_store": seen_trajectory_store})
return {"skipped": False, "promoted": True, "report_id": "report-test"}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
rc = cli.main(["--base-dir", "/custom/path"])
assert rc == 0
assert len(calls) == 1
assert calls[0]["base_dir"] == "/custom/path"
def test_main_entrypoint_parses_min_new_trajectories_argument(monkeypatch):
from memabra import cli
calls = []
def mock_online_learning_workflow(*, base_dir=None, min_new_trajectories=3, seen_trajectory_store=None, **kwargs):
calls.append({"base_dir": str(base_dir) if base_dir else None, "min_new_trajectories": min_new_trajectories, "seen_trajectory_store": seen_trajectory_store})
return {"skipped": False, "promoted": True, "report_id": "report-test"}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
rc = cli.main(["--min-new-trajectories", "10"])
assert rc == 0
assert len(calls) == 1
assert calls[0]["min_new_trajectories"] == 10
def test_run_online_learning_workflow_skips_on_second_run_when_seen_store_provided(tmp_path: Path):
base_dir = tmp_path / "demo-artifacts"
seen_store = tmp_path / "seen.json"
result1 = run_online_learning_workflow(
base_dir=base_dir,
min_new_trajectories=1,
seen_trajectory_store=seen_store,
)
assert result1["skipped"] is False
result2 = run_online_learning_workflow(
base_dir=base_dir,
min_new_trajectories=1,
seen_trajectory_store=seen_store,
)
assert result2["skipped"] is True
assert "too few new trajectories" in result2["reason"].lower()
def test_main_entrypoint_passes_default_seen_trajectory_store(monkeypatch):
from memabra import cli
calls = []
def mock_online_learning_workflow(*, base_dir=None, min_new_trajectories=3, seen_trajectory_store=None, dry_run=False, **kwargs):
calls.append({
"base_dir": str(base_dir) if base_dir else None,
"min_new_trajectories": min_new_trajectories,
"seen_trajectory_store": str(seen_trajectory_store) if seen_trajectory_store else None,
"dry_run": dry_run,
})
return {"skipped": False, "promoted": True, "report_id": "report-test"}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
rc = cli.main()
assert rc == 0
assert len(calls) == 1
assert calls[0]["seen_trajectory_store"] is not None
assert "seen-trajectories.json" in calls[0]["seen_trajectory_store"]
assert calls[0]["dry_run"] is False
def test_main_entrypoint_passes_dry_run_flag(monkeypatch):
from memabra import cli
calls = []
def mock_online_learning_workflow(*, base_dir=None, min_new_trajectories=3, seen_trajectory_store=None, dry_run=False, **kwargs):
calls.append({
"base_dir": str(base_dir) if base_dir else None,
"min_new_trajectories": min_new_trajectories,
"seen_trajectory_store": str(seen_trajectory_store) if seen_trajectory_store else None,
"dry_run": dry_run,
"baseline_version": kwargs.get("baseline_version"),
})
return {"skipped": False, "promoted": True, "report_id": "report-test"}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
rc = cli.main(["--dry-run"])
assert rc == 0
assert len(calls) == 1
assert calls[0]["dry_run"] is True
def test_main_entrypoint_passes_baseline_version_flag(monkeypatch):
from memabra import cli
calls = []
def mock_online_learning_workflow(*, base_dir=None, min_new_trajectories=3, seen_trajectory_store=None, dry_run=False, baseline_version=None, **kwargs):
calls.append({
"base_dir": str(base_dir) if base_dir else None,
"min_new_trajectories": min_new_trajectories,
"seen_trajectory_store": str(seen_trajectory_store) if seen_trajectory_store else None,
"dry_run": dry_run,
"baseline_version": baseline_version,
})
return {"skipped": False, "promoted": True, "report_id": "report-test"}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
rc = cli.main(["--baseline-version", "v1"])
assert rc == 0
assert len(calls) == 1
assert calls[0]["baseline_version"] == "v1"
def test_main_entrypoint_supports_text_format_for_workflow(monkeypatch, capsys):
from memabra import cli
def mock_online_learning_workflow(**kwargs):
return {
"skipped": False,
"promoted": False,
"report_id": "report-text",
"dry_run": True,
"decision": {
"accepted": False,
"reasons": ["Dry run requested"],
"metrics": {
"reward_delta": 0.05,
"error_rate_delta": 0.0,
"latency_delta_ms": 4.0,
},
},
"baseline_metrics": {
"avg_reward": 0.8,
"error_rate": 0.1,
"avg_latency_ms": 90.0,
},
"challenger_metrics": {
"avg_reward": 0.85,
"error_rate": 0.1,
"avg_latency_ms": 94.0,
},
}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
rc = cli.main(["--format", "text", "--dry-run"])
captured = capsys.readouterr()
assert rc == 0
assert "Memabra online learning result" in captured.out
assert "Summary" in captured.out
assert "Dry run: yes" in captured.out
assert "Baseline" in captured.out
assert "Reward: 0.8000" in captured.out
assert "Challenger" in captured.out
assert "Reward: 0.8500" in captured.out
assert "Deltas" in captured.out
assert "Reward delta: 0.0500" in captured.out
assert "Reason: Dry run requested" in captured.out
def test_main_entrypoint_passes_case_index_flags(monkeypatch):
from memabra import cli
calls = []
def mock_online_learning_workflow(*, base_dir=None, min_new_trajectories=3, seen_trajectory_store=None, dry_run=False, baseline_version=None, case_index_path=None, rebuild_case_index=False, **kwargs):
calls.append({
"base_dir": str(base_dir) if base_dir else None,
"case_index_path": str(case_index_path) if case_index_path else None,
"rebuild_case_index": rebuild_case_index,
})
return {"skipped": False, "promoted": True, "report_id": "report-test"}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
rc = cli.main(["--case-index", "/tmp/cases.json", "--rebuild-case-index"])
assert rc == 0
assert len(calls) == 1
assert calls[0]["case_index_path"] == "/tmp/cases.json"
assert calls[0]["rebuild_case_index"] is True
def test_run_online_learning_workflow_loads_existing_case_index(tmp_path: Path):
base_dir = tmp_path / "demo-artifacts"
case_index_path = tmp_path / "case-index.json"
# Run once to create trajectories and rebuild case index
result1 = run_online_learning_workflow(base_dir=base_dir, min_new_trajectories=1, rebuild_case_index=True, case_index_path=case_index_path)
assert result1["skipped"] is False
assert case_index_path.exists()
# Second run should load the existing case index
result2 = run_online_learning_workflow(base_dir=base_dir, min_new_trajectories=1, rebuild_case_index=False, case_index_path=case_index_path)
assert result2["skipped"] is False
def test_run_online_learning_workflow_rebuilds_case_index_after_cycle(tmp_path: Path):
base_dir = tmp_path / "demo-artifacts"
case_index_path = tmp_path / "case-index.json"
result = run_online_learning_workflow(
base_dir=base_dir,
min_new_trajectories=1,
case_index_path=case_index_path,
)
assert result["skipped"] is False
assert case_index_path.exists()
from memabra.case_index import CaseIndex
index = CaseIndex.load(case_index_path)
# The benchmark task during the cycle should produce a trajectory that gets indexed
assert index.best("Use my telegram preference for this answer.") is not None
def test_main_entrypoint_defaults_case_index_path_when_rebuild_flag_set(monkeypatch):
from memabra import cli
calls = []
def mock_online_learning_workflow(*, base_dir=None, min_new_trajectories=3, seen_trajectory_store=None, dry_run=False, baseline_version=None, case_index_path=None, rebuild_case_index=False, **kwargs):
calls.append({
"base_dir": str(base_dir) if base_dir else None,
"case_index_path": str(case_index_path) if case_index_path else None,
"rebuild_case_index": rebuild_case_index,
})
return {"skipped": False, "promoted": True, "report_id": "report-test"}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
rc = cli.main(["--rebuild-case-index"])
assert rc == 0
assert len(calls) == 1
assert calls[0]["rebuild_case_index"] is True
assert calls[0]["case_index_path"] is not None
assert "case-index.json" in calls[0]["case_index_path"]
def test_main_status_flag_prints_status_and_skips_workflow(tmp_path: Path, monkeypatch, capsys):
from memabra import cli
workflow_calls = []
def mock_online_learning_workflow(**kwargs):
workflow_calls.append(kwargs)
return {"skipped": False, "promoted": True, "report_id": "report-test"}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
base_dir = tmp_path / "demo-artifacts"
base_dir.mkdir(parents=True, exist_ok=True)
rc = cli.main(["status", "--base-dir", str(base_dir)])
captured = capsys.readouterr()
assert rc == 0
assert len(workflow_calls) == 0
assert "current_version_id" in captured.out
def test_main_status_flag_supports_text_format(tmp_path: Path, monkeypatch, capsys):
from memabra import cli
workflow_calls = []
def mock_online_learning_workflow(**kwargs):
workflow_calls.append(kwargs)
return {"skipped": False, "promoted": True, "report_id": "report-test"}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
base_dir = tmp_path / "demo-artifacts"
base_dir.mkdir(parents=True, exist_ok=True)
rc = cli.main(["status", "--format", "text", "--base-dir", str(base_dir)])
captured = capsys.readouterr()
assert rc == 0
assert len(workflow_calls) == 0
assert "Memabra status" in captured.out
assert "Current version:" in captured.out
assert "Trajectory count:" in captured.out
def test_main_rollback_flag_rolls_back_and_skips_workflow(tmp_path: Path, monkeypatch, capsys):
from memabra import cli
from memabra.router_versioning import RouterVersionStore
workflow_calls = []
rollback_calls = []
def mock_online_learning_workflow(**kwargs):
workflow_calls.append(kwargs)
return {"skipped": False, "promoted": True, "report_id": "report-test"}
def mock_rollback(self, version_id: str):
rollback_calls.append(version_id)
return {"current_version_id": version_id}
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
monkeypatch.setattr(RouterVersionStore, "rollback", mock_rollback)
base_dir = tmp_path / "demo-artifacts"
base_dir.mkdir(parents=True, exist_ok=True)
rc = cli.main(["version", "rollback", "v1", "--base-dir", str(base_dir)])
captured = capsys.readouterr()
assert rc == 0
assert len(workflow_calls) == 0
assert len(rollback_calls) == 1
assert rollback_calls[0] == "v1"
assert "current_version_id" in captured.out
def test_main_rollback_flag_supports_text_format(tmp_path: Path, monkeypatch, capsys):
from memabra import cli
from memabra.router_versioning import RouterVersionStore
def mock_rollback(self, version_id: str):
return {"current_version_id": version_id}
monkeypatch.setattr(RouterVersionStore, "rollback", mock_rollback)
base_dir = tmp_path / "demo-artifacts"
base_dir.mkdir(parents=True, exist_ok=True)
rc = cli.main(["version", "rollback", "v1", "--format", "text", "--base-dir", str(base_dir)])
captured = capsys.readouterr()
assert rc == 0
assert "Rolled back current version to: v1" in captured.out
def test_main_rollback_missing_version_prints_error_and_exits_nonzero(tmp_path: Path, monkeypatch, capsys):
from memabra import cli
from memabra.router_versioning import RouterVersionStore
def mock_rollback(self, version_id: str):
raise ValueError(f"Version '{version_id}' not found.")
monkeypatch.setattr(RouterVersionStore, "rollback", mock_rollback)
base_dir = tmp_path / "demo-artifacts"
base_dir.mkdir(parents=True, exist_ok=True)
rc = cli.main(["version", "rollback", "v99", "--base-dir", str(base_dir)])
captured = capsys.readouterr()
assert rc == 1
assert "not found" in captured.err.lower()
def test_main_list_versions_flag_prints_versions_and_skips_workflow(tmp_path: Path, monkeypatch, capsys):
from memabra import cli
from memabra.router_versioning import RouterVersionStore
workflow_calls = []
def mock_online_learning_workflow(**kwargs):
workflow_calls.append(kwargs)
return {"skipped": False, "promoted": True, "report_id": "report-test"}
def mock_list_versions(self):
return [
{"version_id": "v1", "metadata": {"source": "test"}},
{"version_id": "v2", "metadata": {"source": "test"}},
]
monkeypatch.setattr(cli, "run_online_learning_workflow", mock_online_learning_workflow)
monkeypatch.setattr(RouterVersionStore, "list_versions", mock_list_versions)
base_dir = tmp_path / "demo-artifacts"
base_dir.mkdir(parents=True, exist_ok=True)
rc = cli.main(["version", "list", "--base-dir", str(base_dir)])
captured = capsys.readouterr()
assert rc == 0
assert len(workflow_calls) == 0
assert "v1" in captured.out
assert "v2" in captured.out
def test_main_list_versions_flag_supports_text_format(tmp_path: Path, monkeypatch, capsys):
from memabra import cli
from memabra.router_versioning import RouterVersionStore
def mock_list_versions(self):
return [
{"version_id": "v1", "metadata": {"source": "seed", "avg_reward": 1.2}},
{"version_id": "v2", "metadata": {"source": "online_learning", "avg_reward": 1.4}},
]
def mock_get_current(self):
return {"current_version_id": "v2"}
monkeypatch.setattr(RouterVersionStore, "list_versions", mock_list_versions)
monkeypatch.setattr(RouterVersionStore, "get_current", mock_get_current)
base_dir = tmp_path / "demo-artifacts"
base_dir.mkdir(parents=True, exist_ok=True)
rc = cli.main(["version", "list", "--format", "text", "--base-dir", str(base_dir)])
captured = capsys.readouterr()
assert rc == 0
assert "Saved router versions (2 total)" in captured.out
assert "Current version: v2" in captured.out
assert "2. v2 (current, source=online_learning, avg_reward=1.4)" in captured.out