Files
memabra/tests/test_router_versioning.py
2026-04-15 11:06:05 +08:00

116 lines
3.9 KiB
Python

import json
from pathlib import Path
from memabra.router import SimpleLearningRouter
from memabra.router_versioning import RouterVersionStore
def test_save_and_load_router_version(tmp_path):
store = RouterVersionStore(base_dir=tmp_path)
router = SimpleLearningRouter()
router._weights = {"call_tool": {"input_length": 0.5, "tool_count": 1.2}}
router._feature_keys = ["input_length", "tool_count"]
store.save(router, version_id="v1", metadata={"avg_reward": 0.75})
loaded = store.load("v1")
assert loaded._weights == router._weights
assert loaded._feature_keys == router._feature_keys
def test_list_versions_returns_metadata(tmp_path):
store = RouterVersionStore(base_dir=tmp_path)
router = SimpleLearningRouter()
router._weights = {"inject_memory": {"memory_count": 0.8}}
router._feature_keys = ["memory_count"]
store.save(router, version_id="v1", metadata={"avg_reward": 0.75})
store.save(router, version_id="v2", metadata={"avg_reward": 0.82})
versions = store.list_versions()
assert len(versions) == 2
assert versions[0]["version_id"] == "v1"
assert versions[0]["metadata"]["avg_reward"] == 0.75
assert versions[1]["version_id"] == "v2"
assert versions[1]["metadata"]["avg_reward"] == 0.82
def test_rollback_changes_current_version(tmp_path):
store = RouterVersionStore(base_dir=tmp_path)
router = SimpleLearningRouter()
router._weights = {"a": {"x": 1.0}}
router._feature_keys = ["x"]
store.save(router, version_id="v1")
store.save(router, version_id="v2")
assert store.get_current()["current_version_id"] == "v2"
store.rollback("v1")
current = store.get_current()
assert current["current_version_id"] == "v1"
assert current.get("rollback_from") == "v2"
assert "rolled_back_at" in current
def test_save_tracks_active_router_metadata(tmp_path):
store = RouterVersionStore(base_dir=tmp_path)
router = SimpleLearningRouter()
router._weights = {"a": {"x": 1.0}}
router._feature_keys = ["x"]
store.save(
router,
version_id="v1",
metadata={"promotion_source": "online_learning", "benchmark_summary": {"reward_delta": 0.1}},
)
current = store.get_current()
assert current["current_version_id"] == "v1"
assert current["promotion_source"] == "online_learning"
assert current["benchmark_summary"]["reward_delta"] == 0.1
assert current.get("prior_version_id") is None
def test_save_records_prior_version_id(tmp_path):
store = RouterVersionStore(base_dir=tmp_path)
router = SimpleLearningRouter()
router._weights = {"a": {"x": 1.0}}
router._feature_keys = ["x"]
store.save(router, version_id="v1")
store.save(router, version_id="v2")
current = store.get_current()
assert current["current_version_id"] == "v2"
assert current["prior_version_id"] == "v1"
def test_load_without_version_uses_current(tmp_path):
store = RouterVersionStore(base_dir=tmp_path)
router = SimpleLearningRouter()
router._weights = {"call_tool": {"input_length": 0.5}}
router._feature_keys = ["input_length"]
store.save(router, version_id="v1")
loaded = store.load()
assert loaded._weights == router._weights
def test_app_save_and_load_learning_router(tmp_path):
from memabra.app import MemabraApp, build_demo_app
app = build_demo_app(base_dir=tmp_path / "artifacts")
router = SimpleLearningRouter()
router._weights = {"clarify": {"input_length": 0.1}}
router._feature_keys = ["input_length"]
app.runner.router = router
version_dir = tmp_path / "router-versions"
app.save_learning_router(version_id="v-test", base_dir=version_dir, metadata={"note": "test"})
loaded_app = build_demo_app(base_dir=tmp_path / "artifacts")
loaded_app.load_learning_router(version_id="v-test", base_dir=version_dir)
assert loaded_app.runner.router._weights == router._weights
assert loaded_app.runner.router._feature_keys == router._feature_keys