From 6a3427b766447a195ab4131681d324f0af111089 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 22 Jun 2026 11:43:16 -0700 Subject: [PATCH 1/3] FEAT: Add default implementations of GCG extension protocols (#1902) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/auxiliary_attacks/gcg/__init__.py | 14 + .../gcg/default_implementations.py | 331 +++++++++++++ .../gcg/test_default_implementations.py | 454 ++++++++++++++++++ 3 files changed, 799 insertions(+) create mode 100644 pyrit/auxiliary_attacks/gcg/default_implementations.py create mode 100644 tests/unit/auxiliary_attacks/gcg/test_default_implementations.py diff --git a/pyrit/auxiliary_attacks/gcg/__init__.py b/pyrit/auxiliary_attacks/gcg/__init__.py index a10d862fe3..160b2f313a 100644 --- a/pyrit/auxiliary_attacks/gcg/__init__.py +++ b/pyrit/auxiliary_attacks/gcg/__init__.py @@ -47,18 +47,28 @@ # mechanism so all GCG public symbols share one re-export pathway. _LAZY_IMPORTS = { "CandidateFilter": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "CandidateFilter"), + "CrossEntropyLoss": ("pyrit.auxiliary_attacks.gcg.default_implementations", "CrossEntropyLoss"), "GCG": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"), "GCGContext": ("pyrit.auxiliary_attacks.gcg.generator", "GCGContext"), "GCGGenerator": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"), "GCGResult": ("pyrit.auxiliary_attacks.gcg.generator", "GCGResult"), + "LengthPreservingFilter": ("pyrit.auxiliary_attacks.gcg.default_implementations", "LengthPreservingFilter"), + "LiteralStringInit": ("pyrit.auxiliary_attacks.gcg.default_implementations", "LiteralStringInit"), "LossFunction": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "LossFunction"), "SamplingStrategy": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SamplingStrategy"), + "StandardGCGSampling": ("pyrit.auxiliary_attacks.gcg.default_implementations", "StandardGCGSampling"), "SuffixInitializer": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SuffixInitializer"), "load_goals_and_targets": ("pyrit.auxiliary_attacks.gcg.data", "load_goals_and_targets"), } if TYPE_CHECKING: from pyrit.auxiliary_attacks.gcg.data import load_goals_and_targets + from pyrit.auxiliary_attacks.gcg.default_implementations import ( + CrossEntropyLoss, + LengthPreservingFilter, + LiteralStringInit, + StandardGCGSampling, + ) from pyrit.auxiliary_attacks.gcg.extension_protocols import ( CandidateFilter, LossFunction, @@ -91,6 +101,7 @@ def __dir__() -> list[str]: __all__ = [ "CandidateFilter", + "CrossEntropyLoss", "GCG", "GCGAlgorithmConfig", "GCGConfig", @@ -101,8 +112,11 @@ def __dir__() -> list[str]: "GCGOutputConfig", "GCGResult", "GCGStrategyConfig", + "LengthPreservingFilter", + "LiteralStringInit", "LossFunction", "SamplingStrategy", + "StandardGCGSampling", "SuffixInitializer", "load_goals_and_targets", ] diff --git a/pyrit/auxiliary_attacks/gcg/default_implementations.py b/pyrit/auxiliary_attacks/gcg/default_implementations.py new file mode 100644 index 0000000000..3967c128c7 --- /dev/null +++ b/pyrit/auxiliary_attacks/gcg/default_implementations.py @@ -0,0 +1,331 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Default concrete implementations of the four GCG extension protocols. + +Each class in this module reproduces the byte-identical behavior of the +legacy GCG attack code path it replaces: + +- ``StandardGCGSampling`` reproduces ``GCGPromptManager.sample_control``. +- ``CrossEntropyLoss`` reproduces ``AttackPrompt.target_loss`` and + ``AttackPrompt.control_loss`` combined via the weighted sum applied + inside ``GCGMultiPromptAttack.step``. +- ``LengthPreservingFilter`` reproduces ``MultiPromptAttack.get_filtered_cands``. +- ``LiteralStringInit`` reproduces the literal-string ``control_init`` + parameter threaded through the attack constructors. + +The defaults are *not* wired into ``GCGMultiPromptAttack`` here. They are +shipped ahead of wiring so the strategy objects can already be constructed +and inspected, and so the wiring change is a pure orchestration edit. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +class StandardGCGSampling: + """Top-k by ``-gradient``, uniform pick within top-k at one random position per row. + + The standard GCG sampling rule: for each of ``batch_size`` candidate + rows, pick one of the ``control_length`` positions, then replace the + token at that position with a uniformly-sampled token id from the top-k + smallest-gradient (most-promising) candidates at that position. The + ``temperature`` argument is part of the protocol but is unused by this + sampler, which always samples uniformly within the top-k. + + Reproduces ``GCGPromptManager.sample_control`` from + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py`` byte-for-byte. + """ + + def sample_candidates( + self, + *, + gradient: torch.Tensor, + control_tokens: torch.Tensor, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: torch.Tensor, + ) -> torch.Tensor: + """Sample ``batch_size`` candidate suffix token sequences. + + Args: + gradient (torch.Tensor): Aggregated gradient over the control + tokens with shape ``(control_length, vocab_size)``. Mutated + in-place when ``allow_non_ascii`` is False (the disallowed + token positions are set to ``+inf``), matching legacy + behavior. + control_tokens (torch.Tensor): Current suffix token sequence + with shape ``(control_length,)``. + batch_size (int): Number of candidate suffix rows to return. + top_k (int): Number of top gradient positions per control slot + drawn from. + temperature (float): Sampling temperature. Unused by this + implementation; kept to match the protocol signature. + allow_non_ascii (bool): When False, mask the ``non_ascii_tokens`` + positions of ``gradient`` to ``+inf`` so they fall out of + the top-k. + non_ascii_tokens (torch.Tensor): Token ids to exclude when + ``allow_non_ascii`` is False. + + Returns: + torch.Tensor: Candidate suffix token sequences with shape + ``(batch_size, control_length)`` on the same device as + ``gradient``. + """ + if not allow_non_ascii: + gradient[:, non_ascii_tokens.to(gradient.device)] = np.inf + top_indices = (-gradient).topk(top_k, dim=1).indices + control_tokens = control_tokens.to(gradient.device) + original_control_tokens = control_tokens.repeat(batch_size, 1) + new_token_pos = torch.arange( + 0, + len(control_tokens), + len(control_tokens) / batch_size, + device=gradient.device, + ).type(torch.int64) + new_token_val = torch.gather( + top_indices[new_token_pos], + 1, + torch.randint(0, top_k, (batch_size, 1), device=gradient.device), + ) + return original_control_tokens.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val) + + +class CrossEntropyLoss: + """Weighted token-level cross-entropy on the target and control slices. + + Per candidate: ``target_weight * CE(target_slice) + control_weight * + CE(control_slice)``, where each cross-entropy term is reduced over its + slice with ``.mean(dim=-1)`` to give one scalar per candidate. The + ``.mean(dim=-1)`` reduction matches where the legacy orchestrator + applies it: ``GCGMultiPromptAttack.step`` calls + ``target_loss(...).mean(dim=-1)`` outside the per-prompt loss method, + so the ``LossFunction`` protocol places the per-candidate scalar + reduction inside the implementation. + + When ``control_weight == 0`` the control term is skipped entirely, + matching the legacy ``if control_weight != 0:`` guard inside ``step``. + The same skip is applied when ``target_weight == 0`` for symmetry. + + Reproduces ``AttackPrompt.target_loss`` + ``AttackPrompt.control_loss`` + from ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``, + combined per ``GCGMultiPromptAttack.step`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + + def __init__(self, *, target_weight: float = 1.0, control_weight: float = 0.0) -> None: + """Initialize the cross-entropy loss with target / control weights. + + Args: + target_weight (float): Weight on the target-slice cross-entropy. + Defaults to 1.0. + control_weight (float): Weight on the control-slice + cross-entropy. Defaults to 0.0 (target-only signal). + + Raises: + ValueError: If either weight is negative, or if both are zero. + """ + if target_weight < 0 or control_weight < 0: + raise ValueError( + "CrossEntropyLoss target_weight and control_weight must be >= 0, " + f"got target_weight={target_weight}, control_weight={control_weight}." + ) + if target_weight == 0 and control_weight == 0: + raise ValueError( + "CrossEntropyLoss requires at least one of target_weight or control_weight to be > 0; " + "with both at 0 the loss is identically zero and provides no signal." + ) + self._target_weight = target_weight + self._control_weight = control_weight + + def compute_loss( + self, + *, + logits: torch.Tensor, + token_ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + """Compute the per-candidate weighted cross-entropy loss. + + Args: + logits (torch.Tensor): Model logits for the candidate batch + with shape ``(batch_size, seq_len, vocab_size)``. + token_ids (torch.Tensor): Input token ids the model was run on + with shape ``(batch_size, seq_len)``. + target_slice (slice): Slice into the sequence dimension that + identifies the target tokens. + control_slice (slice): Slice into the sequence dimension that + identifies the control (suffix) tokens. + + Returns: + torch.Tensor: Per-candidate scalar loss with shape + ``(batch_size,)``. + """ + criterion = nn.CrossEntropyLoss(reduction="none") + total: torch.Tensor | None = None + + if self._target_weight > 0: + target_loss_slice = slice(target_slice.start - 1, target_slice.stop - 1) + target_term = criterion( + logits[:, target_loss_slice, :].transpose(1, 2), + token_ids[:, target_slice], + ).mean(dim=-1) + total = self._target_weight * target_term + + if self._control_weight > 0: + control_loss_slice = slice(control_slice.start - 1, control_slice.stop - 1) + control_term = criterion( + logits[:, control_loss_slice, :].transpose(1, 2), + token_ids[:, control_slice], + ).mean(dim=-1) + weighted_control = self._control_weight * control_term + total = weighted_control if total is None else total + weighted_control + + # Constructor guarantees at least one weight is > 0, so ``total`` is + # always assigned. The check is kept for the type checker. + if total is None: + raise RuntimeError( + "CrossEntropyLoss.compute_loss produced no terms; " + "this indicates a corrupted instance with both weights at 0." + ) + return total + + +class LengthPreservingFilter: + """Decodes each candidate token row and drops any whose decoded string + either (a) equals ``current_control`` or (b) re-tokenizes to a different + token count, padding dropped rows by repeating the last accepted + candidate. + + The ``filter`` constructor parameter selects between filtering (legacy + ``filter_cand=True`` branch) and passthrough decode-only mode (legacy + ``filter_cand=False`` branch). + + Also performs the legacy out-of-vocab clamping: tokens above + ``tokenizer.vocab_size`` are replaced in-place by the id of ``"!"``, + matching the safety pass at the top of ``get_filtered_cands``. + + Reproduces ``MultiPromptAttack.get_filtered_cands`` from + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. + """ + + def __init__(self, *, filter: bool = True) -> None: + """Initialize the filter. + + Args: + filter (bool): When True, drop candidates that equal + ``current_control`` or re-tokenize to a different length, + padding the result with the last accepted candidate. When + False, decode every row and return them all unchanged. + Defaults to True. + """ + self._filter = filter + + def filter_candidates( + self, + *, + candidate_tokens: torch.Tensor, + tokenizer: Any, + current_control: str, + ) -> list[str]: + """Decode and filter a batch of candidate suffix token tensors. + + Args: + candidate_tokens (torch.Tensor): Sampled candidate suffixes + with shape ``(batch_size, control_length)``. Mutated + in-place by the out-of-vocab clamp, matching legacy + behavior. + tokenizer (Any): HuggingFace-style tokenizer. ``tokenizer.decode`` + renders each row to text; ``tokenizer(text, + add_special_tokens=False).input_ids`` is used to detect + re-tokenization drift; ``tokenizer("!").input_ids[0]`` + provides the replacement id for out-of-vocab clamping. + current_control (str): Current suffix string. When ``filter`` + is True, candidates that decode to this string are dropped. + + Returns: + list[str]: Decoded candidate suffix strings of length exactly + ``candidate_tokens.shape[0]``. + """ + logger.info("Masking out of range token_id.") + vocab_size = tokenizer.vocab_size + candidate_tokens[candidate_tokens > vocab_size] = tokenizer("!").input_ids[0] + + candidates: list[str] = [] + for i in range(candidate_tokens.shape[0]): + decoded_str = tokenizer.decode( + candidate_tokens[i], skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + if self._filter: + if decoded_str != current_control and len( + tokenizer(decoded_str, add_special_tokens=False).input_ids + ) == len(candidate_tokens[i]): + candidates.append(decoded_str) + else: + candidates.append(decoded_str) + + if self._filter: + candidates = candidates + [candidates[-1]] * (len(candidate_tokens) - len(candidates)) + return candidates + + +class LiteralStringInit: + """Returns the configured literal suffix verbatim; ignores the tokenizer. + + Encapsulates the current ``control_init`` plumbing — a literal string + threaded through ``AttackPrompt.__init__``, ``PromptManager.__init__``, + ``MultiPromptAttack.__init__``, and the per-strategy ``*Attack`` + constructors — so that custom initializers that do need the tokenizer + (for example, a random vocabulary sampler) can be swapped in without + changing those constructor signatures. + + Reproduces the literal-string ``control_init`` parameter assignment + (``self.control = control_init``) inside ``AttackPrompt.__init__`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. + """ + + def __init__(self, *, suffix: str) -> None: + """Initialize the literal-string suffix initializer. + + Args: + suffix (str): The literal suffix string to return on every + call to ``make_initial_suffix``. Must be non-empty. + + Raises: + ValueError: If ``suffix`` is the empty string. + """ + if not suffix: + raise ValueError("LiteralStringInit.suffix must be a non-empty string.") + self._suffix = suffix + + def make_initial_suffix(self, *, tokenizer: Any) -> str: + """Return the configured suffix string. + + Args: + tokenizer (Any): Ignored. Present to match the protocol + signature so custom initializers that need vocabulary + access can be substituted without changing call sites. + + Returns: + str: The literal suffix string supplied at construction. + """ + return self._suffix + + +__all__ = [ + "CrossEntropyLoss", + "LengthPreservingFilter", + "LiteralStringInit", + "StandardGCGSampling", +] diff --git a/tests/unit/auxiliary_attacks/gcg/test_default_implementations.py b/tests/unit/auxiliary_attacks/gcg/test_default_implementations.py new file mode 100644 index 0000000000..8b89745052 --- /dev/null +++ b/tests/unit/auxiliary_attacks/gcg/test_default_implementations.py @@ -0,0 +1,454 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for ``pyrit.auxiliary_attacks.gcg.default_implementations``. + +These tests verify byte-identical parity between the four default +implementations and the legacy GCG attack code paths they reproduce: + +- ``StandardGCGSampling`` vs ``GCGPromptManager.sample_control`` +- ``CrossEntropyLoss`` vs the weighted sum of ``AttackPrompt.target_loss`` + and ``AttackPrompt.control_loss`` applied inside + ``GCGMultiPromptAttack.step`` +- ``LengthPreservingFilter`` vs ``MultiPromptAttack.get_filtered_cands`` +- ``LiteralStringInit`` vs the literal-string ``control_init`` assignment + inside ``AttackPrompt.__init__`` + +Mocking patterns follow the conventions established in +``tests/unit/auxiliary_attacks/gcg/test_gcg_core.py`` (``object.__new__`` +to skip the real ``__init__``, ``MagicMock`` tokenizers). +""" + +from unittest.mock import MagicMock + +import pytest + +torch = pytest.importorskip("torch", reason="GCG default implementations require torch") + +attack_manager_mod = pytest.importorskip( + "pyrit.auxiliary_attacks.gcg.attack.base.attack_manager", + reason="GCG optional dependencies (torch, mlflow, etc.) not installed", +) +gcg_attack_mod = pytest.importorskip( + "pyrit.auxiliary_attacks.gcg.attack.gcg.gcg_attack", + reason="GCG optional dependencies not installed", +) + +import pyrit.auxiliary_attacks.gcg as gcg_pkg # noqa: E402 +from pyrit.auxiliary_attacks.gcg import ( # noqa: E402 + CrossEntropyLoss, + LengthPreservingFilter, + LiteralStringInit, + StandardGCGSampling, +) +from pyrit.auxiliary_attacks.gcg import default_implementations as defaults_module # noqa: E402 +from pyrit.auxiliary_attacks.gcg.config import GCGAlgorithmConfig # noqa: E402 + +AttackPrompt = attack_manager_mod.AttackPrompt +MultiPromptAttack = attack_manager_mod.MultiPromptAttack +GCGPromptManager = gcg_attack_mod.GCGPromptManager + + +DEFAULT_NAMES = ( + "CrossEntropyLoss", + "LengthPreservingFilter", + "LiteralStringInit", + "StandardGCGSampling", +) + + +class TestPackageReExports: + """Verify the four default classes are re-exported from the package root.""" + + @pytest.mark.parametrize("name", DEFAULT_NAMES) + def test_default_is_reexported_with_identity(self, name: str) -> None: + package_attr = getattr(gcg_pkg, name) + module_attr = getattr(defaults_module, name) + assert package_attr is module_attr, ( + f"{name} re-exported from pyrit.auxiliary_attacks.gcg must be the same " + f"object as pyrit.auxiliary_attacks.gcg.default_implementations.{name}" + ) + + @pytest.mark.parametrize("name", DEFAULT_NAMES) + def test_default_in_package_dunder_all(self, name: str) -> None: + assert name in gcg_pkg.__all__ + + +class TestStandardGCGSampling: + """Parity: ``StandardGCGSampling`` vs ``GCGPromptManager.sample_control``.""" + + def _make_legacy_prompt_manager( + self, + *, + control_tokens: torch.Tensor, + non_ascii_tokens: torch.Tensor, + ) -> GCGPromptManager: + # Mirrors the construction pattern used by TestSampleControl in + # test_gcg_core.py: skip __init__ and seed just the attributes that + # sample_control reads. + prompt_manager = object.__new__(GCGPromptManager) + prompt_manager._nonascii_toks = non_ascii_tokens + prompt_manager._prompts = [MagicMock()] + prompt_manager._prompts[0].control_toks = control_tokens.clone() + return prompt_manager + + def test_sample_candidates_matches_legacy_with_ascii_only(self) -> None: + """Legacy reference: ``GCGPromptManager.sample_control(grad, batch_size, + topk=top_k, temp=1.0, allow_non_ascii=False)`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + n_control_tokens = 5 + vocab_size = 50 + batch_size = 4 + top_k = 8 + + torch.manual_seed(2026) + gradient_template = torch.randn(n_control_tokens, vocab_size) + control_tokens = torch.randint(0, vocab_size, (n_control_tokens,)) + non_ascii_tokens = torch.tensor([2, 7, 13]) + + # Legacy path + prompt_manager = self._make_legacy_prompt_manager( + control_tokens=control_tokens, non_ascii_tokens=non_ascii_tokens + ) + torch.manual_seed(12345) + legacy_out = prompt_manager.sample_control( + gradient_template.clone(), + batch_size, + topk=top_k, + temp=1.0, + allow_non_ascii=False, + ) + + # Default path + default = StandardGCGSampling() + torch.manual_seed(12345) + default_out = default.sample_candidates( + gradient=gradient_template.clone(), + control_tokens=control_tokens.clone(), + batch_size=batch_size, + top_k=top_k, + temperature=1.0, + allow_non_ascii=False, + non_ascii_tokens=non_ascii_tokens, + ) + + assert torch.equal(default_out, legacy_out) + + def test_sample_candidates_matches_legacy_with_non_ascii_allowed(self) -> None: + """Legacy reference: same as above but with ``allow_non_ascii=True`` + (the no-mask branch where the gradient is not mutated). + """ + n_control_tokens = 6 + vocab_size = 40 + batch_size = 5 + top_k = 10 + + torch.manual_seed(2027) + gradient_template = torch.randn(n_control_tokens, vocab_size) + control_tokens = torch.randint(0, vocab_size, (n_control_tokens,)) + non_ascii_tokens = torch.tensor([1, 4]) + + prompt_manager = self._make_legacy_prompt_manager( + control_tokens=control_tokens, non_ascii_tokens=non_ascii_tokens + ) + torch.manual_seed(54321) + legacy_out = prompt_manager.sample_control( + gradient_template.clone(), + batch_size, + topk=top_k, + temp=1.0, + allow_non_ascii=True, + ) + + default = StandardGCGSampling() + torch.manual_seed(54321) + default_out = default.sample_candidates( + gradient=gradient_template.clone(), + control_tokens=control_tokens.clone(), + batch_size=batch_size, + top_k=top_k, + temperature=1.0, + allow_non_ascii=True, + non_ascii_tokens=non_ascii_tokens, + ) + + assert torch.equal(default_out, legacy_out) + + +class TestCrossEntropyLoss: + """Parity: ``CrossEntropyLoss`` vs ``AttackPrompt.target_loss`` + + ``AttackPrompt.control_loss``. + """ + + def _make_legacy_prompt( + self, + *, + target_slice: slice, + control_slice: slice, + ) -> AttackPrompt: + # Mirrors TestTargetAndControlLoss in test_gcg_core.py: skip + # __init__ and seed only the slice attributes that the loss methods + # consult. + prompt = object.__new__(AttackPrompt) + prompt._target_slice = target_slice + prompt._control_slice = control_slice + return prompt + + def test_compute_loss_matches_legacy_weighted_sum(self) -> None: + """Legacy reference: + ``target_weight * AttackPrompt.target_loss(logits, ids).mean(dim=-1)`` + ``+ control_weight * AttackPrompt.control_loss(logits, ids).mean(dim=-1)``, + per ``GCGMultiPromptAttack.step`` in + ``pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py``. + """ + batch_size = 4 + seq_len = 10 + vocab_size = 30 + target_slice = slice(5, 8) + control_slice = slice(2, 5) + target_weight = 1.0 + control_weight = 0.1 + + torch.manual_seed(99) + logits = torch.randn(batch_size, seq_len, vocab_size) + token_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + + prompt = self._make_legacy_prompt(target_slice=target_slice, control_slice=control_slice) + legacy_target = prompt.target_loss(logits, token_ids).mean(dim=-1) + legacy_control = prompt.control_loss(logits, token_ids).mean(dim=-1) + legacy_total = target_weight * legacy_target + control_weight * legacy_control + + default = CrossEntropyLoss(target_weight=target_weight, control_weight=control_weight) + default_total = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=target_slice, + control_slice=control_slice, + ) + + assert torch.equal(default_total, legacy_total) + + def test_compute_loss_target_only_matches_legacy_target_loss(self) -> None: + """With ``control_weight=0`` the legacy ``step`` skips the control + term (``if control_weight != 0:`` guard at line 211). The default + must produce the same per-candidate value as + ``target_weight * target_loss(...).mean(dim=-1)`` alone. + """ + target_slice = slice(4, 7) + control_slice = slice(1, 4) + + torch.manual_seed(7) + logits = torch.randn(3, 9, 25) + token_ids = torch.randint(0, 25, (3, 9)) + + prompt = self._make_legacy_prompt(target_slice=target_slice, control_slice=control_slice) + legacy_total = 1.0 * prompt.target_loss(logits, token_ids).mean(dim=-1) + + default = CrossEntropyLoss(target_weight=1.0, control_weight=0.0) + default_total = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=target_slice, + control_slice=control_slice, + ) + + assert torch.equal(default_total, legacy_total) + + def test_compute_loss_control_only_matches_legacy_control_loss(self) -> None: + """With ``target_weight=0`` the default must produce the same value + as ``control_weight * control_loss(...).mean(dim=-1)`` alone. + """ + target_slice = slice(4, 7) + control_slice = slice(1, 4) + + torch.manual_seed(13) + logits = torch.randn(3, 9, 25) + token_ids = torch.randint(0, 25, (3, 9)) + + prompt = self._make_legacy_prompt(target_slice=target_slice, control_slice=control_slice) + legacy_total = 0.5 * prompt.control_loss(logits, token_ids).mean(dim=-1) + + default = CrossEntropyLoss(target_weight=0.0, control_weight=0.5) + default_total = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=target_slice, + control_slice=control_slice, + ) + + assert torch.equal(default_total, legacy_total) + + def test_init_rejects_both_weights_zero(self) -> None: + with pytest.raises(ValueError, match="at least one"): + CrossEntropyLoss(target_weight=0.0, control_weight=0.0) + + def test_init_rejects_negative_target_weight(self) -> None: + with pytest.raises(ValueError, match=">= 0"): + CrossEntropyLoss(target_weight=-0.5, control_weight=1.0) + + def test_init_rejects_negative_control_weight(self) -> None: + with pytest.raises(ValueError, match=">= 0"): + CrossEntropyLoss(target_weight=1.0, control_weight=-0.5) + + def test_compute_loss_returns_batch_sized_tensor(self) -> None: + batch_size = 4 + logits = torch.randn(batch_size, 10, 20) + token_ids = torch.randint(0, 20, (batch_size, 10)) + + default = CrossEntropyLoss(target_weight=1.0, control_weight=0.1) + out = default.compute_loss( + logits=logits, + token_ids=token_ids, + target_slice=slice(5, 8), + control_slice=slice(2, 5), + ) + + assert out.shape == (batch_size,) + + +def _make_filter_tokenizer() -> MagicMock: + """Build a fresh, deterministic, stateless mock tokenizer for filter tests. + + Behavior: + - ``decode(tensor)`` -> ``"x" * int(tensor[0].item())`` — string length + is keyed off the first token id, so each row maps to a distinct + predictable string. + - ``tokenizer(text, ...).input_ids`` has length ``len(text)`` — so the + retokenized length check is fully predictable from the decoded + string. + - ``tokenizer("!").input_ids[0] == 0`` — provides the clamp + replacement id. + - ``vocab_size == 100``. + """ + tokenizer = MagicMock() + tokenizer.vocab_size = 100 + + def decode_fn(ids, **_kwargs): + return "x" * int(ids[0].item()) + + tokenizer.decode.side_effect = decode_fn + + def call_tokenizer(text, **_kwargs): + result = MagicMock() + if text == "!": + result.input_ids = [0] + else: + result.input_ids = list(range(len(text))) + return result + + tokenizer.side_effect = call_tokenizer + return tokenizer + + +class TestLengthPreservingFilter: + """Parity: ``LengthPreservingFilter`` vs + ``MultiPromptAttack.get_filtered_cands``. + """ + + def _make_legacy_attack(self, *, tokenizer: MagicMock) -> MultiPromptAttack: + # Mirrors TestGetFilteredCands in test_gcg_core.py: skip __init__ + # and only attach the workers list that get_filtered_cands reads. + attack = object.__new__(MultiPromptAttack) + worker = MagicMock() + worker.tokenizer = tokenizer + attack.workers = [worker] + return attack + + def test_filter_candidates_matches_legacy_filtered(self) -> None: + """Legacy reference: + ``MultiPromptAttack.get_filtered_cands(0, control_cand, + filter_cand=True, curr_control=...)`` in + ``pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py``. + + With the helper tokenizer: + - Row 0 ``[3, 0, 1]`` -> decode ``"xxx"`` (len 3); retok len 3 == + control_length 3 -> KEEP. + - Row 1 ``[5, 0, 0]`` -> decode ``"xxxxx"`` (len 5); retok len 5 + != 3 -> DROP. + - Row 2 ``[2, 0, 1]`` -> decode ``"xx"`` (len 2); retok len 2 != + 3 -> DROP. + Pad-with-last gives ``["xxx", "xxx", "xxx"]``. + """ + candidate_template = torch.tensor([[3, 0, 1], [5, 0, 0], [2, 0, 1]]) + + legacy_attack = self._make_legacy_attack(tokenizer=_make_filter_tokenizer()) + legacy_out = legacy_attack.get_filtered_cands( + 0, candidate_template.clone(), filter_cand=True, curr_control="never_matches" + ) + + default = LengthPreservingFilter(filter=True) + default_out = default.filter_candidates( + candidate_tokens=candidate_template.clone(), + tokenizer=_make_filter_tokenizer(), + current_control="never_matches", + ) + + assert default_out == legacy_out + assert legacy_out == ["xxx", "xxx", "xxx"] + + def test_filter_candidates_matches_legacy_unfiltered(self) -> None: + """Legacy reference: ``get_filtered_cands(0, control_cand, + filter_cand=False)``. Every row is decoded and returned unchanged. + """ + candidate_template = torch.tensor([[3, 0, 1], [5, 0, 0], [2, 0, 1]]) + + legacy_attack = self._make_legacy_attack(tokenizer=_make_filter_tokenizer()) + legacy_out = legacy_attack.get_filtered_cands(0, candidate_template.clone(), filter_cand=False) + + default = LengthPreservingFilter(filter=False) + default_out = default.filter_candidates( + candidate_tokens=candidate_template.clone(), + tokenizer=_make_filter_tokenizer(), + current_control="ignored_when_filter_false", + ) + + assert default_out == legacy_out + assert legacy_out == ["xxx", "xxxxx", "xx"] + + def test_filter_candidates_clamps_out_of_vocab_tokens(self) -> None: + """Both code paths apply the legacy vocab-clamp in-place: tokens + above ``vocab_size`` are replaced by the id of ``"!"`` before any + decoding happens. + """ + candidate_template = torch.tensor([[150, 0, 1], [3, 0, 1]]) # 150 > vocab_size=100 + + legacy_input = candidate_template.clone() + legacy_attack = self._make_legacy_attack(tokenizer=_make_filter_tokenizer()) + legacy_attack.get_filtered_cands(0, legacy_input, filter_cand=False) + + default_input = candidate_template.clone() + default = LengthPreservingFilter(filter=False) + default.filter_candidates( + candidate_tokens=default_input, + tokenizer=_make_filter_tokenizer(), + current_control="", + ) + + assert torch.equal(default_input, legacy_input) + assert default_input[0, 0].item() == 0 + + +class TestLiteralStringInit: + """Parity: ``LiteralStringInit`` vs the literal-string ``control_init`` + assignment inside ``AttackPrompt.__init__`` (``self.control = + control_init``). + """ + + def test_make_initial_suffix_returns_default_control_init(self) -> None: + """Legacy reference: ``GCGAlgorithmConfig.control_init`` (default + ``_DEFAULT_CONTROL_INIT``) is assigned to ``self.control`` in + ``AttackPrompt.__init__``. + """ + default_suffix = GCGAlgorithmConfig().control_init + initializer = LiteralStringInit(suffix=default_suffix) + assert initializer.make_initial_suffix(tokenizer=MagicMock()) == default_suffix + + def test_make_initial_suffix_ignores_tokenizer(self) -> None: + suffix = "custom suffix string" + initializer = LiteralStringInit(suffix=suffix) + assert initializer.make_initial_suffix(tokenizer=None) == suffix + + def test_init_rejects_empty_suffix(self) -> None: + with pytest.raises(ValueError, match="non-empty"): + LiteralStringInit(suffix="") From abe86e5ed8a23261a289ad52ad6775eb1c6cc999 Mon Sep 17 00:00:00 2001 From: Copilot <223556219+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:08:29 -0700 Subject: [PATCH 2/3] Wire GCG extension protocol path Add optional protocol fields on GCGAlgorithmConfig and wire GCGMultiPromptAttack.step to dispatch through protocol objects with default fallbacks that preserve legacy behavior when unset. Also wire suffix initialization through config and add unit coverage for config validation, manager wiring, default parity, and custom dispatch. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../gcg/attack/gcg/gcg_attack.py | 136 +++++++- pyrit/auxiliary_attacks/gcg/config.py | 44 +++ .../gcg/extension_protocols.py | 11 +- pyrit/auxiliary_attacks/gcg/generator.py | 37 ++- .../unit/auxiliary_attacks/gcg/test_config.py | 76 ++++- .../auxiliary_attacks/gcg/test_gcg_core.py | 309 ++++++++++++++++++ .../auxiliary_attacks/gcg/test_generator.py | 112 ++++++- 7 files changed, 702 insertions(+), 23 deletions(-) diff --git a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py index 4df1ae9205..c967ca0841 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py +++ b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py @@ -17,6 +17,12 @@ get_embedding_matrix, get_embeddings, ) +from pyrit.auxiliary_attacks.gcg.default_implementations import ( + CrossEntropyLoss, + LengthPreservingFilter, + StandardGCGSampling, +) +from pyrit.auxiliary_attacks.gcg.extension_protocols import CandidateFilter, LossFunction, SamplingStrategy logger = logging.getLogger(__name__) @@ -125,6 +131,93 @@ def sample_control( class GCGMultiPromptAttack(MultiPromptAttack): """GCG-specific multi-prompt attack that implements the GCG optimization step.""" + def __init__( + self, + goals: list[str], + targets: list[str], + workers: list[Any], + control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[Any] | None = None, + *, + sampling: SamplingStrategy | None = None, + loss: LossFunction | None = None, + candidate_filter: CandidateFilter | None = None, + ) -> None: + super().__init__( + goals, + targets, + workers, + control_init, + test_prefixes, + logfile, + managers, + test_goals, + test_targets, + test_workers, + ) + self._sampling = sampling + self._loss = loss + self._candidate_filter = candidate_filter + + def _resolve_sampling(self) -> SamplingStrategy: + sampling = getattr(self, "_sampling", None) + if sampling is not None: + return sampling + return StandardGCGSampling() + + def _resolve_loss(self, *, target_weight: float, control_weight: float) -> LossFunction: + loss = getattr(self, "_loss", None) + if loss is not None: + return loss + return CrossEntropyLoss(target_weight=target_weight, control_weight=control_weight) + + def _resolve_candidate_filter(self, *, filter_cand: bool) -> CandidateFilter: + candidate_filter = getattr(self, "_candidate_filter", None) + if candidate_filter is not None: + return candidate_filter + return LengthPreservingFilter(filter=filter_cand) + + def _sample_control_candidates( + self, + *, + worker_index: int, + gradient: torch.Tensor, + batch_size: int, + topk: int, + temp: float, + allow_non_ascii: bool, + ) -> torch.Tensor: + sampler = self._resolve_sampling() + prompt_manager = self.prompts[worker_index] + return sampler.sample_candidates( + gradient=gradient, + control_tokens=prompt_manager.control_toks, + batch_size=batch_size, + top_k=topk, + temperature=temp, + allow_non_ascii=allow_non_ascii, + non_ascii_tokens=prompt_manager.disallowed_toks, + ) + + def _filter_control_candidates( + self, + *, + worker_index: int, + control_cand: torch.Tensor, + filter_cand: bool, + ) -> list[str]: + candidate_filter = self._resolve_candidate_filter(filter_cand=filter_cand) + return candidate_filter.filter_candidates( + candidate_tokens=control_cand, + tokenizer=self.workers[worker_index].tokenizer, + current_control=self.control_str, + ) + def step( self, *, @@ -158,6 +251,7 @@ def step( """ main_device = self.models[0].device control_cands = [] + loss_function = self._resolve_loss(target_weight=target_weight, control_weight=control_weight) for j, worker in enumerate(self.workers): worker(self.prompts[j], "grad", worker.model) @@ -171,10 +265,19 @@ def step( grad = torch.zeros_like(new_grad) if grad.shape != new_grad.shape: with torch.no_grad(): - control_cand = self.prompts[j - 1].sample_control(grad, batch_size, topk, temp, allow_non_ascii) + control_cand = self._sample_control_candidates( + worker_index=j - 1, + gradient=grad, + batch_size=batch_size, + topk=topk, + temp=temp, + allow_non_ascii=allow_non_ascii, + ) control_cands.append( - self.get_filtered_cands( - j - 1, control_cand, filter_cand=filter_cand, curr_control=self.control_str + self._filter_control_candidates( + worker_index=j - 1, + control_cand=control_cand, + filter_cand=filter_cand, ) ) grad = new_grad @@ -182,9 +285,20 @@ def step( grad += new_grad with torch.no_grad(): - control_cand = self.prompts[j].sample_control(grad, batch_size, topk, temp, allow_non_ascii) + control_cand = self._sample_control_candidates( + worker_index=j, + gradient=grad, + batch_size=batch_size, + topk=topk, + temp=temp, + allow_non_ascii=allow_non_ascii, + ) control_cands.append( - self.get_filtered_cands(j, control_cand, filter_cand=filter_cand, curr_control=self.control_str) + self._filter_control_candidates( + worker_index=j, + control_cand=control_cand, + filter_cand=filter_cand, + ) ) del grad, control_cand gc.collect() @@ -205,14 +319,14 @@ def step( worker(self.prompts[k][i], "logits", worker.model, cand, return_ids=True) logits, ids = zip(*[worker.results.get() for worker in self.workers]) loss[j * batch_size : (j + 1) * batch_size] += sum( - target_weight * self.prompts[k][i].target_loss(logit, id).mean(dim=-1).to(main_device) + loss_function.compute_loss( + logits=logit, + token_ids=id, + target_slice=self.prompts[k][i]._target_slice, + control_slice=self.prompts[k][i]._control_slice, + ).to(main_device) for k, (logit, id) in enumerate(zip(logits, ids)) ) - if control_weight != 0: - loss[j * batch_size : (j + 1) * batch_size] += sum( - control_weight * self.prompts[k][i].control_loss(logit, id).mean(dim=-1).to(main_device) - for k, (logit, id) in enumerate(zip(logits, ids)) - ) del logits, ids gc.collect() diff --git a/pyrit/auxiliary_attacks/gcg/config.py b/pyrit/auxiliary_attacks/gcg/config.py index 097a9087af..c2debada6e 100644 --- a/pyrit/auxiliary_attacks/gcg/config.py +++ b/pyrit/auxiliary_attacks/gcg/config.py @@ -25,6 +25,13 @@ if TYPE_CHECKING: from pathlib import Path + from pyrit.auxiliary_attacks.gcg.extension_protocols import ( + CandidateFilter, + LossFunction, + SamplingStrategy, + SuffixInitializer, + ) + _DEFAULT_CONTROL_INIT: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !" @@ -147,6 +154,18 @@ class GCGAlgorithmConfig: random_seed (int): Seed for ``torch``/``numpy``/``random``. Defaults to 42. control_init (str): Initial suffix string the optimization starts from. Defaults to twenty space-separated ``!`` tokens. + sampling (SamplingStrategy | None): Optional strategy object that + samples candidate suffix token sequences from the aggregated + gradient. ``None`` uses the built-in default implementation. + loss (LossFunction | None): Optional loss object used to score each + candidate suffix. ``None`` uses the built-in weighted + cross-entropy default that preserves legacy behavior. + candidate_filter (CandidateFilter | None): Optional candidate-filter + object that decodes/prunes sampled candidate token sequences. + ``None`` uses the built-in length-preserving filter. + suffix_init (SuffixInitializer | None): Optional initializer object + that produces the initial suffix string at attack construction + time. ``None`` uses ``control_init`` verbatim. """ n_steps: int = 500 @@ -161,6 +180,10 @@ class GCGAlgorithmConfig: filter_cand: bool = True random_seed: int = 42 control_init: str = _DEFAULT_CONTROL_INIT + sampling: SamplingStrategy | None = None + loss: LossFunction | None = None + candidate_filter: CandidateFilter | None = None + suffix_init: SuffixInitializer | None = None def __post_init__(self) -> None: if self.n_steps <= 0: @@ -183,6 +206,27 @@ def __post_init__(self) -> None: ) if not self.control_init: raise ValueError("GCGAlgorithmConfig.control_init must be a non-empty string.") + self._validate_extensions() + + def _validate_extensions(self) -> None: + from pyrit.auxiliary_attacks.gcg.extension_protocols import ( + CandidateFilter, + LossFunction, + SamplingStrategy, + SuffixInitializer, + ) + + checks = ( + ("sampling", self.sampling, SamplingStrategy), + ("loss", self.loss, LossFunction), + ("candidate_filter", self.candidate_filter, CandidateFilter), + ("suffix_init", self.suffix_init, SuffixInitializer), + ) + for field_name, value, protocol in checks: + if value is not None and not isinstance(value, protocol): + raise ValueError( + f"GCGAlgorithmConfig.{field_name} must satisfy {protocol.__name__}, got {type(value)!r}." + ) @dataclass diff --git a/pyrit/auxiliary_attacks/gcg/extension_protocols.py b/pyrit/auxiliary_attacks/gcg/extension_protocols.py index f9f1a3013e..973fb22a2b 100644 --- a/pyrit/auxiliary_attacks/gcg/extension_protocols.py +++ b/pyrit/auxiliary_attacks/gcg/extension_protocols.py @@ -16,12 +16,11 @@ - ``SuffixInitializer`` — how the initial suffix string fed into the optimization loop is constructed. -The module is **typing surface only**. It ships no concrete implementations, -no defaults, and no wiring into ``GCGAlgorithmConfig`` or -``GCGMultiPromptAttack``. The default behaviors that match the current attack -code will land as concrete classes in a follow-up PR; the optional -``GCGAlgorithmConfig`` fields that select between defaults and custom -implementations will land in the PR after that. +The module is **typing surface only**. Concrete defaults live in +``default_implementations.py``, and orchestration wiring lives in +``GCGAlgorithmConfig`` + ``GCGMultiPromptAttack``. Keeping this module purely +protocol definitions preserves a stable extension API that can be imported +without pulling in heavy runtime dependencies. Tensor-typed signatures are kept lazy via ``from __future__ import annotations`` plus a ``TYPE_CHECKING`` import for ``torch`` so that diff --git a/pyrit/auxiliary_attacks/gcg/generator.py b/pyrit/auxiliary_attacks/gcg/generator.py index 4c812594e9..12ef46040c 100644 --- a/pyrit/auxiliary_attacks/gcg/generator.py +++ b/pyrit/auxiliary_attacks/gcg/generator.py @@ -38,6 +38,7 @@ import logging import time from dataclasses import dataclass, field +from functools import partial from typing import Any, overload import numpy as np @@ -212,6 +213,18 @@ def _build_identifier(self) -> ComponentIdentifier: "topk": self._algorithm.topk, "target_weight": self._algorithm.target_weight, "control_weight": self._algorithm.control_weight, + "sampling_impl": ( + type(self._algorithm.sampling).__name__ if self._algorithm.sampling is not None else "default" + ), + "loss_impl": type(self._algorithm.loss).__name__ if self._algorithm.loss is not None else "default", + "candidate_filter_impl": ( + type(self._algorithm.candidate_filter).__name__ + if self._algorithm.candidate_filter is not None + else "default" + ), + "suffix_init_impl": ( + type(self._algorithm.suffix_init).__name__ if self._algorithm.suffix_init is not None else "default" + ), "transfer": self._strategy.transfer, "progressive_goals": self._strategy.progressive_goals, "progressive_models": self._strategy.progressive_models, @@ -257,7 +270,12 @@ async def _perform_async(self, *, context: GCGContext) -> GCGResult: managers = { "AP": attack_lib.GCGAttackPrompt, "PM": attack_lib.GCGPromptManager, - "MPA": attack_lib.GCGMultiPromptAttack, + "MPA": partial( + attack_lib.GCGMultiPromptAttack, + sampling=self._algorithm.sampling, + loss=self._algorithm.loss, + candidate_filter=self._algorithm.candidate_filter, + ), } context.attack = self._create_attack( params=params, @@ -400,6 +418,7 @@ def _create_attack( logfile_path: str, ) -> Any: """Build the right attack object based on the strategy flags.""" + control_init = self._resolve_control_init(workers=workers) if self._strategy.transfer: return ProgressiveMultiPromptAttack( train_goals, @@ -407,7 +426,7 @@ def _create_attack( workers, progressive_models=self._strategy.progressive_models, progressive_goals=self._strategy.progressive_goals, - control_init=self._algorithm.control_init, + control_init=control_init, logfile=logfile_path, managers=managers, test_goals=test_goals, @@ -421,7 +440,7 @@ def _create_attack( train_goals, train_targets, workers, - control_init=self._algorithm.control_init, + control_init=control_init, logfile=logfile_path, managers=managers, test_goals=test_goals, @@ -432,6 +451,18 @@ def _create_attack( mpa_n_steps=self._algorithm.n_steps, ) + def _resolve_control_init(self, *, workers: list[Any]) -> str: + """Resolve the initial suffix string for a run. + + Uses the configured ``suffix_init`` extension when provided; otherwise + falls back to the legacy literal ``control_init`` value. + """ + if self._algorithm.suffix_init is None: + return self._algorithm.control_init + if not workers: + raise ValueError("Cannot resolve suffix_init without at least one worker tokenizer.") + return self._algorithm.suffix_init.make_initial_suffix(tokenizer=workers[0].tokenizer) + @staticmethod def _read_result(*, logfile_path: str, memory_labels: dict[str, str]) -> GCGResult: """Pull final-step values out of the JSON log written during the run.""" diff --git a/tests/unit/auxiliary_attacks/gcg/test_config.py b/tests/unit/auxiliary_attacks/gcg/test_config.py index da0a7f6a9a..922b0ffedd 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_config.py +++ b/tests/unit/auxiliary_attacks/gcg/test_config.py @@ -9,7 +9,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest @@ -28,6 +28,49 @@ _LLAMA_2 = "meta-llama/Llama-2-7b-chat-hf" +class _SamplingStub: + def sample_candidates( + self, + *, + gradient: Any, + control_tokens: Any, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: Any, + ) -> Any: + return control_tokens + + +class _LossStub: + def compute_loss( + self, + *, + logits: Any, + token_ids: Any, + target_slice: slice, + control_slice: slice, + ) -> Any: + return logits + + +class _FilterStub: + def filter_candidates( + self, + *, + candidate_tokens: Any, + tokenizer: Any, + current_control: str, + ) -> list[str]: + return [current_control] + + +class _SuffixInitStub: + def make_initial_suffix(self, *, tokenizer: Any) -> str: + return "stub suffix" + + def _minimal_config() -> GCGConfig: return GCGConfig(models=[GCGModelConfig(name=_LLAMA_2)]) @@ -42,6 +85,10 @@ def test_minimal_config_constructs_with_defaults() -> None: assert config.test_models == [] assert config.algorithm.n_steps == 500 assert config.algorithm.batch_size == 512 + assert config.algorithm.sampling is None + assert config.algorithm.loss is None + assert config.algorithm.candidate_filter is None + assert config.algorithm.suffix_init is None assert config.strategy.transfer is False assert config.output.verbose is True assert config.hf_token is None @@ -100,6 +147,33 @@ def test_algorithm_empty_control_init_raises() -> None: GCGAlgorithmConfig(control_init="") +@pytest.mark.parametrize( + "field_name,value", + [ + ("sampling", object()), + ("loss", object()), + ("candidate_filter", object()), + ("suffix_init", object()), + ], +) +def test_algorithm_extension_type_validation(field_name: str, value: object) -> None: + with pytest.raises(ValueError, match=rf"GCGAlgorithmConfig\.{field_name} must satisfy"): + GCGAlgorithmConfig(**{field_name: value}) + + +def test_algorithm_accepts_protocol_implementations() -> None: + config = GCGAlgorithmConfig( + sampling=_SamplingStub(), + loss=_LossStub(), + candidate_filter=_FilterStub(), + suffix_init=_SuffixInitStub(), + ) + assert config.sampling is not None + assert config.loss is not None + assert config.candidate_filter is not None + assert config.suffix_init is not None + + @pytest.mark.parametrize("field_name", ["n_train_data", "n_test_data"]) def test_data_negative_count_raises(field_name: str) -> None: with pytest.raises(ValueError, match=f"GCGDataConfig.{field_name} must be >= 0"): diff --git a/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py b/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py index c3858bf357..3df82d2d41 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py +++ b/tests/unit/auxiliary_attacks/gcg/test_gcg_core.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from typing import Any from unittest.mock import MagicMock import pytest @@ -25,6 +26,7 @@ "pyrit.auxiliary_attacks.gcg.attack.gcg.gcg_attack", reason="GCG optional dependencies not installed", ) +GCGMultiPromptAttack = gcg_attack_mod.GCGMultiPromptAttack GCGPromptManager = gcg_attack_mod.GCGPromptManager token_gradients = gcg_attack_mod.token_gradients @@ -501,3 +503,310 @@ def test_raises_when_tokenizer_has_no_chat_template(self) -> None: with patch.object(attack_manager_mod.AutoTokenizer, "from_pretrained", return_value=bare_tokenizer): with pytest.raises(ValueError, match="no chat_template configured"): get_workers(params) + + +class _Queue: + def __init__(self, items: list[Any]) -> None: + self._items = list(items) + + def get(self) -> Any: + return self._items.pop(0) + + +class _WorkerStub: + def __init__( + self, + *, + gradient: torch.Tensor, + logits: torch.Tensor, + token_ids: torch.Tensor, + tokenizer: MagicMock, + ) -> None: + self.model = MagicMock() + self.model.device = "cpu" + self.tokenizer = tokenizer + self.results = _Queue([gradient, (logits, token_ids)]) + self.calls: list[tuple] = [] + + def __call__(self, *args: Any, **kwargs: Any) -> None: + self.calls.append((args, kwargs)) + + +class _PromptManagerStub: + def __init__( + self, + *, + prompt: AttackPrompt, + control_tokens: torch.Tensor, + disallowed_tokens: torch.Tensor, + control_str: str, + ) -> None: + self._prompts = [prompt] + self._control_tokens = control_tokens + self._disallowed_tokens = disallowed_tokens + self.control_str = control_str + + def __len__(self) -> int: + return len(self._prompts) + + def __getitem__(self, i: int) -> AttackPrompt: + return self._prompts[i] + + @property + def control_toks(self) -> torch.Tensor: + return self._control_tokens + + @property + def disallowed_toks(self) -> torch.Tensor: + return self._disallowed_tokens + + +class _SpySampling: + def __init__(self, *, sampled_tokens: torch.Tensor) -> None: + self.sampled_tokens = sampled_tokens + self.calls: list[dict] = [] + + def sample_candidates( + self, + *, + gradient: torch.Tensor, + control_tokens: torch.Tensor, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: torch.Tensor, + ) -> torch.Tensor: + self.calls.append( + { + "gradient": gradient.clone(), + "control_tokens": control_tokens.clone(), + "batch_size": batch_size, + "top_k": top_k, + "temperature": temperature, + "allow_non_ascii": allow_non_ascii, + "non_ascii_tokens": non_ascii_tokens.clone(), + } + ) + return self.sampled_tokens.clone() + + +class _SpyLoss: + def __init__(self, *, losses: torch.Tensor) -> None: + self.losses = losses + self.calls: list[dict] = [] + + def compute_loss( + self, + *, + logits: torch.Tensor, + token_ids: torch.Tensor, + target_slice: slice, + control_slice: slice, + ) -> torch.Tensor: + self.calls.append( + { + "logits": logits.clone(), + "token_ids": token_ids.clone(), + "target_slice": target_slice, + "control_slice": control_slice, + } + ) + return self.losses.to(logits.device) + + +class _SpyFilter: + def __init__(self, *, candidates: list[str]) -> None: + self.candidates = list(candidates) + self.calls: list[dict] = [] + + def filter_candidates( + self, + *, + candidate_tokens: torch.Tensor, + tokenizer: MagicMock, + current_control: str, + ) -> list[str]: + self.calls.append( + { + "candidate_tokens": candidate_tokens.clone(), + "tokenizer": tokenizer, + "current_control": current_control, + } + ) + return list(self.candidates) + + +class TestGCGMultiPromptAttackStepWiring: + @staticmethod + def _make_tokenizer() -> MagicMock: + tokenizer = MagicMock() + tokenizer.vocab_size = 100 + + def decode_fn(ids, **_kwargs): + values = ids.tolist() if hasattr(ids, "tolist") else list(ids) + return " ".join(str(int(v)) for v in values) + + def call_fn(text, **_kwargs): + output = MagicMock() + if text == "!": + output.input_ids = [0] + else: + output.input_ids = [int(piece) for piece in text.split()] if text else [] + return output + + tokenizer.decode.side_effect = decode_fn + tokenizer.side_effect = call_fn + return tokenizer + + @staticmethod + def _make_prompt(*, target_slice: slice, control_slice: slice) -> AttackPrompt: + prompt = object.__new__(AttackPrompt) + prompt._target_slice = target_slice + prompt._control_slice = control_slice + return prompt + + @staticmethod + def _make_attack( + *, + worker: _WorkerStub, + prompt_manager: _PromptManagerStub, + sampling: object | None = None, + loss: object | None = None, + candidate_filter: object | None = None, + ) -> GCGMultiPromptAttack: + attack = object.__new__(GCGMultiPromptAttack) + attack.workers = [worker] + attack.models = [worker.model] + attack.prompts = [prompt_manager] + attack._sampling = sampling + attack._loss = loss + attack._candidate_filter = candidate_filter + return attack + + def test_step_default_path_matches_legacy_behavior(self) -> None: + gradient = torch.tensor( + [ + [0.3, -0.4, 0.8, -0.2, 0.1, 0.5], + [-0.3, 0.2, -0.8, 0.4, 0.1, 0.7], + [0.2, 0.6, -0.1, -0.5, 0.4, -0.2], + ], + dtype=torch.float32, + ) + logits = torch.randn(1, 8, 10) + token_ids = torch.randint(0, 10, (1, 8)) + control_tokens = torch.tensor([1, 2, 3], dtype=torch.long) + disallowed_tokens = torch.tensor([], dtype=torch.long) + target_slice = slice(4, 6) + control_slice = slice(1, 4) + current_control = "99 99 99" + tokenizer = self._make_tokenizer() + + worker = _WorkerStub(gradient=gradient.clone(), logits=logits, token_ids=token_ids, tokenizer=tokenizer) + prompt = self._make_prompt(target_slice=target_slice, control_slice=control_slice) + prompt_manager = _PromptManagerStub( + prompt=prompt, + control_tokens=control_tokens, + disallowed_tokens=disallowed_tokens, + control_str=current_control, + ) + attack = self._make_attack(worker=worker, prompt_manager=prompt_manager) + + target_weight = 1.3 + control_weight = 0.2 + torch.manual_seed(2026) + actual_control, actual_loss = attack.step( + batch_size=1, + topk=3, + temp=1.0, + allow_non_ascii=True, + target_weight=target_weight, + control_weight=control_weight, + verbose=True, + filter_cand=True, + ) + + legacy_prompt_manager = object.__new__(GCGPromptManager) + legacy_prompt_for_sampling = MagicMock() + legacy_prompt_for_sampling.control_toks = control_tokens.clone() + legacy_prompt_manager._prompts = [legacy_prompt_for_sampling] + legacy_prompt_manager._nonascii_toks = disallowed_tokens + + legacy_attack = object.__new__(MultiPromptAttack) + legacy_worker = MagicMock() + legacy_worker.tokenizer = tokenizer + legacy_attack.workers = [legacy_worker] + + legacy_prompt_for_loss = self._make_prompt(target_slice=target_slice, control_slice=control_slice) + normalized_gradient = gradient / gradient.norm(dim=-1, keepdim=True) + torch.manual_seed(2026) + legacy_control_cand = legacy_prompt_manager.sample_control( + normalized_gradient.clone(), + 1, + topk=3, + temp=1.0, + allow_non_ascii=True, + ) + legacy_controls = legacy_attack.get_filtered_cands( + 0, + legacy_control_cand, + filter_cand=True, + curr_control=current_control, + ) + legacy_loss = target_weight * legacy_prompt_for_loss.target_loss(logits, token_ids).mean( + dim=-1 + ) + control_weight * legacy_prompt_for_loss.control_loss(logits, token_ids).mean(dim=-1) + + assert actual_control == legacy_controls[0] + assert actual_loss == pytest.approx(legacy_loss[0].item()) + + def test_step_uses_custom_protocol_implementations_when_supplied(self) -> None: + gradient = torch.randn(3, 6) + logits = torch.randn(2, 8, 10) + token_ids = torch.randint(0, 10, (2, 8)) + control_tokens = torch.tensor([1, 2, 3], dtype=torch.long) + disallowed_tokens = torch.tensor([5], dtype=torch.long) + tokenizer = self._make_tokenizer() + + worker = _WorkerStub(gradient=gradient.clone(), logits=logits, token_ids=token_ids, tokenizer=tokenizer) + prompt = self._make_prompt(target_slice=slice(4, 6), control_slice=slice(1, 4)) + prompt_manager = _PromptManagerStub( + prompt=prompt, + control_tokens=control_tokens, + disallowed_tokens=disallowed_tokens, + control_str="current control", + ) + + sampled_tokens = torch.tensor([[8, 8, 8], [9, 9, 9]], dtype=torch.long) + sampling = _SpySampling(sampled_tokens=sampled_tokens) + candidate_filter = _SpyFilter(candidates=["candidate-A", "candidate-B"]) + custom_losses = torch.tensor([3.0, 0.5], dtype=torch.float32) + loss = _SpyLoss(losses=custom_losses) + attack = self._make_attack( + worker=worker, + prompt_manager=prompt_manager, + sampling=sampling, + loss=loss, + candidate_filter=candidate_filter, + ) + + selected_control, normalized_loss = attack.step( + batch_size=2, + topk=4, + temp=0.8, + allow_non_ascii=False, + target_weight=0.0, + control_weight=1.0, + verbose=True, + filter_cand=True, + ) + + assert selected_control == "candidate-B" + assert normalized_loss == pytest.approx(0.5) + assert len(sampling.calls) == 1 + assert len(candidate_filter.calls) == 1 + assert len(loss.calls) == 1 + assert sampling.calls[0]["batch_size"] == 2 + assert sampling.calls[0]["top_k"] == 4 + assert sampling.calls[0]["allow_non_ascii"] is False + assert candidate_filter.calls[0]["current_control"] == "current control" diff --git a/tests/unit/auxiliary_attacks/gcg/test_generator.py b/tests/unit/auxiliary_attacks/gcg/test_generator.py index f410aa5079..2652a5558b 100644 --- a/tests/unit/auxiliary_attacks/gcg/test_generator.py +++ b/tests/unit/auxiliary_attacks/gcg/test_generator.py @@ -6,8 +6,9 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING -from unittest.mock import MagicMock, patch +from functools import partial +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -238,6 +239,113 @@ def test_augmentation_modifies_at_least_some_targets(self) -> None: assert num_changed > 0 +class TestExtensionWiring: + def test_create_attack_uses_suffix_initializer_when_configured(self) -> None: + class _SuffixInitStub: + def __init__(self) -> None: + self.calls: list[object] = [] + + def make_initial_suffix(self, *, tokenizer: object) -> str: + self.calls.append(tokenizer) + return "initialized suffix" + + suffix_init = _SuffixInitStub() + gen = GCGGenerator( + models=[GCGModelConfig(name=_LLAMA_2)], + algorithm=GCGAlgorithmConfig(suffix_init=suffix_init), + ) + worker = MagicMock() + worker.tokenizer = MagicMock() + + with patch.object(generator_mod, "IndividualPromptAttack") as mock_individual: + gen._create_attack( + params=MagicMock(), + managers={"MPA": MagicMock()}, + train_goals=["g"], + train_targets=["t"], + test_goals=[], + test_targets=[], + workers=[worker], + test_workers=[], + logfile_path="out.json", + ) + + assert suffix_init.calls == [worker.tokenizer] + assert mock_individual.call_args.kwargs["control_init"] == "initialized suffix" + + async def test_perform_async_binds_algorithm_extensions_into_mpa_factory(self, tmp_path: Path) -> None: + class _SamplingStub: + def sample_candidates( + self, + *, + gradient: Any, + control_tokens: Any, + batch_size: int, + top_k: int, + temperature: float, + allow_non_ascii: bool, + non_ascii_tokens: Any, + ) -> Any: + return control_tokens + + class _LossStub: + def compute_loss( + self, + *, + logits: Any, + token_ids: Any, + target_slice: slice, + control_slice: slice, + ) -> Any: + return logits + + class _FilterStub: + def filter_candidates( + self, + *, + candidate_tokens: Any, + tokenizer: Any, + current_control: str, + ) -> list[str]: + return [current_control] + + sampling = _SamplingStub() + loss = _LossStub() + candidate_filter = _FilterStub() + gen = GCGGenerator( + models=[GCGModelConfig(name=_LLAMA_2)], + algorithm=GCGAlgorithmConfig( + sampling=sampling, + loss=loss, + candidate_filter=candidate_filter, + ), + output=GCGOutputConfig(result_prefix=str(tmp_path / "gcg")), + ) + context = GCGContext( + goals=["g"], + targets=["t"], + workers=[MagicMock()], + test_workers=[], + ) + fake_attack = MagicMock() + + with ( + patch.object(gen, "_create_attack", return_value=fake_attack) as mock_create_attack, + patch.object(gen, "_build_logfile_path", return_value=str(tmp_path / "result.json")), + patch.object(gen, "_read_result", return_value=GCGResult(final_suffix="x")), + patch("pyrit.auxiliary_attacks.gcg.generator.asyncio.to_thread", new=AsyncMock(return_value=None)), + ): + await gen._perform_async(context=context) + + managers = mock_create_attack.call_args.kwargs["managers"] + mpa_factory = managers["MPA"] + assert isinstance(mpa_factory, partial) + assert mpa_factory.func is generator_mod.attack_lib.GCGMultiPromptAttack + assert mpa_factory.keywords["sampling"] is sampling + assert mpa_factory.keywords["loss"] is loss + assert mpa_factory.keywords["candidate_filter"] is candidate_filter + + class TestReadResult: def test_reads_final_suffix_and_loss(self, tmp_path: Path) -> None: log_path = tmp_path / "result.json" From 9588263bc32f9ceaf41f91aced74fb802d03e496 Mon Sep 17 00:00:00 2001 From: Copilot <223556219+Copilot@users.noreply.github.com> Date: Thu, 25 Jun 2026 06:39:33 -0700 Subject: [PATCH 3/3] Fix GCG step logging for custom controls Avoid failing the GCG step when logging tokenized control length for custom candidate-filter outputs that are not directly re-tokenizable by the active tokenizer. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py index c967ca0841..ea0677527e 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py +++ b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py @@ -218,6 +218,12 @@ def _filter_control_candidates( current_control=self.control_str, ) + def _get_control_length(self, *, control: str) -> int | None: + try: + return len(self.workers[0].tokenizer(control).input_ids[1:]) + except (AttributeError, TypeError, ValueError): + return None + def step( self, *, @@ -343,7 +349,9 @@ def step( del control_cands, loss gc.collect() - logger.info(f"Current length: {len(self.workers[0].tokenizer(next_control).input_ids[1:])}") + current_length = self._get_control_length(control=next_control) + if current_length is not None: + logger.info(f"Current length: {current_length}") logger.info(next_control) return next_control, cand_loss.item() / len(self.prompts[0]) / len(self.workers)