51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
from memabra.case_index import CaseIndex
|
|
|
|
|
|
def test_case_index_adds_and_retrieves_best_trajectory():
|
|
index = CaseIndex()
|
|
trajectory = {
|
|
"trajectory_id": "traj-1",
|
|
"task": {"input": "Hello world"},
|
|
"outcome": {"status": "success"},
|
|
"reward": {"total": 1.0},
|
|
}
|
|
index.add(trajectory)
|
|
assert index.best("Hello world") == "traj-1"
|
|
|
|
|
|
def test_case_index_returns_none_for_unknown_input():
|
|
index = CaseIndex()
|
|
assert index.best("Unknown input") is None
|
|
|
|
|
|
def test_case_index_keeps_higher_reward_for_same_input():
|
|
index = CaseIndex()
|
|
index.add({
|
|
"trajectory_id": "traj-low",
|
|
"task": {"input": "Same input"},
|
|
"outcome": {"status": "success"},
|
|
"reward": {"total": 0.5},
|
|
})
|
|
index.add({
|
|
"trajectory_id": "traj-high",
|
|
"task": {"input": "Same input"},
|
|
"outcome": {"status": "success"},
|
|
"reward": {"total": 1.5},
|
|
})
|
|
assert index.best("Same input") == "traj-high"
|
|
|
|
|
|
def test_case_index_save_and_round_trip(tmp_path):
|
|
index = CaseIndex()
|
|
index.add({
|
|
"trajectory_id": "traj-save",
|
|
"task": {"input": "Persist me"},
|
|
"outcome": {"status": "success"},
|
|
"reward": {"total": 2.0},
|
|
})
|
|
path = tmp_path / "case_index.json"
|
|
index.save(path)
|
|
|
|
loaded = CaseIndex.load(path)
|
|
assert loaded.best("Persist me") == "traj-save"
|