50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
from memabra.dataset import DatasetBuilder, TrainingSample
|
|
|
|
|
|
def test_dataset_builder_extracts_features_and_label():
|
|
trajectories = [
|
|
{
|
|
"task": {"input": "hello world"},
|
|
"candidate_sets": {
|
|
"memory": [{"confidence": 0.8}],
|
|
"skill": [{"success_rate": 0.9}],
|
|
"tool": [{"confidence": 0.7, "risk": 0.2}],
|
|
},
|
|
"decisions": [{"decision_type": "direct_answer"}],
|
|
"reward": {"total": 0.95},
|
|
}
|
|
]
|
|
builder = DatasetBuilder()
|
|
samples = builder.build(trajectories)
|
|
assert len(samples) == 1
|
|
sample = samples[0]
|
|
assert sample.input_text == "hello world"
|
|
assert sample.label == "direct_answer"
|
|
assert sample.reward == 0.95
|
|
assert sample.features["input_length"] == 11
|
|
assert sample.features["memory_count"] == 1
|
|
assert sample.features["skill_count"] == 1
|
|
assert sample.features["tool_count"] == 1
|
|
assert sample.features["top_memory_confidence"] == 0.8
|
|
assert sample.features["top_skill_success_rate"] == 0.9
|
|
assert sample.features["top_tool_confidence"] == 0.7
|
|
assert sample.features["top_tool_risk"] == 0.2
|
|
|
|
|
|
def test_dataset_builder_handles_empty_candidates():
|
|
trajectories = [
|
|
{
|
|
"task": {"input": "hi"},
|
|
"candidate_sets": {"memory": [], "skill": [], "tool": []},
|
|
"decisions": [{"decision_type": "clarify"}],
|
|
"reward": {"total": 0.0},
|
|
}
|
|
]
|
|
builder = DatasetBuilder()
|
|
samples = builder.build(trajectories)
|
|
assert len(samples) == 1
|
|
assert samples[0].features["top_memory_confidence"] == 0.0
|
|
assert samples[0].features["top_skill_success_rate"] == 0.0
|
|
assert samples[0].features["top_tool_confidence"] == 0.0
|
|
assert samples[0].features["top_tool_risk"] == 0.0
|