Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 328 additions & 3 deletions worker/tests/policy/test_runner.py
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey claude,
design alternative for the new tests - lots of them share the same setup - can that test preparation be shared across the tests as a common step? Design and /notify-me

Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from netboxlabs.diode.sdk.diode.v1 import ingester_pb2

from worker.backend import Backend
from worker.exceptions import IngestError, IngestRejected
from worker.models import Config, DiodeConfig, Metadata, Policy, Status
from worker.policy.run import RunStore
from worker.policy.run import RunStatus, RunStore
from worker.policy.runner import PolicyRunner


Expand Down Expand Up @@ -116,6 +117,11 @@ def mock_backend():
return backend


def _extract_callback(mock_backend_class):
"""Recover the ingest_callback closure that PolicyRunner.setup attached to the backend."""
return mock_backend_class.return_value.ingest_callback


def test_initial_status(policy_runner):
"""Test initial status of PolicyRunner."""
assert policy_runner.status == Status.NEW
Expand Down Expand Up @@ -347,6 +353,16 @@ def test_run_backend_exception(
mock_diode_client.ingest.assert_not_called() # Client ingestion should not be called
assert "Policy test_policy: Backend error" in caplog.text

# Regression guard: a crashing scheduled backend must still be recorded as a
# FAILED run — the run is created before the backend executes.
mock_run_store.create_run.assert_called_once()
failed_updates = [
c
for c in mock_run_store.update_run.call_args_list
if c.kwargs.get("status") == RunStatus.FAILED
]
assert failed_updates, "backend crash must record a FAILED run"


def test_stop_policy_runner(policy_runner):
"""Test stopping the PolicyRunner."""
Expand Down Expand Up @@ -545,7 +561,7 @@ def test_run_with_multiple_chunks(
assert mock_diode_client.ingest.call_count == 2

# Verify log messages for successful ingestion
assert "Successfully ingested 10 entities in 2 chunks" in caplog.text
assert "Successfully ingested 10 entities" in caplog.text


def test_run_chunk_ingestion_error(
Expand Down Expand Up @@ -587,8 +603,317 @@ def test_run_chunk_ingestion_error(
with caplog.at_level("ERROR"):
policy_runner.run(mock_diode_client, mock_backend, sample_policy)

# Should call ingest once and fail on first chunk error (it raises RuntimeError immediately)
# Should call ingest once and fail on first chunk error (it raises IngestRejected immediately)
Comment thread
ldrozdz93 marked this conversation as resolved.
Outdated
assert mock_diode_client.ingest.call_count == 2

# Should log the error
assert "Chunk ingestion failed" in caplog.text


# ---------------------------------------------------------------------------
# New tests: _build_ingest_callback
Comment thread
ldrozdz93 marked this conversation as resolved.
# ---------------------------------------------------------------------------


def test_setup_constructs_backend_directly_and_attaches_callback(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""setup() constructs the backend with no args and attaches ingest_callback after deps are ready."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

mock_backend_class = mock_load_class.return_value
assert mock_backend_class.call_args.args == ()
assert mock_backend_class.call_args.kwargs == {}
backend = mock_backend_class.return_value
assert callable(backend.ingest_callback)


def test_ingest_callback_entities_happy_path(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Callback with entities= ingests them and records COMPLETED run."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)

client_instance = mock_diode_client.return_value
client_instance.ingest.return_value.errors = []

entity1 = ingester_pb2.Entity()
entity1.device.name = "dev1"
entity2 = ingester_pb2.Entity()
entity2.device.name = "dev2"

with patch("worker.policy.runner.apply_run_id_to_entities"), patch(
"worker.policy.runner.estimate_message_size", return_value=1024
), patch("worker.policy.runner.create_message_chunks"):
result = callback(entities=[entity1, entity2])

assert result is None
mock_run_store.create_run.assert_called_once()
client_instance.ingest.assert_called_once()
call_kwargs = client_instance.ingest.call_args.kwargs
assert len(call_kwargs["entities"]) == 2
mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.COMPLETED


def test_ingest_callback_error_path(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Callback with error= records FAILED run and skips client.ingest."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)
client_instance = mock_diode_client.return_value

err = Exception("vendor unreachable")
result = callback(error=err)

assert result is None
client_instance.ingest.assert_not_called()
mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.FAILED
assert update_kwargs["error"] is err


@pytest.mark.parametrize(
"kwargs",
[
pytest.param({}, id="neither"),
pytest.param({"entities": [], "error": Exception("x")}, id="both"),
],
)
def test_ingest_callback_requires_exactly_one_of_entities_or_error(
kwargs,
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Callback raises TypeError when neither or both of entities/error are given."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)

with pytest.raises(TypeError):
callback(**kwargs)


def test_ingest_callback_translates_transport_errors_to_ingest_error(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Non-IngestError transport exceptions are wrapped as the base IngestError."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)
client_instance = mock_diode_client.return_value
client_instance.ingest.side_effect = RuntimeError("connection refused")

entity = ingester_pb2.Entity()
entity.device.name = "dev1"

with patch("worker.policy.runner.estimate_message_size", return_value=1024), patch(
"worker.policy.runner.apply_run_id_to_entities"
):
with pytest.raises(IngestError):
callback(entities=[entity])

mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.FAILED


def test_ingest_callback_translates_response_errors_to_rejected(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Response errors (non-empty errors list) are raised as IngestRejected."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)
client_instance = mock_diode_client.return_value
client_instance.ingest.return_value.errors = ["bad payload"]

entity = ingester_pb2.Entity()
entity.device.name = "dev1"

with patch("worker.policy.runner.estimate_message_size", return_value=1024), patch(
"worker.policy.runner.apply_run_id_to_entities"
):
with pytest.raises(IngestRejected):
callback(entities=[entity])

mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.FAILED


def test_ingest_callback_chunks_large_payloads(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""Callback splits large payloads into chunks and ingests each separately."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)
client_instance = mock_diode_client.return_value
client_instance.ingest.return_value.errors = []

entity1 = ingester_pb2.Entity()
entity1.device.name = "dev1"
entity2 = ingester_pb2.Entity()
entity2.device.name = "dev2"
chunk_a = [entity1]
chunk_b = [entity2]

with patch(
"worker.policy.runner.estimate_message_size", return_value=4 * 1024 * 1024
), patch(
"worker.policy.runner.create_message_chunks", return_value=[chunk_a, chunk_b]
), patch(
"worker.policy.runner.apply_run_id_to_entities"
):
callback(entities=[entity1, entity2])

assert client_instance.ingest.call_count == 2
mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.COMPLETED


def test_run_unaffected_by_callback(
policy_runner, sample_policy, mock_diode_client, mock_backend, mock_run_store
):
"""PolicyRunner.run() is unaffected by the new callback mechanism."""
policy_runner.name = "test_policy"
policy_runner.run_store = mock_run_store

entity = ingester_pb2.Entity()
entity.device.name = "device-x"
mock_backend.run.return_value = [entity]
mock_diode_client.ingest.return_value.errors = []

with patch("worker.policy.runner.estimate_message_size", return_value=512):
policy_runner.run(mock_diode_client, mock_backend, sample_policy)

mock_backend.run.assert_called_once_with("test_policy", sample_policy)
mock_diode_client.ingest.assert_called_once()
mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.COMPLETED


def test_ingest_callback_records_failure_on_apply_run_id_error(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""apply_run_id_to_entities failure inside try: records FAILED run as IngestError."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)

entity = ingester_pb2.Entity()
entity.device.name = "dev1"

with patch(
"worker.policy.runner.apply_run_id_to_entities",
side_effect=RuntimeError("entity corrupt"),
), patch("worker.policy.runner.estimate_message_size", return_value=1024):
with pytest.raises(IngestError):
callback(entities=[entity])

mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.FAILED
# The entity was materialised before apply_run_id_to_entities raised.
assert update_kwargs["entity_count"] == 1


def test_ingest_callback_records_failure_on_iterable_error(
policy_runner,
sample_policy,
sample_diode_config,
mock_load_class,
mock_diode_client,
mock_run_store,
):
"""An iterable that raises on first next() records FAILED run with entity_count=0."""
with patch.object(policy_runner.scheduler, "start"), patch.object(
policy_runner.scheduler, "add_job"
):
policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store)

callback = _extract_callback(mock_load_class.return_value)

bad_iterable = MagicMock()
bad_iterable.__iter__ = MagicMock(side_effect=ValueError("bad"))

with pytest.raises(IngestError):
callback(entities=bad_iterable)

mock_run_store.update_run.assert_called_once()
update_kwargs = mock_run_store.update_run.call_args.kwargs
assert update_kwargs["status"] == RunStatus.FAILED
# Iterable failed before any entity was materialised.
assert update_kwargs["entity_count"] == 0
Loading
Loading