Files
memabra/tests/test_online_learning.py
2026-04-15 11:06:05 +08:00

349 lines
11 KiB
Python

from __future__ import annotations
from memabra.app import build_demo_app
from memabra.benchmarks import BenchmarkTask
from memabra.dataset import DatasetBuilder
from memabra.evaluator import Evaluator
from memabra.online_learning import OnlineLearningCoordinator
from memabra.promotion import PromotionPolicy
from memabra.router_versioning import RouterVersionStore
def _seed_trajectories(app, count: int):
for i in range(count):
app.run_task(f"Test task {i}", channel="local")
def test_coordinator_skips_when_too_few_new_trajectories(tmp_path):
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
_seed_trajectories(app, 2)
coordinator = OnlineLearningCoordinator(
app=app,
policy=PromotionPolicy(
min_reward_delta=0.01,
max_error_rate_increase=0.05,
max_latency_increase_ms=100.0,
required_task_count=1,
),
benchmark_tasks=[BenchmarkTask(user_input="test")],
min_new_trajectories=5,
)
result = coordinator.run_cycle()
assert result["skipped"] is True
assert "too few new trajectories" in result["reason"].lower()
def test_coordinator_rejects_when_policy_fails(tmp_path):
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
# Seed enough trajectories for training and benchmarking
_seed_trajectories(app, 10)
# Use a very strict policy that will reject any challenger
policy = PromotionPolicy(
min_reward_delta=1.0, # impossible to meet
max_error_rate_increase=0.0,
max_latency_increase_ms=0.0,
required_task_count=1,
)
coordinator = OnlineLearningCoordinator(
app=app,
policy=policy,
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
min_new_trajectories=1,
version_store_base_dir=tmp_path / "versions",
)
result = coordinator.run_cycle()
assert result["skipped"] is False
assert result["promoted"] is False
assert "decision" in result
assert result["decision"].accepted is False
def test_coordinator_accepts_and_saves_version_when_policy_passes(tmp_path):
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
_seed_trajectories(app, 10)
# Lenient policy that should pass
policy = PromotionPolicy(
min_reward_delta=-1.0, # always passes
max_error_rate_increase=1.0,
max_latency_increase_ms=10000.0,
required_task_count=1,
)
version_dir = tmp_path / "versions"
report_dir = tmp_path / "reports"
coordinator = OnlineLearningCoordinator(
app=app,
policy=policy,
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
min_new_trajectories=1,
version_store_base_dir=version_dir,
report_store_base_dir=report_dir,
)
result = coordinator.run_cycle()
assert result["skipped"] is False
assert result["promoted"] is True
assert "version_id" in result
assert result["decision"].accepted is True
# Verify version was saved
store = RouterVersionStore(base_dir=version_dir)
versions = store.list_versions()
assert len(versions) == 1
assert versions[0]["version_id"] == result["version_id"]
# Verify report was saved
from memabra.training_reports import TrainingReportStore
report_store = TrainingReportStore(base_dir=report_dir)
reports = report_store.list_reports()
assert len(reports) == 1
assert reports[0]["promoted_version_id"] == result["version_id"]
def test_coordinator_saves_report_on_rejection(tmp_path):
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
_seed_trajectories(app, 10)
policy = PromotionPolicy(
min_reward_delta=1.0,
max_error_rate_increase=0.0,
max_latency_increase_ms=0.0,
required_task_count=1,
)
report_dir = tmp_path / "reports"
coordinator = OnlineLearningCoordinator(
app=app,
policy=policy,
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
min_new_trajectories=1,
report_store_base_dir=report_dir,
)
result = coordinator.run_cycle()
assert result["promoted"] is False
from memabra.training_reports import TrainingReportStore
report_store = TrainingReportStore(base_dir=report_dir)
reports = report_store.list_reports()
assert len(reports) == 1
assert reports[0]["promotion_decision"]["accepted"] is False
def test_coordinator_catches_training_exception_and_returns_error_report(tmp_path):
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
_seed_trajectories(app, 10)
policy = PromotionPolicy(
min_reward_delta=-1.0,
max_error_rate_increase=1.0,
max_latency_increase_ms=10000.0,
required_task_count=1,
)
report_dir = tmp_path / "reports"
coordinator = OnlineLearningCoordinator(
app=app,
policy=policy,
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
min_new_trajectories=1,
report_store_base_dir=report_dir,
)
# Force a training failure by monkeypatching DatasetBuilder.build to raise
original_build = DatasetBuilder.build
DatasetBuilder.build = lambda self, trajectories: (_ for _ in ()).throw(RuntimeError("simulated training failure"))
try:
result = coordinator.run_cycle()
finally:
DatasetBuilder.build = original_build
assert result["skipped"] is False
assert result["promoted"] is False
assert "error" in result
assert "simulated training failure" in result["error"]
# Verify error report was saved
from memabra.training_reports import TrainingReportStore
report_store = TrainingReportStore(base_dir=report_dir)
reports = report_store.list_reports()
assert len(reports) == 1
assert reports[0]["promotion_decision"]["accepted"] is False
assert "simulated training failure" in reports[0]["promotion_decision"]["reasons"][0]
def test_coordinator_persists_seen_trajectory_ids_across_restarts(tmp_path):
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
_seed_trajectories(app, 5)
policy = PromotionPolicy(
min_reward_delta=-1.0,
max_error_rate_increase=1.0,
max_latency_increase_ms=10000.0,
required_task_count=1,
)
benchmark_tasks = [BenchmarkTask(user_input="Test task 0")]
seen_store = tmp_path / "seen_trajectories.json"
version_dir = tmp_path / "versions"
report_dir = tmp_path / "reports"
coordinator1 = OnlineLearningCoordinator(
app=app,
policy=policy,
benchmark_tasks=benchmark_tasks,
min_new_trajectories=1,
version_store_base_dir=version_dir,
report_store_base_dir=report_dir,
seen_trajectory_store=seen_store,
)
result1 = coordinator1.run_cycle()
assert result1["skipped"] is False
# New coordinator instance pointing to same store
coordinator2 = OnlineLearningCoordinator(
app=app,
policy=policy,
benchmark_tasks=benchmark_tasks,
min_new_trajectories=1,
version_store_base_dir=version_dir,
report_store_base_dir=report_dir,
seen_trajectory_store=seen_store,
)
result2 = coordinator2.run_cycle()
assert result2["skipped"] is True
assert "too few new trajectories" in result2["reason"].lower()
def test_coordinator_dry_run_does_not_promote_or_save_version(tmp_path):
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
_seed_trajectories(app, 10)
policy = PromotionPolicy(
min_reward_delta=-1.0,
max_error_rate_increase=1.0,
max_latency_increase_ms=10000.0,
required_task_count=1,
)
version_dir = tmp_path / "versions"
report_dir = tmp_path / "reports"
coordinator = OnlineLearningCoordinator(
app=app,
policy=policy,
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
min_new_trajectories=1,
version_store_base_dir=version_dir,
report_store_base_dir=report_dir,
)
result = coordinator.run_cycle(dry_run=True)
assert result["skipped"] is False
assert result["promoted"] is False
assert "decision" in result
assert result["decision"].accepted is True # policy would accept, but dry_run blocks promotion
# No version should be saved
store = RouterVersionStore(base_dir=version_dir)
assert len(store.list_versions()) == 0
# Report should still be saved for audit
from memabra.training_reports import TrainingReportStore
report_store = TrainingReportStore(base_dir=report_dir)
reports = report_store.list_reports()
assert len(reports) == 1
assert reports[0].get("dry_run") is True
def test_coordinator_rebuilds_case_index_when_path_provided(tmp_path):
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
_seed_trajectories(app, 10)
policy = PromotionPolicy(
min_reward_delta=-1.0,
max_error_rate_increase=1.0,
max_latency_increase_ms=10000.0,
required_task_count=1,
)
case_index_path = tmp_path / "case-index.json"
coordinator = OnlineLearningCoordinator(
app=app,
policy=policy,
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
min_new_trajectories=1,
case_index_path=case_index_path,
)
result = coordinator.run_cycle()
assert result["skipped"] is False
assert case_index_path.exists()
from memabra.case_index import CaseIndex
index = CaseIndex.load(case_index_path)
assert index.best("Test task 0") is not None
def test_coordinator_uses_specified_baseline_version(tmp_path):
from memabra.router import SimpleLearningRouter
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
_seed_trajectories(app, 10)
# Save a baseline version with known weights
baseline_router = SimpleLearningRouter()
baseline_router._weights = {"call_tool": {"input_length": 0.99}}
baseline_router._feature_keys = ["input_length"]
version_dir = tmp_path / "versions"
store = RouterVersionStore(base_dir=version_dir)
store.save(baseline_router, version_id="v-baseline", metadata={"note": "baseline"})
# Change app's current router to something different
different_router = SimpleLearningRouter()
different_router._weights = {"clarify": {"input_length": 0.01}}
different_router._feature_keys = ["input_length"]
app.set_router(different_router)
policy = PromotionPolicy(
min_reward_delta=-1.0,
max_error_rate_increase=1.0,
max_latency_increase_ms=10000.0,
required_task_count=1,
)
report_dir = tmp_path / "reports"
coordinator = OnlineLearningCoordinator(
app=app,
policy=policy,
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
min_new_trajectories=1,
version_store_base_dir=version_dir,
report_store_base_dir=report_dir,
)
result = coordinator.run_cycle(baseline_version_id="v-baseline")
assert result["skipped"] is False
assert "baseline_metrics" in result
assert "challenger_metrics" in result
# Verify report records the baseline version
from memabra.training_reports import TrainingReportStore
report_store = TrainingReportStore(base_dir=report_dir)
reports = report_store.list_reports()
assert len(reports) == 1
assert reports[0].get("baseline_version_id") == "v-baseline"