92 lines
2.4 KiB
Python
92 lines
2.4 KiB
Python
from memabra.candidate_types import CandidateObject
|
|
from memabra.dataset import TrainingSample
|
|
from memabra.router import SimpleLearningRouter, TaskContext
|
|
|
|
|
|
def test_learning_router_fits_and_predicts():
|
|
router = SimpleLearningRouter()
|
|
samples = [
|
|
TrainingSample(
|
|
input_text="run tool",
|
|
features={
|
|
"input_length": 8,
|
|
"memory_count": 0,
|
|
"skill_count": 0,
|
|
"tool_count": 1,
|
|
"top_memory_confidence": 0.0,
|
|
"top_skill_success_rate": 0.0,
|
|
"top_tool_confidence": 0.9,
|
|
"top_tool_risk": 0.1,
|
|
},
|
|
label="call_tool",
|
|
reward=1.0,
|
|
),
|
|
TrainingSample(
|
|
input_text="remember",
|
|
features={
|
|
"input_length": 8,
|
|
"memory_count": 1,
|
|
"skill_count": 0,
|
|
"tool_count": 0,
|
|
"top_memory_confidence": 0.9,
|
|
"top_skill_success_rate": 0.0,
|
|
"top_tool_confidence": 0.0,
|
|
"top_tool_risk": 0.0,
|
|
},
|
|
label="inject_memory",
|
|
reward=1.0,
|
|
),
|
|
]
|
|
router.fit(samples)
|
|
|
|
tool = CandidateObject(
|
|
id="t1",
|
|
type="tool",
|
|
title="t",
|
|
summary="s",
|
|
triggers=[],
|
|
confidence=0.9,
|
|
success_rate=0.9,
|
|
freshness=0.9,
|
|
cost=0.0,
|
|
risk=0.1,
|
|
)
|
|
decision = router.choose(
|
|
TaskContext(user_input="run tool"),
|
|
memory_candidates=[],
|
|
skill_candidates=[],
|
|
tool_candidates=[tool],
|
|
)
|
|
assert decision.decision_type == "call_tool"
|
|
|
|
mem = CandidateObject(
|
|
id="m1",
|
|
type="memory",
|
|
title="m",
|
|
summary="s",
|
|
triggers=[],
|
|
confidence=0.9,
|
|
success_rate=0.9,
|
|
freshness=0.9,
|
|
cost=0.0,
|
|
risk=0.0,
|
|
)
|
|
decision = router.choose(
|
|
TaskContext(user_input="remember"),
|
|
memory_candidates=[mem],
|
|
skill_candidates=[],
|
|
tool_candidates=[],
|
|
)
|
|
assert decision.decision_type == "inject_memory"
|
|
|
|
|
|
def test_learning_router_falls_back_to_clarify_when_untrained():
|
|
router = SimpleLearningRouter()
|
|
decision = router.choose(
|
|
TaskContext(user_input="hi"),
|
|
memory_candidates=[],
|
|
skill_candidates=[],
|
|
tool_candidates=[],
|
|
)
|
|
assert decision.decision_type == "clarify"
|