Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 284 additions & 1 deletion worker/tests/policy/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from netboxlabs.diode.sdk.diode.v1 import ingester_pb2

from worker.backend import Backend
from worker.exceptions import IngestRejected, IngestUnavailable
from worker.models import Config, DiodeConfig, Metadata, Policy, Status
from worker.policy.run import RunStore
from worker.policy.run import RunStatus, RunStore
from worker.policy.runner import PolicyRunner


Expand Down Expand Up @@ -116,6 +117,11 @@ def mock_backend():
return backend


def _extract_callback(mock_backend_class):
"""Recover the ingest_callback closure that PolicyRunner.setup passed to backend_class()."""
return mock_backend_class.call_args.kwargs["ingest_callback"]


def test_initial_status(policy_runner):
"""Test initial status of PolicyRunner."""
assert policy_runner.status == Status.NEW
Expand Down Expand Up @@ -592,3 +598,280 @@ def test_run_chunk_ingestion_error(

# Should log the error
assert "Chunk ingestion failed" in caplog.text


# ---------------------------------------------------------------------------
# New tests: _build_ingest_callback
# ---------------------------------------------------------------------------


def test_setup_passes_kwargs_to_backend(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""setup() constructs the backend with ingest_callback= and policy= kwargs."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

mock_backend_class = mock_load_class.return_value
call_kwargs = mock_backend_class.call_args.kwargs
assert "ingest_callback" in call_kwargs
assert callable(call_kwargs["ingest_callback"])
assert call_kwargs["policy"] == sample_policy


def test_ingest_callback_entities_happy_path(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Callback with entities= ingests them and records COMPLETED run."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)

client_instance = mock_diode_client.return_value
client_instance.ingest.return_value.errors = []

entity1 = ingester_pb2.Entity()
entity1.device.name = "dev1"
entity2 = ingester_pb2.Entity()
entity2.device.name = "dev2"

with patch("worker.policy.runner.apply_run_id_to_entities"), patch(
"worker.policy.runner.estimate_message_size", return_value=1024
), patch("worker.policy.runner.create_message_chunks"):
result = callback(entities=[entity1, entity2])

assert result is None
mock_run_store.create_run.assert_called_once()
client_instance.ingest.assert_called_once()
call_kwargs = client_instance.ingest.call_args.kwargs
assert len(call_kwargs["entities"]) == 2
mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.COMPLETED


def test_ingest_callback_error_path(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Callback with error= records FAILED run and skips client.ingest."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)
client_instance = mock_diode_client.return_value

err = Exception("vendor unreachable")
result = callback(error=err)

assert result is None
client_instance.ingest.assert_not_called()
mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.FAILED
assert update_kwargs["error"] is err


@pytest.mark.parametrize(
"kwargs",
[
pytest.param({}, id="neither"),
pytest.param({"entities": [], "error": Exception("x")}, id="both"),
],
)
def test_ingest_callback_requires_exactly_one_of_entities_or_error(
kwargs,
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Callback raises TypeError when neither or both of entities/error are given."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)

with pytest.raises(TypeError):
callback(**kwargs)


def test_ingest_callback_translates_transport_errors_to_unavailable(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Non-IngestError transport exceptions are wrapped as IngestUnavailable."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)
client_instance = mock_diode_client.return_value
client_instance.ingest.side_effect = RuntimeError("connection refused")

entity = ingester_pb2.Entity()
entity.device.name = "dev1"

with patch("worker.policy.runner.estimate_message_size", return_value=1024), patch(
"worker.policy.runner.apply_run_id_to_entities"
):
with pytest.raises(IngestUnavailable):
callback(entities=[entity])

mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.FAILED


def test_ingest_callback_translates_response_errors_to_rejected(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Response errors (non-empty errors list) are raised as IngestRejected."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)
client_instance = mock_diode_client.return_value
client_instance.ingest.return_value.errors = ["bad payload"]

entity = ingester_pb2.Entity()
entity.device.name = "dev1"

with patch("worker.policy.runner.estimate_message_size", return_value=1024), patch(
"worker.policy.runner.apply_run_id_to_entities"
):
with pytest.raises(IngestRejected):
callback(entities=[entity])

mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.FAILED


def test_ingest_callback_chunks_large_payloads(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Callback splits large payloads into chunks and ingests each separately."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)
client_instance = mock_diode_client.return_value
client_instance.ingest.return_value.errors = []

entity1 = ingester_pb2.Entity()
entity1.device.name = "dev1"
entity2 = ingester_pb2.Entity()
entity2.device.name = "dev2"
chunk_a = [entity1]
chunk_b = [entity2]

with patch(
"worker.policy.runner.estimate_message_size", return_value=4 * 1024 * 1024
), patch(
"worker.policy.runner.create_message_chunks", return_value=[chunk_a, chunk_b]
), patch(
"worker.policy.runner.apply_run_id_to_entities"
):
callback(entities=[entity1, entity2])

assert client_instance.ingest.call_count == 2
mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.COMPLETED


def test_run_unaffected_by_callback(
policy_runner, sample_policy, mock_diode_client, mock_backend, mock_run_store
):
"""PolicyRunner.run() is unaffected by the new callback mechanism."""
policy_runner.name = "test_policy"
policy_runner.run_store = mock_run_store

entity = ingester_pb2.Entity()
entity.device.name = "device-x"
mock_backend.run.return_value = [entity]
mock_diode_client.ingest.return_value.errors = []

with patch("worker.policy.runner.estimate_message_size", return_value=512):
policy_runner.run(mock_diode_client, mock_backend, sample_policy)

mock_backend.run.assert_called_once_with("test_policy", sample_policy)
mock_diode_client.ingest.assert_called_once()
mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.COMPLETED


def test_ingest_callback_raises_when_client_not_initialised(
policy_runner,
sample_policy,
sample_diode_config,
mock_run_store,
mock_load_class,
mock_diode_client,
):
"""Calling the callback before _diode_client is assigned raises IngestUnavailable."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy-x", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)

# Simulate "called before client init" — clear the attribute.
policy_runner._diode_client = None

entity = MagicMock()
with patch("worker.policy.runner.apply_run_id_to_entities"):
with pytest.raises(IngestUnavailable, match="diode client was initialised"):
callback(entities=[entity])

# Pseudo-run was still recorded as FAILED.
mock_run_store.update_run.assert_called()
final_call = mock_run_store.update_run.call_args
assert final_call.kwargs["status"] == RunStatus.FAILED
24 changes: 24 additions & 0 deletions worker/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,27 @@ def test_load_class_attribute_error(mock_import_module):
match=f"Failed to load a class inheriting from 'Backend' in module '{mock_module_name}': Attribute error",
):
load_class(mock_module_name)


def test_backend_init_default_no_args():
"""Zero-arg construction still works (back-compat with older worker)."""
b = Backend()
assert b.ingest_callback is None
assert b.policy is None


def test_backend_init_stores_kwargs():
"""ingest_callback and policy are stored on the instance."""

def cb(**_):
return None

pol = MagicMock(spec=Policy)
b = Backend(ingest_callback=cb, policy=pol)
assert b.ingest_callback is cb
assert b.policy is pol


def test_backend_init_absorbs_unknown_kwargs():
"""Forward-compat: unknown kwargs don't raise."""
Backend(unknown_future_resource="x", another_one=42) # must not raise
39 changes: 39 additions & 0 deletions worker/tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python
# Copyright 2026 NetBox Labs Inc
"""NetBox Labs - worker.exceptions hierarchy tests."""

import pytest

from worker.exceptions import IngestError, IngestRejected, IngestUnavailable


def test_unavailable_is_ingest_error():
"""IngestUnavailable is a subclass of IngestError."""
assert issubclass(IngestUnavailable, IngestError)


def test_rejected_is_ingest_error():
"""IngestRejected is a subclass of IngestError."""
assert issubclass(IngestRejected, IngestError)


def test_ingest_error_chain_catchable_as_base():
"""Subclasses can be caught as the base IngestError."""
with pytest.raises(IngestError):
raise IngestUnavailable("transient")
with pytest.raises(IngestError):
raise IngestRejected("bad payload")


@pytest.mark.parametrize(
"exc_cls,msg",
[
pytest.param(IngestError, "base", id="base"),
pytest.param(IngestUnavailable, "transient", id="unavailable"),
pytest.param(IngestRejected, "permanent", id="rejected"),
],
)
def test_exceptions_carry_message(exc_cls, msg):
"""Each exception carries its constructor message via str()."""
exc = exc_cls(msg)
assert str(exc) == msg
Loading
Loading