Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion Safaa/src/safaa/Safaa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import re
import spacy
import warnings
from joblib import load, dump
from importlib.resources import files, as_file
from importlib.metadata import version
Expand Down Expand Up @@ -235,7 +236,14 @@ def predict(self, data, threshold=0.5):
Parameters:
data (iterable): The data to predict.
threshold (float): The probability threshold for classification.
Defaults to 0.5.
Defaults to 0.5. Only used when the loaded
classifier supports predict_proba (e.g., SGD
with loss='log_loss' or 'modified_huber').
The default shipped model uses loss='hinge'
and does not produce probabilities, so this
argument is ignored for that model. A
UserWarning is emitted when a non-default
threshold is passed but the model cannot use it.

Returns:
list: The predictions.
Expand All @@ -258,6 +266,17 @@ def predict(self, data, threshold=0.5):
for prediction in predictions
]

if threshold != 0.5:
warnings.warn(
"The loaded false positive detector does not support "
"probability estimates (predict_proba); the 'threshold' "
"argument is being ignored. Train a classifier with "
"loss='log_loss' or 'modified_huber' if threshold control "
"is required.",
UserWarning,
stacklevel=2,
)

# Get binary predictions from the model if probability prediction is not
# supported
return [
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ psycopg2-binary = '>=2.9'
requests = '>=2.28'
flake8 = '*'
build = '*'
pytest = '*'

[project]
name = 'safaa'
Expand Down
3 changes: 3 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: © 2026 RAJVEER42 <irajveer.bishnoi2310@gmail.com>
#
# SPDX-License-Identifier: LGPL-2.1-only
102 changes: 102 additions & 0 deletions tests/test_safaa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: © 2026 RAJVEER42 <irajveer.bishnoi2310@gmail.com>
#
# SPDX-License-Identifier: LGPL-2.1-only

import warnings

import pytest

from safaa.Safaa import SafaaAgent


@pytest.fixture(scope="module")
def agent():
return SafaaAgent()


# ---------------------------------------------------------------------------
# predict — threshold argument behavior
# ---------------------------------------------------------------------------

class _FakeProbaClassifier:
"""Test double: predict_proba returns a fixed probability for class 1."""

def __init__(self, prob_class_one):
self.prob_class_one = prob_class_one

def predict_proba(self, X):
n = X.shape[0]
p1 = self.prob_class_one
return [[1.0 - p1, p1] for _ in range(n)]

def predict(self, X):
# Not used when predict_proba exists; included so introspection
# sees a normal estimator interface.
return [1 if self.prob_class_one >= 0.5 else 0 for _ in range(X.shape[0])]


class TestPredictThreshold:

SAMPLE = "Copyright 2024 Siemens AG"

def test_no_warning_with_default_threshold_on_hinge_model(self, agent):
# Shipped model is SGD(loss='hinge') which has no predict_proba.
# Calling predict() with the default threshold must NOT warn.
with warnings.catch_warnings():
warnings.simplefilter("error") # promote any warning to an error
agent.predict([self.SAMPLE])

def test_warning_fires_on_non_default_threshold_without_predict_proba(self, agent):
with pytest.warns(UserWarning, match="threshold"):
agent.predict([self.SAMPLE], threshold=0.3)

def test_warning_message_mentions_predict_proba(self, agent):
with pytest.warns(UserWarning) as record:
agent.predict([self.SAMPLE], threshold=0.7)
assert any("predict_proba" in str(w.message) for w in record)

def test_warning_fires_even_at_extreme_thresholds(self, agent):
for t in (0.0, 1.0, 0.99, 0.01):
with pytest.warns(UserWarning):
agent.predict([self.SAMPLE], threshold=t)

def test_no_warning_when_threshold_explicitly_passed_as_default(self, agent):
# User explicitly passes 0.5 — equivalent to default, must not warn
with warnings.catch_warnings():
warnings.simplefilter("error")
agent.predict([self.SAMPLE], threshold=0.5)

def test_prediction_still_returns_valid_output_when_warning_fires(self, agent):
with pytest.warns(UserWarning):
result = agent.predict([self.SAMPLE], threshold=0.1)
assert result in (["t"], ["f"])

def test_threshold_actually_controls_proba_classifier(self, agent, monkeypatch):
# Swap in a fake classifier whose probability for class 1 is 0.6.
# threshold=0.5 → 0.6 >= 0.5 → 'f'
# threshold=0.7 → 0.6 < 0.7 → 't'
monkeypatch.setattr(
agent, "false_positive_detector", _FakeProbaClassifier(prob_class_one=0.6)
)
assert agent.predict([self.SAMPLE], threshold=0.5) == ["f"]
assert agent.predict([self.SAMPLE], threshold=0.7) == ["t"]
assert agent.predict([self.SAMPLE], threshold=0.3) == ["f"]

def test_no_warning_when_proba_classifier_used_with_custom_threshold(
self, agent, monkeypatch
):
# When the loaded model DOES support predict_proba, a non-default
# threshold is honored and no warning should fire.
monkeypatch.setattr(
agent, "false_positive_detector", _FakeProbaClassifier(prob_class_one=0.6)
)
with warnings.catch_warnings():
warnings.simplefilter("error")
agent.predict([self.SAMPLE], threshold=0.8)

def test_threshold_boundary_inclusive(self, agent, monkeypatch):
# Threshold check uses >=, so prediction at the exact boundary should be 'f'
monkeypatch.setattr(
agent, "false_positive_detector", _FakeProbaClassifier(prob_class_one=0.5)
)
assert agent.predict([self.SAMPLE], threshold=0.5) == ["f"]