349 lines
11 KiB
Python
349 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
from memabra.app import build_demo_app
|
|
from memabra.benchmarks import BenchmarkTask
|
|
from memabra.dataset import DatasetBuilder
|
|
from memabra.evaluator import Evaluator
|
|
from memabra.online_learning import OnlineLearningCoordinator
|
|
from memabra.promotion import PromotionPolicy
|
|
from memabra.router_versioning import RouterVersionStore
|
|
|
|
|
|
def _seed_trajectories(app, count: int):
|
|
for i in range(count):
|
|
app.run_task(f"Test task {i}", channel="local")
|
|
|
|
|
|
def test_coordinator_skips_when_too_few_new_trajectories(tmp_path):
|
|
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
|
_seed_trajectories(app, 2)
|
|
|
|
coordinator = OnlineLearningCoordinator(
|
|
app=app,
|
|
policy=PromotionPolicy(
|
|
min_reward_delta=0.01,
|
|
max_error_rate_increase=0.05,
|
|
max_latency_increase_ms=100.0,
|
|
required_task_count=1,
|
|
),
|
|
benchmark_tasks=[BenchmarkTask(user_input="test")],
|
|
min_new_trajectories=5,
|
|
)
|
|
|
|
result = coordinator.run_cycle()
|
|
|
|
assert result["skipped"] is True
|
|
assert "too few new trajectories" in result["reason"].lower()
|
|
|
|
|
|
def test_coordinator_rejects_when_policy_fails(tmp_path):
|
|
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
|
# Seed enough trajectories for training and benchmarking
|
|
_seed_trajectories(app, 10)
|
|
|
|
# Use a very strict policy that will reject any challenger
|
|
policy = PromotionPolicy(
|
|
min_reward_delta=1.0, # impossible to meet
|
|
max_error_rate_increase=0.0,
|
|
max_latency_increase_ms=0.0,
|
|
required_task_count=1,
|
|
)
|
|
|
|
coordinator = OnlineLearningCoordinator(
|
|
app=app,
|
|
policy=policy,
|
|
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
|
min_new_trajectories=1,
|
|
version_store_base_dir=tmp_path / "versions",
|
|
)
|
|
|
|
result = coordinator.run_cycle()
|
|
|
|
assert result["skipped"] is False
|
|
assert result["promoted"] is False
|
|
assert "decision" in result
|
|
assert result["decision"].accepted is False
|
|
|
|
|
|
def test_coordinator_accepts_and_saves_version_when_policy_passes(tmp_path):
|
|
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
|
_seed_trajectories(app, 10)
|
|
|
|
# Lenient policy that should pass
|
|
policy = PromotionPolicy(
|
|
min_reward_delta=-1.0, # always passes
|
|
max_error_rate_increase=1.0,
|
|
max_latency_increase_ms=10000.0,
|
|
required_task_count=1,
|
|
)
|
|
|
|
version_dir = tmp_path / "versions"
|
|
report_dir = tmp_path / "reports"
|
|
coordinator = OnlineLearningCoordinator(
|
|
app=app,
|
|
policy=policy,
|
|
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
|
min_new_trajectories=1,
|
|
version_store_base_dir=version_dir,
|
|
report_store_base_dir=report_dir,
|
|
)
|
|
|
|
result = coordinator.run_cycle()
|
|
|
|
assert result["skipped"] is False
|
|
assert result["promoted"] is True
|
|
assert "version_id" in result
|
|
assert result["decision"].accepted is True
|
|
|
|
# Verify version was saved
|
|
store = RouterVersionStore(base_dir=version_dir)
|
|
versions = store.list_versions()
|
|
assert len(versions) == 1
|
|
assert versions[0]["version_id"] == result["version_id"]
|
|
|
|
# Verify report was saved
|
|
from memabra.training_reports import TrainingReportStore
|
|
report_store = TrainingReportStore(base_dir=report_dir)
|
|
reports = report_store.list_reports()
|
|
assert len(reports) == 1
|
|
assert reports[0]["promoted_version_id"] == result["version_id"]
|
|
|
|
|
|
def test_coordinator_saves_report_on_rejection(tmp_path):
|
|
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
|
_seed_trajectories(app, 10)
|
|
|
|
policy = PromotionPolicy(
|
|
min_reward_delta=1.0,
|
|
max_error_rate_increase=0.0,
|
|
max_latency_increase_ms=0.0,
|
|
required_task_count=1,
|
|
)
|
|
|
|
report_dir = tmp_path / "reports"
|
|
coordinator = OnlineLearningCoordinator(
|
|
app=app,
|
|
policy=policy,
|
|
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
|
min_new_trajectories=1,
|
|
report_store_base_dir=report_dir,
|
|
)
|
|
|
|
result = coordinator.run_cycle()
|
|
|
|
assert result["promoted"] is False
|
|
from memabra.training_reports import TrainingReportStore
|
|
report_store = TrainingReportStore(base_dir=report_dir)
|
|
reports = report_store.list_reports()
|
|
assert len(reports) == 1
|
|
assert reports[0]["promotion_decision"]["accepted"] is False
|
|
|
|
|
|
def test_coordinator_catches_training_exception_and_returns_error_report(tmp_path):
|
|
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
|
_seed_trajectories(app, 10)
|
|
|
|
policy = PromotionPolicy(
|
|
min_reward_delta=-1.0,
|
|
max_error_rate_increase=1.0,
|
|
max_latency_increase_ms=10000.0,
|
|
required_task_count=1,
|
|
)
|
|
|
|
report_dir = tmp_path / "reports"
|
|
coordinator = OnlineLearningCoordinator(
|
|
app=app,
|
|
policy=policy,
|
|
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
|
min_new_trajectories=1,
|
|
report_store_base_dir=report_dir,
|
|
)
|
|
|
|
# Force a training failure by monkeypatching DatasetBuilder.build to raise
|
|
original_build = DatasetBuilder.build
|
|
DatasetBuilder.build = lambda self, trajectories: (_ for _ in ()).throw(RuntimeError("simulated training failure"))
|
|
|
|
try:
|
|
result = coordinator.run_cycle()
|
|
finally:
|
|
DatasetBuilder.build = original_build
|
|
|
|
assert result["skipped"] is False
|
|
assert result["promoted"] is False
|
|
assert "error" in result
|
|
assert "simulated training failure" in result["error"]
|
|
|
|
# Verify error report was saved
|
|
from memabra.training_reports import TrainingReportStore
|
|
report_store = TrainingReportStore(base_dir=report_dir)
|
|
reports = report_store.list_reports()
|
|
assert len(reports) == 1
|
|
assert reports[0]["promotion_decision"]["accepted"] is False
|
|
assert "simulated training failure" in reports[0]["promotion_decision"]["reasons"][0]
|
|
|
|
|
|
def test_coordinator_persists_seen_trajectory_ids_across_restarts(tmp_path):
|
|
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
|
_seed_trajectories(app, 5)
|
|
|
|
policy = PromotionPolicy(
|
|
min_reward_delta=-1.0,
|
|
max_error_rate_increase=1.0,
|
|
max_latency_increase_ms=10000.0,
|
|
required_task_count=1,
|
|
)
|
|
benchmark_tasks = [BenchmarkTask(user_input="Test task 0")]
|
|
seen_store = tmp_path / "seen_trajectories.json"
|
|
version_dir = tmp_path / "versions"
|
|
report_dir = tmp_path / "reports"
|
|
|
|
coordinator1 = OnlineLearningCoordinator(
|
|
app=app,
|
|
policy=policy,
|
|
benchmark_tasks=benchmark_tasks,
|
|
min_new_trajectories=1,
|
|
version_store_base_dir=version_dir,
|
|
report_store_base_dir=report_dir,
|
|
seen_trajectory_store=seen_store,
|
|
)
|
|
result1 = coordinator1.run_cycle()
|
|
assert result1["skipped"] is False
|
|
|
|
# New coordinator instance pointing to same store
|
|
coordinator2 = OnlineLearningCoordinator(
|
|
app=app,
|
|
policy=policy,
|
|
benchmark_tasks=benchmark_tasks,
|
|
min_new_trajectories=1,
|
|
version_store_base_dir=version_dir,
|
|
report_store_base_dir=report_dir,
|
|
seen_trajectory_store=seen_store,
|
|
)
|
|
result2 = coordinator2.run_cycle()
|
|
assert result2["skipped"] is True
|
|
assert "too few new trajectories" in result2["reason"].lower()
|
|
|
|
|
|
def test_coordinator_dry_run_does_not_promote_or_save_version(tmp_path):
|
|
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
|
_seed_trajectories(app, 10)
|
|
|
|
policy = PromotionPolicy(
|
|
min_reward_delta=-1.0,
|
|
max_error_rate_increase=1.0,
|
|
max_latency_increase_ms=10000.0,
|
|
required_task_count=1,
|
|
)
|
|
|
|
version_dir = tmp_path / "versions"
|
|
report_dir = tmp_path / "reports"
|
|
coordinator = OnlineLearningCoordinator(
|
|
app=app,
|
|
policy=policy,
|
|
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
|
min_new_trajectories=1,
|
|
version_store_base_dir=version_dir,
|
|
report_store_base_dir=report_dir,
|
|
)
|
|
|
|
result = coordinator.run_cycle(dry_run=True)
|
|
|
|
assert result["skipped"] is False
|
|
assert result["promoted"] is False
|
|
assert "decision" in result
|
|
assert result["decision"].accepted is True # policy would accept, but dry_run blocks promotion
|
|
|
|
# No version should be saved
|
|
store = RouterVersionStore(base_dir=version_dir)
|
|
assert len(store.list_versions()) == 0
|
|
|
|
# Report should still be saved for audit
|
|
from memabra.training_reports import TrainingReportStore
|
|
|
|
report_store = TrainingReportStore(base_dir=report_dir)
|
|
reports = report_store.list_reports()
|
|
assert len(reports) == 1
|
|
assert reports[0].get("dry_run") is True
|
|
|
|
|
|
def test_coordinator_rebuilds_case_index_when_path_provided(tmp_path):
|
|
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
|
_seed_trajectories(app, 10)
|
|
|
|
policy = PromotionPolicy(
|
|
min_reward_delta=-1.0,
|
|
max_error_rate_increase=1.0,
|
|
max_latency_increase_ms=10000.0,
|
|
required_task_count=1,
|
|
)
|
|
|
|
case_index_path = tmp_path / "case-index.json"
|
|
coordinator = OnlineLearningCoordinator(
|
|
app=app,
|
|
policy=policy,
|
|
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
|
min_new_trajectories=1,
|
|
case_index_path=case_index_path,
|
|
)
|
|
|
|
result = coordinator.run_cycle()
|
|
|
|
assert result["skipped"] is False
|
|
assert case_index_path.exists()
|
|
from memabra.case_index import CaseIndex
|
|
|
|
index = CaseIndex.load(case_index_path)
|
|
assert index.best("Test task 0") is not None
|
|
|
|
|
|
def test_coordinator_uses_specified_baseline_version(tmp_path):
|
|
from memabra.router import SimpleLearningRouter
|
|
|
|
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
|
_seed_trajectories(app, 10)
|
|
|
|
# Save a baseline version with known weights
|
|
baseline_router = SimpleLearningRouter()
|
|
baseline_router._weights = {"call_tool": {"input_length": 0.99}}
|
|
baseline_router._feature_keys = ["input_length"]
|
|
version_dir = tmp_path / "versions"
|
|
store = RouterVersionStore(base_dir=version_dir)
|
|
store.save(baseline_router, version_id="v-baseline", metadata={"note": "baseline"})
|
|
|
|
# Change app's current router to something different
|
|
different_router = SimpleLearningRouter()
|
|
different_router._weights = {"clarify": {"input_length": 0.01}}
|
|
different_router._feature_keys = ["input_length"]
|
|
app.set_router(different_router)
|
|
|
|
policy = PromotionPolicy(
|
|
min_reward_delta=-1.0,
|
|
max_error_rate_increase=1.0,
|
|
max_latency_increase_ms=10000.0,
|
|
required_task_count=1,
|
|
)
|
|
|
|
report_dir = tmp_path / "reports"
|
|
coordinator = OnlineLearningCoordinator(
|
|
app=app,
|
|
policy=policy,
|
|
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
|
min_new_trajectories=1,
|
|
version_store_base_dir=version_dir,
|
|
report_store_base_dir=report_dir,
|
|
)
|
|
|
|
result = coordinator.run_cycle(baseline_version_id="v-baseline")
|
|
|
|
assert result["skipped"] is False
|
|
assert "baseline_metrics" in result
|
|
assert "challenger_metrics" in result
|
|
|
|
# Verify report records the baseline version
|
|
from memabra.training_reports import TrainingReportStore
|
|
|
|
report_store = TrainingReportStore(base_dir=report_dir)
|
|
reports = report_store.list_reports()
|
|
assert len(reports) == 1
|
|
assert reports[0].get("baseline_version_id") == "v-baseline"
|