Initial standalone memabra release
This commit is contained in:
49
tests/test_dataset.py
Normal file
49
tests/test_dataset.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from memabra.dataset import DatasetBuilder, TrainingSample
|
||||
|
||||
|
||||
def test_dataset_builder_extracts_features_and_label():
|
||||
trajectories = [
|
||||
{
|
||||
"task": {"input": "hello world"},
|
||||
"candidate_sets": {
|
||||
"memory": [{"confidence": 0.8}],
|
||||
"skill": [{"success_rate": 0.9}],
|
||||
"tool": [{"confidence": 0.7, "risk": 0.2}],
|
||||
},
|
||||
"decisions": [{"decision_type": "direct_answer"}],
|
||||
"reward": {"total": 0.95},
|
||||
}
|
||||
]
|
||||
builder = DatasetBuilder()
|
||||
samples = builder.build(trajectories)
|
||||
assert len(samples) == 1
|
||||
sample = samples[0]
|
||||
assert sample.input_text == "hello world"
|
||||
assert sample.label == "direct_answer"
|
||||
assert sample.reward == 0.95
|
||||
assert sample.features["input_length"] == 11
|
||||
assert sample.features["memory_count"] == 1
|
||||
assert sample.features["skill_count"] == 1
|
||||
assert sample.features["tool_count"] == 1
|
||||
assert sample.features["top_memory_confidence"] == 0.8
|
||||
assert sample.features["top_skill_success_rate"] == 0.9
|
||||
assert sample.features["top_tool_confidence"] == 0.7
|
||||
assert sample.features["top_tool_risk"] == 0.2
|
||||
|
||||
|
||||
def test_dataset_builder_handles_empty_candidates():
|
||||
trajectories = [
|
||||
{
|
||||
"task": {"input": "hi"},
|
||||
"candidate_sets": {"memory": [], "skill": [], "tool": []},
|
||||
"decisions": [{"decision_type": "clarify"}],
|
||||
"reward": {"total": 0.0},
|
||||
}
|
||||
]
|
||||
builder = DatasetBuilder()
|
||||
samples = builder.build(trajectories)
|
||||
assert len(samples) == 1
|
||||
assert samples[0].features["top_memory_confidence"] == 0.0
|
||||
assert samples[0].features["top_skill_success_rate"] == 0.0
|
||||
assert samples[0].features["top_tool_confidence"] == 0.0
|
||||
assert samples[0].features["top_tool_risk"] == 0.0
|
||||
Reference in New Issue
Block a user