Files
memabra/tests/test_tool_adapters.py
2026-04-15 11:06:05 +08:00

67 lines
2.3 KiB
Python

from memabra.router import TaskContext
def test_local_function_tool_adapter_executes_callable():
from memabra.execution import LocalFunctionToolAdapter
def add(a: int, b: int) -> int:
return a + b
adapter = LocalFunctionToolAdapter(func=add)
result = adapter.run_tool("add", TaskContext(user_input="add 1 and 2"), {"a": 1, "b": 2})
assert result["status"] == "success"
assert result["output"] == 3
assert result["error"] is None
def test_subprocess_tool_adapter_executes_command():
from memabra.execution import SubprocessToolAdapter
adapter = SubprocessToolAdapter(command="echo hello")
result = adapter.run_tool("echo", TaskContext(user_input="say hello"))
assert result["status"] == "success"
assert "hello" in result["output"]
assert result["error"] is None
assert result["latency_ms"] >= 0
def test_tool_registry_resolves_and_runs_tools():
from memabra.execution import LocalFunctionToolAdapter, ToolRegistry
registry = ToolRegistry()
registry.register("double", LocalFunctionToolAdapter(func=lambda x: x * 2))
result = registry.run_tool("double", TaskContext(user_input="double 5"), {"x": 5})
assert result["status"] == "success"
assert result["output"] == 10
def test_tool_registry_returns_error_for_unknown_tool():
from memabra.execution import ToolRegistry
registry = ToolRegistry()
result = registry.run_tool("missing", TaskContext(user_input="missing"))
assert result["status"] == "error"
assert "not found" in result["error"].lower()
def test_tool_executor_uses_registry_and_produces_result_events():
from memabra.execution import ToolExecutor, ToolRegistry, LocalFunctionToolAdapter
from memabra.router import RouteDecision
registry = ToolRegistry()
registry.register("add", LocalFunctionToolAdapter(func=lambda a, b: a + b))
executor = ToolExecutor(backend=registry)
decision = RouteDecision(decision_type="call_tool", selected_ids=["add"], selected_payloads=[{"a": 2, "b": 3}])
result = executor.execute(decision, TaskContext(user_input="add 2 and 3"), trajectory_id="traj-1")
assert result.status == "executed"
assert result.details["results"][0]["output"] == 5
assert any(event.event_type == "tool_called" for event in result.events)
assert any(event.event_type == "tool_result" for event in result.events)