116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
import json
|
|
from pathlib import Path
|
|
|
|
from memabra.router import SimpleLearningRouter
|
|
from memabra.router_versioning import RouterVersionStore
|
|
|
|
|
|
def test_save_and_load_router_version(tmp_path):
|
|
store = RouterVersionStore(base_dir=tmp_path)
|
|
router = SimpleLearningRouter()
|
|
router._weights = {"call_tool": {"input_length": 0.5, "tool_count": 1.2}}
|
|
router._feature_keys = ["input_length", "tool_count"]
|
|
|
|
store.save(router, version_id="v1", metadata={"avg_reward": 0.75})
|
|
loaded = store.load("v1")
|
|
|
|
assert loaded._weights == router._weights
|
|
assert loaded._feature_keys == router._feature_keys
|
|
|
|
|
|
def test_list_versions_returns_metadata(tmp_path):
|
|
store = RouterVersionStore(base_dir=tmp_path)
|
|
router = SimpleLearningRouter()
|
|
router._weights = {"inject_memory": {"memory_count": 0.8}}
|
|
router._feature_keys = ["memory_count"]
|
|
|
|
store.save(router, version_id="v1", metadata={"avg_reward": 0.75})
|
|
store.save(router, version_id="v2", metadata={"avg_reward": 0.82})
|
|
|
|
versions = store.list_versions()
|
|
assert len(versions) == 2
|
|
assert versions[0]["version_id"] == "v1"
|
|
assert versions[0]["metadata"]["avg_reward"] == 0.75
|
|
assert versions[1]["version_id"] == "v2"
|
|
assert versions[1]["metadata"]["avg_reward"] == 0.82
|
|
|
|
|
|
def test_rollback_changes_current_version(tmp_path):
|
|
store = RouterVersionStore(base_dir=tmp_path)
|
|
router = SimpleLearningRouter()
|
|
router._weights = {"a": {"x": 1.0}}
|
|
router._feature_keys = ["x"]
|
|
|
|
store.save(router, version_id="v1")
|
|
store.save(router, version_id="v2")
|
|
assert store.get_current()["current_version_id"] == "v2"
|
|
|
|
store.rollback("v1")
|
|
current = store.get_current()
|
|
assert current["current_version_id"] == "v1"
|
|
assert current.get("rollback_from") == "v2"
|
|
assert "rolled_back_at" in current
|
|
|
|
|
|
def test_save_tracks_active_router_metadata(tmp_path):
|
|
store = RouterVersionStore(base_dir=tmp_path)
|
|
router = SimpleLearningRouter()
|
|
router._weights = {"a": {"x": 1.0}}
|
|
router._feature_keys = ["x"]
|
|
|
|
store.save(
|
|
router,
|
|
version_id="v1",
|
|
metadata={"promotion_source": "online_learning", "benchmark_summary": {"reward_delta": 0.1}},
|
|
)
|
|
|
|
current = store.get_current()
|
|
assert current["current_version_id"] == "v1"
|
|
assert current["promotion_source"] == "online_learning"
|
|
assert current["benchmark_summary"]["reward_delta"] == 0.1
|
|
assert current.get("prior_version_id") is None
|
|
|
|
|
|
def test_save_records_prior_version_id(tmp_path):
|
|
store = RouterVersionStore(base_dir=tmp_path)
|
|
router = SimpleLearningRouter()
|
|
router._weights = {"a": {"x": 1.0}}
|
|
router._feature_keys = ["x"]
|
|
|
|
store.save(router, version_id="v1")
|
|
store.save(router, version_id="v2")
|
|
|
|
current = store.get_current()
|
|
assert current["current_version_id"] == "v2"
|
|
assert current["prior_version_id"] == "v1"
|
|
|
|
|
|
def test_load_without_version_uses_current(tmp_path):
|
|
store = RouterVersionStore(base_dir=tmp_path)
|
|
router = SimpleLearningRouter()
|
|
router._weights = {"call_tool": {"input_length": 0.5}}
|
|
router._feature_keys = ["input_length"]
|
|
|
|
store.save(router, version_id="v1")
|
|
loaded = store.load()
|
|
|
|
assert loaded._weights == router._weights
|
|
|
|
|
|
def test_app_save_and_load_learning_router(tmp_path):
|
|
from memabra.app import MemabraApp, build_demo_app
|
|
|
|
app = build_demo_app(base_dir=tmp_path / "artifacts")
|
|
router = SimpleLearningRouter()
|
|
router._weights = {"clarify": {"input_length": 0.1}}
|
|
router._feature_keys = ["input_length"]
|
|
app.runner.router = router
|
|
|
|
version_dir = tmp_path / "router-versions"
|
|
app.save_learning_router(version_id="v-test", base_dir=version_dir, metadata={"note": "test"})
|
|
loaded_app = build_demo_app(base_dir=tmp_path / "artifacts")
|
|
loaded_app.load_learning_router(version_id="v-test", base_dir=version_dir)
|
|
|
|
assert loaded_app.runner.router._weights == router._weights
|
|
assert loaded_app.runner.router._feature_keys == router._feature_keys
|