diff --git a/worker/pyproject.toml b/worker/pyproject.toml index d1f68de8..b60ec926 100644 --- a/worker/pyproject.toml +++ b/worker/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "netboxlabs-orb-worker" -version = "1.0.0" # Overwritten during the build process +version = "1.4.0" # Overwritten during the build process description = "NetBox Labs, Worker backend for Orb Agent" readme = "README.md" requires-python = ">=3.10" diff --git a/worker/tests/policy/test_runner.py b/worker/tests/policy/test_runner.py index 7365c2e0..774555cd 100644 --- a/worker/tests/policy/test_runner.py +++ b/worker/tests/policy/test_runner.py @@ -9,8 +9,9 @@ from netboxlabs.diode.sdk.diode.v1 import ingester_pb2 from worker.backend import Backend +from worker.exceptions import IngestRejected, IngestUnavailable from worker.models import Config, DiodeConfig, Metadata, Policy, Status -from worker.policy.run import RunStore +from worker.policy.run import RunStatus, RunStore from worker.policy.runner import PolicyRunner @@ -116,6 +117,11 @@ def mock_backend(): return backend +def _extract_callback(mock_backend_class): + """Recover the ingest_callback closure that PolicyRunner.setup passed to backend_class().""" + return mock_backend_class.call_args.kwargs["ingest_callback"] + + def test_initial_status(policy_runner): """Test initial status of PolicyRunner.""" assert policy_runner.status == Status.NEW @@ -587,8 +593,426 @@ def test_run_chunk_ingestion_error( with caplog.at_level("ERROR"): policy_runner.run(mock_diode_client, mock_backend, sample_policy) - # Should call ingest once and fail on first chunk error (it raises RuntimeError immediately) + # Should call ingest once and fail on first chunk error (it raises IngestRejected immediately) assert mock_diode_client.ingest.call_count == 2 # Should log the error assert "Chunk ingestion failed" in caplog.text + + +# --------------------------------------------------------------------------- +# New tests: _build_ingest_callback +# --------------------------------------------------------------------------- + + +def test_setup_passes_kwargs_to_backend( + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """setup() constructs the backend with ingest_callback= kwarg (no policy= per ADR-0008).""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + mock_backend_class = mock_load_class.return_value + call_kwargs = mock_backend_class.call_args.kwargs + assert "ingest_callback" in call_kwargs + assert callable(call_kwargs["ingest_callback"]) + assert "policy" not in call_kwargs + + +@pytest.mark.parametrize( + "init_signature, expects_ingest_callback", + [ + pytest.param( + "def __init__(self): self.ingest_callback = None", + False, + id="legacy-zero-arg", + ), + pytest.param( + "def __init__(self, **kwargs): self.ingest_callback = kwargs.get('ingest_callback')", + True, + id="kwargs-absorber", + ), + pytest.param( + "def __init__(self, *, ingest_callback=None, **kwargs): self.ingest_callback = ingest_callback", + True, + id="named-kwarg", + ), + ], +) +def test_construct_backend_introspection(init_signature, expects_ingest_callback): + """_construct_backend only passes ingest_callback when the class accepts it.""" + from worker.policy.runner import _construct_backend + + namespace: dict = {} + exec( # noqa: S102 — synthesizing tiny class fixture under test + "class _Stub:\n" + " " + init_signature + "\n", + namespace, + ) + stub_class = namespace["_Stub"] + sentinel = object() + instance = _construct_backend(stub_class, ingest_callback=sentinel) + if expects_ingest_callback: + assert instance.ingest_callback is sentinel + else: + assert instance.ingest_callback is None + + +def test_ingest_callback_entities_happy_path( + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """Callback with entities= ingests them and records COMPLETED run.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + callback = _extract_callback(mock_load_class.return_value) + + client_instance = mock_diode_client.return_value + client_instance.ingest.return_value.errors = [] + + entity1 = ingester_pb2.Entity() + entity1.device.name = "dev1" + entity2 = ingester_pb2.Entity() + entity2.device.name = "dev2" + + with patch("worker.policy.runner.apply_run_id_to_entities"), patch( + "worker.policy.runner.estimate_message_size", return_value=1024 + ), patch("worker.policy.runner.create_message_chunks"): + result = callback(entities=[entity1, entity2]) + + assert result is None + mock_run_store.create_run.assert_called_once() + client_instance.ingest.assert_called_once() + call_kwargs = client_instance.ingest.call_args.kwargs + assert len(call_kwargs["entities"]) == 2 + mock_run_store.update_run.assert_called_once() + update_kwargs = mock_run_store.update_run.call_args.kwargs + assert update_kwargs["status"] == RunStatus.COMPLETED + + +def test_ingest_callback_error_path( + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """Callback with error= records FAILED run and skips client.ingest.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + callback = _extract_callback(mock_load_class.return_value) + client_instance = mock_diode_client.return_value + + err = Exception("vendor unreachable") + result = callback(error=err) + + assert result is None + client_instance.ingest.assert_not_called() + mock_run_store.update_run.assert_called_once() + update_kwargs = mock_run_store.update_run.call_args.kwargs + assert update_kwargs["status"] == RunStatus.FAILED + assert update_kwargs["error"] is err + + +@pytest.mark.parametrize( + "kwargs", + [ + pytest.param({}, id="neither"), + pytest.param({"entities": [], "error": Exception("x")}, id="both"), + ], +) +def test_ingest_callback_requires_exactly_one_of_entities_or_error( + kwargs, + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """Callback raises TypeError when neither or both of entities/error are given.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + callback = _extract_callback(mock_load_class.return_value) + + with pytest.raises(TypeError): + callback(**kwargs) + + +def test_ingest_callback_translates_transport_errors_to_unavailable( + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """Non-IngestError transport exceptions are wrapped as IngestUnavailable.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + callback = _extract_callback(mock_load_class.return_value) + client_instance = mock_diode_client.return_value + client_instance.ingest.side_effect = RuntimeError("connection refused") + + entity = ingester_pb2.Entity() + entity.device.name = "dev1" + + with patch("worker.policy.runner.estimate_message_size", return_value=1024), patch( + "worker.policy.runner.apply_run_id_to_entities" + ): + with pytest.raises(IngestUnavailable): + callback(entities=[entity]) + + mock_run_store.update_run.assert_called_once() + update_kwargs = mock_run_store.update_run.call_args.kwargs + assert update_kwargs["status"] == RunStatus.FAILED + + +def test_ingest_callback_translates_response_errors_to_rejected( + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """Response errors (non-empty errors list) are raised as IngestRejected.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + callback = _extract_callback(mock_load_class.return_value) + client_instance = mock_diode_client.return_value + client_instance.ingest.return_value.errors = ["bad payload"] + + entity = ingester_pb2.Entity() + entity.device.name = "dev1" + + with patch("worker.policy.runner.estimate_message_size", return_value=1024), patch( + "worker.policy.runner.apply_run_id_to_entities" + ): + with pytest.raises(IngestRejected): + callback(entities=[entity]) + + mock_run_store.update_run.assert_called_once() + update_kwargs = mock_run_store.update_run.call_args.kwargs + assert update_kwargs["status"] == RunStatus.FAILED + + +def test_ingest_callback_chunks_large_payloads( + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """Callback splits large payloads into chunks and ingests each separately.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + callback = _extract_callback(mock_load_class.return_value) + client_instance = mock_diode_client.return_value + client_instance.ingest.return_value.errors = [] + + entity1 = ingester_pb2.Entity() + entity1.device.name = "dev1" + entity2 = ingester_pb2.Entity() + entity2.device.name = "dev2" + chunk_a = [entity1] + chunk_b = [entity2] + + with patch( + "worker.policy.runner.estimate_message_size", return_value=4 * 1024 * 1024 + ), patch( + "worker.policy.runner.create_message_chunks", return_value=[chunk_a, chunk_b] + ), patch( + "worker.policy.runner.apply_run_id_to_entities" + ): + callback(entities=[entity1, entity2]) + + assert client_instance.ingest.call_count == 2 + mock_run_store.update_run.assert_called_once() + update_kwargs = mock_run_store.update_run.call_args.kwargs + assert update_kwargs["status"] == RunStatus.COMPLETED + + +def test_run_unaffected_by_callback( + policy_runner, sample_policy, mock_diode_client, mock_backend, mock_run_store +): + """PolicyRunner.run() is unaffected by the new callback mechanism.""" + policy_runner.name = "test_policy" + policy_runner.run_store = mock_run_store + + entity = ingester_pb2.Entity() + entity.device.name = "device-x" + mock_backend.run.return_value = [entity] + mock_diode_client.ingest.return_value.errors = [] + + with patch("worker.policy.runner.estimate_message_size", return_value=512): + policy_runner.run(mock_diode_client, mock_backend, sample_policy) + + mock_backend.run.assert_called_once_with("test_policy", sample_policy) + mock_diode_client.ingest.assert_called_once() + mock_run_store.update_run.assert_called_once() + update_kwargs = mock_run_store.update_run.call_args.kwargs + assert update_kwargs["status"] == RunStatus.COMPLETED + + +def test_ingest_callback_raises_when_not_ready( + policy_runner, + sample_policy, + sample_diode_config, + mock_run_store, + mock_load_class, + mock_diode_client, +): + """Calling the callback before _callback_ready is True raises IngestUnavailable.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy-x", sample_diode_config, sample_policy, mock_run_store) + + callback = _extract_callback(mock_load_class.return_value) + + # Simulate "called before worker finished constructing" — clear the readiness flag. + policy_runner._callback_ready = False + + entity = MagicMock() + with pytest.raises(IngestUnavailable, match="before the worker finished constructing"): + callback(entities=[entity]) + + # No pseudo-run must have been created (guard fires before create_run). + mock_run_store.create_run.assert_not_called() + + +def test_ingest_callback_records_failure_on_apply_run_id_error( + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """apply_run_id_to_entities failure inside try: records FAILED run as IngestUnavailable.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + callback = _extract_callback(mock_load_class.return_value) + + entity = ingester_pb2.Entity() + entity.device.name = "dev1" + + with patch( + "worker.policy.runner.apply_run_id_to_entities", + side_effect=RuntimeError("entity corrupt"), + ), patch("worker.policy.runner.estimate_message_size", return_value=1024): + with pytest.raises(IngestUnavailable): + callback(entities=[entity]) + + mock_run_store.update_run.assert_called_once() + update_kwargs = mock_run_store.update_run.call_args.kwargs + assert update_kwargs["status"] == RunStatus.FAILED + # The entity was materialised before apply_run_id_to_entities raised. + assert update_kwargs["entity_count"] == 1 + + +def test_ingest_callback_records_failure_on_iterable_error( + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """An iterable that raises on first next() records FAILED run with entity_count=0.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + callback = _extract_callback(mock_load_class.return_value) + + bad_iterable = MagicMock() + bad_iterable.__iter__ = MagicMock(side_effect=ValueError("bad")) + + with pytest.raises(IngestUnavailable): + callback(entities=bad_iterable) + + mock_run_store.update_run.assert_called_once() + update_kwargs = mock_run_store.update_run.call_args.kwargs + assert update_kwargs["status"] == RunStatus.FAILED + # Iterable failed before any entity was materialised. + assert update_kwargs["entity_count"] == 0 + + +def test_setup_sets_callback_ready_flag( + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """After setup() returns, _callback_ready is True.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + assert policy_runner._callback_ready is True + + +def test_stop_clears_callback_ready_flag( + policy_runner, + sample_policy, + sample_diode_config, + mock_load_class, + mock_diode_client, + mock_run_store, +): + """After stop(), _callback_ready is False and the callback raises IngestUnavailable.""" + with patch.object(policy_runner.scheduler, "start"), patch.object( + policy_runner.scheduler, "add_job" + ): + policy_runner.setup("policy1", sample_diode_config, sample_policy, mock_run_store) + + callback = _extract_callback(mock_load_class.return_value) + + with patch.object(policy_runner.scheduler, "shutdown"): + policy_runner.stop() + + assert policy_runner._callback_ready is False + + entity = MagicMock() + with pytest.raises(IngestUnavailable, match="before the worker finished constructing"): + callback(entities=[entity]) diff --git a/worker/tests/test_backend.py b/worker/tests/test_backend.py index 8c2a18b6..827a2801 100644 --- a/worker/tests/test_backend.py +++ b/worker/tests/test_backend.py @@ -88,3 +88,35 @@ def test_load_class_attribute_error(mock_import_module): match=f"Failed to load a class inheriting from 'Backend' in module '{mock_module_name}': Attribute error", ): load_class(mock_module_name) + + +def test_backend_init_default_no_args(): + """Zero-arg construction still works (back-compat with older worker).""" + b = Backend() + assert b.ingest_callback is None + + +def test_backend_init_stores_ingest_callback(): + """ingest_callback is stored on the instance.""" + + def cb(**_): + return None + + b = Backend(ingest_callback=cb) + assert b.ingest_callback is cb + + +def test_backend_init_absorbs_unknown_kwargs(): + """Forward-compat: unknown kwargs don't raise.""" + Backend(unknown_future_resource="x", another_one=42) # must not raise + + +def test_backend_run_accepts_kwargs(): + """run() signature absorbs **kwargs (passive forward-compat door).""" + b = Backend() + # The base implementation raises NotImplementedError; the point is + # that calling with kwargs reaches the body without a TypeError. + try: + b.run("policy", MagicMock(spec=Policy), future_kwarg="x") + except NotImplementedError: + pass diff --git a/worker/tests/test_exceptions.py b/worker/tests/test_exceptions.py new file mode 100644 index 00000000..d66d071d --- /dev/null +++ b/worker/tests/test_exceptions.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# Copyright 2026 NetBox Labs Inc +"""NetBox Labs - worker.exceptions hierarchy tests.""" + +import pytest + +from worker.exceptions import IngestError, IngestRejected, IngestUnavailable + + +def test_unavailable_is_ingest_error(): + """IngestUnavailable is a subclass of IngestError.""" + assert issubclass(IngestUnavailable, IngestError) + + +def test_rejected_is_ingest_error(): + """IngestRejected is a subclass of IngestError.""" + assert issubclass(IngestRejected, IngestError) + + +def test_ingest_error_chain_catchable_as_base(): + """Subclasses can be caught as the base IngestError.""" + with pytest.raises(IngestError): + raise IngestUnavailable("transient") + with pytest.raises(IngestError): + raise IngestRejected("bad payload") + + +@pytest.mark.parametrize( + "exc_cls,msg", + [ + pytest.param(IngestError, "base", id="base"), + pytest.param(IngestUnavailable, "transient", id="unavailable"), + pytest.param(IngestRejected, "permanent", id="rejected"), + ], +) +def test_exceptions_carry_message(exc_cls, msg): + """Each exception carries its constructor message via str().""" + exc = exc_cls(msg) + assert str(exc) == msg diff --git a/worker/worker/backend.py b/worker/worker/backend.py index 3571c204..f6eec1fa 100644 --- a/worker/worker/backend.py +++ b/worker/worker/backend.py @@ -14,6 +14,36 @@ class Backend: """Backend Class.""" + def __init__( + self, + *, + ingest_callback=None, + **kwargs, + ) -> None: + """ + Construct the Backend. + + Worker passes ``ingest_callback`` at construction starting with the + minor release this docstring ships in. Older worker versions + construct ``Backend()`` with zero args; integrations that override + ``__init__`` should accept ``**kwargs`` so both paths keep working. + + Args: + ---- + ingest_callback: Optional callable that ingests entities or + reports errors outside of the ``run()`` cycle. **Do not + invoke from ``__init__`` or ``setup()`` — the callback is + only usable starting after the worker finishes constructing + the Backend (i.e. after ``setup()`` returns); calling it + earlier raises ``IngestUnavailable``.** See + ``worker.exceptions`` for the exception hierarchy it may + raise. + **kwargs: Forward-compat door for additional resources worker + may pass in future versions; silently ignored by default. + + """ + self.ingest_callback = ingest_callback + def setup(self) -> Metadata: """ Set up the backend. @@ -25,7 +55,12 @@ def setup(self) -> Metadata: """ raise NotImplementedError("The 'setup' method must be implemented.") - def run(self, policy_name: str, policy: Policy) -> Iterable[Entity]: + def run( + self, + policy_name: str, + policy: Policy, + **kwargs, + ) -> Iterable[Entity]: """ Run the backend. @@ -33,10 +68,16 @@ def run(self, policy_name: str, policy: Policy) -> Iterable[Entity]: ---- policy_name (str): The name of the policy. policy (Policy): The policy to run. + **kwargs: Passive forward-compat door. The worker passes nothing + through it in v1; future minor releases may add per-tick + context (e.g. ``source="scheduled"|"trigger"``, ``run_id``). + Concrete backends are encouraged to declare ``**kwargs`` so + additive kwargs ride into the contract without a coordinated + upgrade. Returns: ------- - Iterable[Entity]: The entities produced by the backend + Iterable[Entity]: The entities produced by the backend. """ raise NotImplementedError("The 'run' method must be implemented.") diff --git a/worker/worker/exceptions.py b/worker/worker/exceptions.py new file mode 100644 index 00000000..71bb4eb9 --- /dev/null +++ b/worker/worker/exceptions.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# Copyright 2026 NetBox Labs Inc +"""Exception hierarchy raised by the worker ingest callback.""" + + +class IngestError(Exception): + """ + Base for pipeline-side ingestion failures. + + Integrations should catch this when they want uniform handling. New + subclasses are added under this base in future minor releases. + """ + + +class IngestUnavailable(IngestError): + """ + Transient pipeline failure — Diode unreachable, queue full, rate-limited. + + Retry-friendly. The integration MAY retry with backoff. + """ + + +class IngestRejected(IngestError): + """ + Permanent pipeline rejection for this call. + + Reasons include bad payload, instance retired, or policy removed. + The integration should NOT retry; the call will fail again. + """ diff --git a/worker/worker/policy/runner.py b/worker/worker/policy/runner.py index 0f5baa57..7083c389 100644 --- a/worker/worker/policy/runner.py +++ b/worker/worker/policy/runner.py @@ -2,6 +2,7 @@ # Copyright 2025 NetBox Labs Inc """Orb Worker Policy Runner.""" +import inspect import logging import time from datetime import datetime, timedelta @@ -19,14 +20,40 @@ from worker.backend import Backend, load_class from worker.entity_metadata import apply_run_id_to_entities +from worker.exceptions import IngestError, IngestRejected, IngestUnavailable from worker.metrics import get_metric from worker.models import DiodeConfig, Policy, Status from worker.package_finder import maybe_evict from worker.policy.run import RunStatus, RunStore +# Diode message-size cap (per chunk). Stays in sync with the reconciler's +# 4 MiB gRPC ceiling minus a safety margin. +MAX_INGEST_MESSAGE_BYTES = 3 * 1024 * 1024 + logger = logging.getLogger(__name__) +def _construct_backend(backend_class, *, ingest_callback): + """ + Construct a Backend, passing ``ingest_callback`` only when accepted. + + Uses ``inspect.signature`` to detect whether ``backend_class.__init__`` + has a matching named parameter or ``VAR_KEYWORD`` absorber. Legacy + backends with ``__init__(self)`` get constructed zero-arg and never + see the kwarg, so the worker can ship this contract bump without a + coordinated upgrade across every integration package. + """ + sig = inspect.signature(backend_class) + params = sig.parameters + accepts_var_kw = any( + p.kind is inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + init_kwargs: dict[str, object] = {} + if accepts_var_kw or "ingest_callback" in params: + init_kwargs["ingest_callback"] = ingest_callback + return backend_class(**init_kwargs) + + class PolicyRunner: """Policy Runner class.""" @@ -38,6 +65,8 @@ def __init__(self): self.status = Status.NEW self.scheduler = BackgroundScheduler() self.run_store = None + self._diode_client = None + self._callback_ready = False def setup( self, name: str, diode_config: DiodeConfig, policy: Policy, run_store: RunStore @@ -65,7 +94,12 @@ def setup( # Debug logging for backend loading logger.debug(f"Loading backend class: {policy.config.package}") backend_class = load_class(policy.config.package) - backend = backend_class() + + # Build the ingest callback closure. It captures `self` and reads + # `self._diode_client` lazily, so it is safe to construct before the + # client is assigned below. + ingest_callback = self._build_ingest_callback(self.name) + backend = _construct_backend(backend_class, ingest_callback=ingest_callback) logger.debug(f"Backend class loaded successfully: {backend_class.__name__}") metadata = backend.setup() @@ -101,6 +135,7 @@ def setup( self.metadata = metadata self.policy = policy self.run_store = run_store + self._diode_client = client self.scheduler.start() @@ -123,10 +158,126 @@ def setup( self.status = Status.RUNNING + # Callback is now safe to invoke — every dependency the closure reads + # (run_store, metadata, _diode_client) is attached. Integrations may push + # entities via ingest_callback from this point onward. + self._callback_ready = True + active_policies = get_metric("active_policies") if active_policies: active_policies.add(1, {"policy": self.name}) + def _build_ingest_callback(self, policy_name: str): + """ + Build a closure used to ingest entities outside the scheduled run() cycle. + + The returned callable signature: + cb(entities=None, *, error=None, **kwargs) -> None + + Exactly one of ``entities`` / ``error`` must be supplied. + On the ``entities`` path: a pseudo-run is created in the RunStore, + entities are chunked and ingested via the same path run() uses, and + response/transport errors are translated into IngestRejected / + IngestUnavailable. On the ``error`` path: a failed pseudo-run is + recorded; no client.ingest call is made; returns None. + """ + + def ingest_callback( + entities=None, + *, + error: Exception | None = None, + **kwargs, + ) -> None: + # kwargs is reserved for forward-compat (run_id, source, etc.); currently ignored. + if (entities is None) == (error is None): + raise TypeError( + "ingest_callback requires exactly one of 'entities' or 'error'" + ) + if not self._callback_ready: + raise IngestUnavailable( + "ingest_callback invoked before the worker finished constructing " + "this Backend (likely called from Backend.__init__ or " + "Backend.setup() — defer until after setup() returns)" + ) + run = self.run_store.create_run( + policy_name=policy_name, + metadata={ + "name": self.metadata.name, + "app_name": self.metadata.app_name, + "app_version": self.metadata.app_version, + "source": "ingest_callback", + }, + ) + if error is not None: + self.run_store.update_run( + policy_name=policy_name, + run_id=run.id, + status=RunStatus.FAILED, + error=error, + entity_count=0, + ) + return + entities_list: list = [] + try: + entities_list = list(entities) + apply_run_id_to_entities(entities_list, run.id) + metadata = { + "policy_name": policy_name, + "worker_backend": self.metadata.name, + "run_id": run.id, + } + self._send_entities(self._diode_client, entities_list, metadata) + except IngestError as exc: + self.run_store.update_run( + policy_name=policy_name, + run_id=run.id, + status=RunStatus.FAILED, + error=exc, + entity_count=len(entities_list), + ) + raise + except Exception as exc: + logger.exception( + "Unexpected exception in ingest_callback; " + "translating to IngestUnavailable" + ) + self.run_store.update_run( + policy_name=policy_name, + run_id=run.id, + status=RunStatus.FAILED, + error=exc, + entity_count=len(entities_list), + ) + raise IngestUnavailable(str(exc)) from exc + self.run_store.update_run( + policy_name=policy_name, + run_id=run.id, + status=RunStatus.COMPLETED, + error=None, + entity_count=len(entities_list), + ) + + return ingest_callback + + def _send_entities(self, client, entities_list: list, metadata: dict) -> int: + """ + Send entities to the Diode client, chunking if the payload exceeds MAX_INGEST_MESSAGE_BYTES. + + Returns the number of chunks actually sent (1 if not chunked). + """ + size_bytes = estimate_message_size(entities_list) + if size_bytes > MAX_INGEST_MESSAGE_BYTES: + chunks = create_message_chunks(entities_list) + for chunk in chunks: + response = client.ingest(entities=chunk, metadata=metadata) + if response.errors: + raise IngestRejected(f"Chunk ingestion failed: {response.errors}") + return len(chunks) + response = client.ingest(entities=entities_list, metadata=metadata) + if response.errors: + raise IngestRejected(f"Entities ingestion failed: {response.errors}") + return 1 + def run( self, client: DiodeClient | DiodeDryRunClient | DiodeOTLPClient, @@ -174,20 +325,7 @@ def run( "worker_backend": self.metadata.name, "run_id": run.id, } - chunk_num = 1 - size_bytes = estimate_message_size(entities) - - if size_bytes > (3.0 * 1024 * 1024): - chunks = create_message_chunks(entities) - chunk_num = len(chunks) - for chunk in chunks: - response = client.ingest(entities=chunk, metadata=metadata) - if response.errors: - raise RuntimeError(f"Chunk ingestion failed: {response.errors}") - else: - response = client.ingest(entities=entities, metadata=metadata) - if response.errors: - raise RuntimeError(f"Entities ingestion failed: {response.errors}") + chunk_num = self._send_entities(client, entities, metadata) logger.info( f"Policy {self.name}: Successfully ingested {entity_count} entities in {chunk_num} chunks" ) @@ -251,6 +389,7 @@ def run( def stop(self): """Stop the policy runner.""" + self._callback_ready = False self.scheduler.shutdown(wait=False) self.status = Status.FINISHED active_policies = get_metric("active_policies")