From bf2ce35d2f6085af0cf05bcabf62ad4262adbad3 Mon Sep 17 00:00:00 2001 From: RAJVEER42 Date: Sat, 23 May 2026 16:32:59 +0530 Subject: [PATCH] fix(safaa): warn when threshold argument cannot be honored SafaaAgent.predict accepts a `threshold` parameter, but only the predict_proba branch consults it. The shipped SGD classifier uses loss='hinge' and therefore has no predict_proba (sklearn's @available_if descriptor raises AttributeError on access), so the hasattr() check is False at runtime and execution falls to the binary predict() path that ignores the threshold entirely. Callers tuning the sensitivity get the default SVM decision boundary with no indication that their argument had no effect. This was working when the original model was trained with a probability-supporting loss; the regression slipped in when the SGD(hinge) model replaced it. Emit a UserWarning when a non-default threshold is passed but the loaded model cannot honor it, and document the constraint in the docstring. Default usage stays warning-free. Tests cover: no warning at default threshold, warning at non-default threshold, warning text mentions predict_proba, warnings at extreme threshold values, predictions remain valid alongside the warning, and a monkeypatched fake classifier proves the threshold actually controls output when predict_proba is available (with a boundary-inclusive check). Signed-off-by: RAJVEER42 --- Safaa/src/safaa/Safaa.py | 21 +++++++- pyproject.toml | 1 + tests/__init__.py | 3 ++ tests/test_safaa.py | 102 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_safaa.py diff --git a/Safaa/src/safaa/Safaa.py b/Safaa/src/safaa/Safaa.py index cfa4cd5..07144e5 100644 --- a/Safaa/src/safaa/Safaa.py +++ b/Safaa/src/safaa/Safaa.py @@ -9,6 +9,7 @@ import os import re import spacy +import warnings from joblib import load, dump from importlib.resources import files, as_file from importlib.metadata import version @@ -235,7 +236,14 @@ def predict(self, data, threshold=0.5): Parameters: data (iterable): The data to predict. threshold (float): The probability threshold for classification. - Defaults to 0.5. + Defaults to 0.5. Only used when the loaded + classifier supports predict_proba (e.g., SGD + with loss='log_loss' or 'modified_huber'). + The default shipped model uses loss='hinge' + and does not produce probabilities, so this + argument is ignored for that model. A + UserWarning is emitted when a non-default + threshold is passed but the model cannot use it. Returns: list: The predictions. @@ -258,6 +266,17 @@ def predict(self, data, threshold=0.5): for prediction in predictions ] + if threshold != 0.5: + warnings.warn( + "The loaded false positive detector does not support " + "probability estimates (predict_proba); the 'threshold' " + "argument is being ignored. Train a classifier with " + "loss='log_loss' or 'modified_huber' if threshold control " + "is required.", + UserWarning, + stacklevel=2, + ) + # Get binary predictions from the model if probability prediction is not # supported return [ diff --git a/pyproject.toml b/pyproject.toml index 830a55d..8e29e89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ psycopg2-binary = '>=2.9' requests = '>=2.28' flake8 = '*' build = '*' +pytest = '*' [project] name = 'safaa' diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d31f995 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: © 2026 RAJVEER42 +# +# SPDX-License-Identifier: LGPL-2.1-only diff --git a/tests/test_safaa.py b/tests/test_safaa.py new file mode 100644 index 0000000..3c09a16 --- /dev/null +++ b/tests/test_safaa.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: © 2026 RAJVEER42 +# +# SPDX-License-Identifier: LGPL-2.1-only + +import warnings + +import pytest + +from safaa.Safaa import SafaaAgent + + +@pytest.fixture(scope="module") +def agent(): + return SafaaAgent() + + +# --------------------------------------------------------------------------- +# predict — threshold argument behavior +# --------------------------------------------------------------------------- + +class _FakeProbaClassifier: + """Test double: predict_proba returns a fixed probability for class 1.""" + + def __init__(self, prob_class_one): + self.prob_class_one = prob_class_one + + def predict_proba(self, X): + n = X.shape[0] + p1 = self.prob_class_one + return [[1.0 - p1, p1] for _ in range(n)] + + def predict(self, X): + # Not used when predict_proba exists; included so introspection + # sees a normal estimator interface. + return [1 if self.prob_class_one >= 0.5 else 0 for _ in range(X.shape[0])] + + +class TestPredictThreshold: + + SAMPLE = "Copyright 2024 Siemens AG" + + def test_no_warning_with_default_threshold_on_hinge_model(self, agent): + # Shipped model is SGD(loss='hinge') which has no predict_proba. + # Calling predict() with the default threshold must NOT warn. + with warnings.catch_warnings(): + warnings.simplefilter("error") # promote any warning to an error + agent.predict([self.SAMPLE]) + + def test_warning_fires_on_non_default_threshold_without_predict_proba(self, agent): + with pytest.warns(UserWarning, match="threshold"): + agent.predict([self.SAMPLE], threshold=0.3) + + def test_warning_message_mentions_predict_proba(self, agent): + with pytest.warns(UserWarning) as record: + agent.predict([self.SAMPLE], threshold=0.7) + assert any("predict_proba" in str(w.message) for w in record) + + def test_warning_fires_even_at_extreme_thresholds(self, agent): + for t in (0.0, 1.0, 0.99, 0.01): + with pytest.warns(UserWarning): + agent.predict([self.SAMPLE], threshold=t) + + def test_no_warning_when_threshold_explicitly_passed_as_default(self, agent): + # User explicitly passes 0.5 — equivalent to default, must not warn + with warnings.catch_warnings(): + warnings.simplefilter("error") + agent.predict([self.SAMPLE], threshold=0.5) + + def test_prediction_still_returns_valid_output_when_warning_fires(self, agent): + with pytest.warns(UserWarning): + result = agent.predict([self.SAMPLE], threshold=0.1) + assert result in (["t"], ["f"]) + + def test_threshold_actually_controls_proba_classifier(self, agent, monkeypatch): + # Swap in a fake classifier whose probability for class 1 is 0.6. + # threshold=0.5 → 0.6 >= 0.5 → 'f' + # threshold=0.7 → 0.6 < 0.7 → 't' + monkeypatch.setattr( + agent, "false_positive_detector", _FakeProbaClassifier(prob_class_one=0.6) + ) + assert agent.predict([self.SAMPLE], threshold=0.5) == ["f"] + assert agent.predict([self.SAMPLE], threshold=0.7) == ["t"] + assert agent.predict([self.SAMPLE], threshold=0.3) == ["f"] + + def test_no_warning_when_proba_classifier_used_with_custom_threshold( + self, agent, monkeypatch + ): + # When the loaded model DOES support predict_proba, a non-default + # threshold is honored and no warning should fire. + monkeypatch.setattr( + agent, "false_positive_detector", _FakeProbaClassifier(prob_class_one=0.6) + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + agent.predict([self.SAMPLE], threshold=0.8) + + def test_threshold_boundary_inclusive(self, agent, monkeypatch): + # Threshold check uses >=, so prediction at the exact boundary should be 'f' + monkeypatch.setattr( + agent, "false_positive_detector", _FakeProbaClassifier(prob_class_one=0.5) + ) + assert agent.predict([self.SAMPLE], threshold=0.5) == ["f"]