Initial standalone memabra release
This commit is contained in:
348
tests/test_online_learning.py
Normal file
348
tests/test_online_learning.py
Normal file
@@ -0,0 +1,348 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memabra.app import build_demo_app
|
||||
from memabra.benchmarks import BenchmarkTask
|
||||
from memabra.dataset import DatasetBuilder
|
||||
from memabra.evaluator import Evaluator
|
||||
from memabra.online_learning import OnlineLearningCoordinator
|
||||
from memabra.promotion import PromotionPolicy
|
||||
from memabra.router_versioning import RouterVersionStore
|
||||
|
||||
|
||||
def _seed_trajectories(app, count: int):
|
||||
for i in range(count):
|
||||
app.run_task(f"Test task {i}", channel="local")
|
||||
|
||||
|
||||
def test_coordinator_skips_when_too_few_new_trajectories(tmp_path):
|
||||
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
||||
_seed_trajectories(app, 2)
|
||||
|
||||
coordinator = OnlineLearningCoordinator(
|
||||
app=app,
|
||||
policy=PromotionPolicy(
|
||||
min_reward_delta=0.01,
|
||||
max_error_rate_increase=0.05,
|
||||
max_latency_increase_ms=100.0,
|
||||
required_task_count=1,
|
||||
),
|
||||
benchmark_tasks=[BenchmarkTask(user_input="test")],
|
||||
min_new_trajectories=5,
|
||||
)
|
||||
|
||||
result = coordinator.run_cycle()
|
||||
|
||||
assert result["skipped"] is True
|
||||
assert "too few new trajectories" in result["reason"].lower()
|
||||
|
||||
|
||||
def test_coordinator_rejects_when_policy_fails(tmp_path):
|
||||
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
||||
# Seed enough trajectories for training and benchmarking
|
||||
_seed_trajectories(app, 10)
|
||||
|
||||
# Use a very strict policy that will reject any challenger
|
||||
policy = PromotionPolicy(
|
||||
min_reward_delta=1.0, # impossible to meet
|
||||
max_error_rate_increase=0.0,
|
||||
max_latency_increase_ms=0.0,
|
||||
required_task_count=1,
|
||||
)
|
||||
|
||||
coordinator = OnlineLearningCoordinator(
|
||||
app=app,
|
||||
policy=policy,
|
||||
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
||||
min_new_trajectories=1,
|
||||
version_store_base_dir=tmp_path / "versions",
|
||||
)
|
||||
|
||||
result = coordinator.run_cycle()
|
||||
|
||||
assert result["skipped"] is False
|
||||
assert result["promoted"] is False
|
||||
assert "decision" in result
|
||||
assert result["decision"].accepted is False
|
||||
|
||||
|
||||
def test_coordinator_accepts_and_saves_version_when_policy_passes(tmp_path):
|
||||
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
||||
_seed_trajectories(app, 10)
|
||||
|
||||
# Lenient policy that should pass
|
||||
policy = PromotionPolicy(
|
||||
min_reward_delta=-1.0, # always passes
|
||||
max_error_rate_increase=1.0,
|
||||
max_latency_increase_ms=10000.0,
|
||||
required_task_count=1,
|
||||
)
|
||||
|
||||
version_dir = tmp_path / "versions"
|
||||
report_dir = tmp_path / "reports"
|
||||
coordinator = OnlineLearningCoordinator(
|
||||
app=app,
|
||||
policy=policy,
|
||||
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
||||
min_new_trajectories=1,
|
||||
version_store_base_dir=version_dir,
|
||||
report_store_base_dir=report_dir,
|
||||
)
|
||||
|
||||
result = coordinator.run_cycle()
|
||||
|
||||
assert result["skipped"] is False
|
||||
assert result["promoted"] is True
|
||||
assert "version_id" in result
|
||||
assert result["decision"].accepted is True
|
||||
|
||||
# Verify version was saved
|
||||
store = RouterVersionStore(base_dir=version_dir)
|
||||
versions = store.list_versions()
|
||||
assert len(versions) == 1
|
||||
assert versions[0]["version_id"] == result["version_id"]
|
||||
|
||||
# Verify report was saved
|
||||
from memabra.training_reports import TrainingReportStore
|
||||
report_store = TrainingReportStore(base_dir=report_dir)
|
||||
reports = report_store.list_reports()
|
||||
assert len(reports) == 1
|
||||
assert reports[0]["promoted_version_id"] == result["version_id"]
|
||||
|
||||
|
||||
def test_coordinator_saves_report_on_rejection(tmp_path):
|
||||
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
||||
_seed_trajectories(app, 10)
|
||||
|
||||
policy = PromotionPolicy(
|
||||
min_reward_delta=1.0,
|
||||
max_error_rate_increase=0.0,
|
||||
max_latency_increase_ms=0.0,
|
||||
required_task_count=1,
|
||||
)
|
||||
|
||||
report_dir = tmp_path / "reports"
|
||||
coordinator = OnlineLearningCoordinator(
|
||||
app=app,
|
||||
policy=policy,
|
||||
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
||||
min_new_trajectories=1,
|
||||
report_store_base_dir=report_dir,
|
||||
)
|
||||
|
||||
result = coordinator.run_cycle()
|
||||
|
||||
assert result["promoted"] is False
|
||||
from memabra.training_reports import TrainingReportStore
|
||||
report_store = TrainingReportStore(base_dir=report_dir)
|
||||
reports = report_store.list_reports()
|
||||
assert len(reports) == 1
|
||||
assert reports[0]["promotion_decision"]["accepted"] is False
|
||||
|
||||
|
||||
def test_coordinator_catches_training_exception_and_returns_error_report(tmp_path):
|
||||
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
||||
_seed_trajectories(app, 10)
|
||||
|
||||
policy = PromotionPolicy(
|
||||
min_reward_delta=-1.0,
|
||||
max_error_rate_increase=1.0,
|
||||
max_latency_increase_ms=10000.0,
|
||||
required_task_count=1,
|
||||
)
|
||||
|
||||
report_dir = tmp_path / "reports"
|
||||
coordinator = OnlineLearningCoordinator(
|
||||
app=app,
|
||||
policy=policy,
|
||||
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
||||
min_new_trajectories=1,
|
||||
report_store_base_dir=report_dir,
|
||||
)
|
||||
|
||||
# Force a training failure by monkeypatching DatasetBuilder.build to raise
|
||||
original_build = DatasetBuilder.build
|
||||
DatasetBuilder.build = lambda self, trajectories: (_ for _ in ()).throw(RuntimeError("simulated training failure"))
|
||||
|
||||
try:
|
||||
result = coordinator.run_cycle()
|
||||
finally:
|
||||
DatasetBuilder.build = original_build
|
||||
|
||||
assert result["skipped"] is False
|
||||
assert result["promoted"] is False
|
||||
assert "error" in result
|
||||
assert "simulated training failure" in result["error"]
|
||||
|
||||
# Verify error report was saved
|
||||
from memabra.training_reports import TrainingReportStore
|
||||
report_store = TrainingReportStore(base_dir=report_dir)
|
||||
reports = report_store.list_reports()
|
||||
assert len(reports) == 1
|
||||
assert reports[0]["promotion_decision"]["accepted"] is False
|
||||
assert "simulated training failure" in reports[0]["promotion_decision"]["reasons"][0]
|
||||
|
||||
|
||||
def test_coordinator_persists_seen_trajectory_ids_across_restarts(tmp_path):
|
||||
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
||||
_seed_trajectories(app, 5)
|
||||
|
||||
policy = PromotionPolicy(
|
||||
min_reward_delta=-1.0,
|
||||
max_error_rate_increase=1.0,
|
||||
max_latency_increase_ms=10000.0,
|
||||
required_task_count=1,
|
||||
)
|
||||
benchmark_tasks = [BenchmarkTask(user_input="Test task 0")]
|
||||
seen_store = tmp_path / "seen_trajectories.json"
|
||||
version_dir = tmp_path / "versions"
|
||||
report_dir = tmp_path / "reports"
|
||||
|
||||
coordinator1 = OnlineLearningCoordinator(
|
||||
app=app,
|
||||
policy=policy,
|
||||
benchmark_tasks=benchmark_tasks,
|
||||
min_new_trajectories=1,
|
||||
version_store_base_dir=version_dir,
|
||||
report_store_base_dir=report_dir,
|
||||
seen_trajectory_store=seen_store,
|
||||
)
|
||||
result1 = coordinator1.run_cycle()
|
||||
assert result1["skipped"] is False
|
||||
|
||||
# New coordinator instance pointing to same store
|
||||
coordinator2 = OnlineLearningCoordinator(
|
||||
app=app,
|
||||
policy=policy,
|
||||
benchmark_tasks=benchmark_tasks,
|
||||
min_new_trajectories=1,
|
||||
version_store_base_dir=version_dir,
|
||||
report_store_base_dir=report_dir,
|
||||
seen_trajectory_store=seen_store,
|
||||
)
|
||||
result2 = coordinator2.run_cycle()
|
||||
assert result2["skipped"] is True
|
||||
assert "too few new trajectories" in result2["reason"].lower()
|
||||
|
||||
|
||||
def test_coordinator_dry_run_does_not_promote_or_save_version(tmp_path):
|
||||
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
||||
_seed_trajectories(app, 10)
|
||||
|
||||
policy = PromotionPolicy(
|
||||
min_reward_delta=-1.0,
|
||||
max_error_rate_increase=1.0,
|
||||
max_latency_increase_ms=10000.0,
|
||||
required_task_count=1,
|
||||
)
|
||||
|
||||
version_dir = tmp_path / "versions"
|
||||
report_dir = tmp_path / "reports"
|
||||
coordinator = OnlineLearningCoordinator(
|
||||
app=app,
|
||||
policy=policy,
|
||||
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
||||
min_new_trajectories=1,
|
||||
version_store_base_dir=version_dir,
|
||||
report_store_base_dir=report_dir,
|
||||
)
|
||||
|
||||
result = coordinator.run_cycle(dry_run=True)
|
||||
|
||||
assert result["skipped"] is False
|
||||
assert result["promoted"] is False
|
||||
assert "decision" in result
|
||||
assert result["decision"].accepted is True # policy would accept, but dry_run blocks promotion
|
||||
|
||||
# No version should be saved
|
||||
store = RouterVersionStore(base_dir=version_dir)
|
||||
assert len(store.list_versions()) == 0
|
||||
|
||||
# Report should still be saved for audit
|
||||
from memabra.training_reports import TrainingReportStore
|
||||
|
||||
report_store = TrainingReportStore(base_dir=report_dir)
|
||||
reports = report_store.list_reports()
|
||||
assert len(reports) == 1
|
||||
assert reports[0].get("dry_run") is True
|
||||
|
||||
|
||||
def test_coordinator_rebuilds_case_index_when_path_provided(tmp_path):
|
||||
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
||||
_seed_trajectories(app, 10)
|
||||
|
||||
policy = PromotionPolicy(
|
||||
min_reward_delta=-1.0,
|
||||
max_error_rate_increase=1.0,
|
||||
max_latency_increase_ms=10000.0,
|
||||
required_task_count=1,
|
||||
)
|
||||
|
||||
case_index_path = tmp_path / "case-index.json"
|
||||
coordinator = OnlineLearningCoordinator(
|
||||
app=app,
|
||||
policy=policy,
|
||||
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
||||
min_new_trajectories=1,
|
||||
case_index_path=case_index_path,
|
||||
)
|
||||
|
||||
result = coordinator.run_cycle()
|
||||
|
||||
assert result["skipped"] is False
|
||||
assert case_index_path.exists()
|
||||
from memabra.case_index import CaseIndex
|
||||
|
||||
index = CaseIndex.load(case_index_path)
|
||||
assert index.best("Test task 0") is not None
|
||||
|
||||
|
||||
def test_coordinator_uses_specified_baseline_version(tmp_path):
|
||||
from memabra.router import SimpleLearningRouter
|
||||
|
||||
app = build_demo_app(base_dir=tmp_path / "demo-artifacts")
|
||||
_seed_trajectories(app, 10)
|
||||
|
||||
# Save a baseline version with known weights
|
||||
baseline_router = SimpleLearningRouter()
|
||||
baseline_router._weights = {"call_tool": {"input_length": 0.99}}
|
||||
baseline_router._feature_keys = ["input_length"]
|
||||
version_dir = tmp_path / "versions"
|
||||
store = RouterVersionStore(base_dir=version_dir)
|
||||
store.save(baseline_router, version_id="v-baseline", metadata={"note": "baseline"})
|
||||
|
||||
# Change app's current router to something different
|
||||
different_router = SimpleLearningRouter()
|
||||
different_router._weights = {"clarify": {"input_length": 0.01}}
|
||||
different_router._feature_keys = ["input_length"]
|
||||
app.set_router(different_router)
|
||||
|
||||
policy = PromotionPolicy(
|
||||
min_reward_delta=-1.0,
|
||||
max_error_rate_increase=1.0,
|
||||
max_latency_increase_ms=10000.0,
|
||||
required_task_count=1,
|
||||
)
|
||||
|
||||
report_dir = tmp_path / "reports"
|
||||
coordinator = OnlineLearningCoordinator(
|
||||
app=app,
|
||||
policy=policy,
|
||||
benchmark_tasks=[BenchmarkTask(user_input="Test task 0")],
|
||||
min_new_trajectories=1,
|
||||
version_store_base_dir=version_dir,
|
||||
report_store_base_dir=report_dir,
|
||||
)
|
||||
|
||||
result = coordinator.run_cycle(baseline_version_id="v-baseline")
|
||||
|
||||
assert result["skipped"] is False
|
||||
assert "baseline_metrics" in result
|
||||
assert "challenger_metrics" in result
|
||||
|
||||
# Verify report records the baseline version
|
||||
from memabra.training_reports import TrainingReportStore
|
||||
|
||||
report_store = TrainingReportStore(base_dir=report_dir)
|
||||
reports = report_store.list_reports()
|
||||
assert len(reports) == 1
|
||||
assert reports[0].get("baseline_version_id") == "v-baseline"
|
||||
Reference in New Issue
Block a user