Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pyrit/auxiliary_attacks/gcg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,28 @@
# mechanism so all GCG public symbols share one re-export pathway.
_LAZY_IMPORTS = {
"CandidateFilter": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "CandidateFilter"),
"CrossEntropyLoss": ("pyrit.auxiliary_attacks.gcg.default_implementations", "CrossEntropyLoss"),
"GCG": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"),
"GCGContext": ("pyrit.auxiliary_attacks.gcg.generator", "GCGContext"),
"GCGGenerator": ("pyrit.auxiliary_attacks.gcg.generator", "GCGGenerator"),
"GCGResult": ("pyrit.auxiliary_attacks.gcg.generator", "GCGResult"),
"LengthPreservingFilter": ("pyrit.auxiliary_attacks.gcg.default_implementations", "LengthPreservingFilter"),
"LiteralStringInit": ("pyrit.auxiliary_attacks.gcg.default_implementations", "LiteralStringInit"),
"LossFunction": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "LossFunction"),
"SamplingStrategy": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SamplingStrategy"),
"StandardGCGSampling": ("pyrit.auxiliary_attacks.gcg.default_implementations", "StandardGCGSampling"),
"SuffixInitializer": ("pyrit.auxiliary_attacks.gcg.extension_protocols", "SuffixInitializer"),
"load_goals_and_targets": ("pyrit.auxiliary_attacks.gcg.data", "load_goals_and_targets"),
}

if TYPE_CHECKING:
from pyrit.auxiliary_attacks.gcg.data import load_goals_and_targets
from pyrit.auxiliary_attacks.gcg.default_implementations import (
CrossEntropyLoss,
LengthPreservingFilter,
LiteralStringInit,
StandardGCGSampling,
)
from pyrit.auxiliary_attacks.gcg.extension_protocols import (
CandidateFilter,
LossFunction,
Expand Down Expand Up @@ -91,6 +101,7 @@ def __dir__() -> list[str]:

__all__ = [
"CandidateFilter",
"CrossEntropyLoss",
"GCG",
"GCGAlgorithmConfig",
"GCGConfig",
Expand All @@ -101,8 +112,11 @@ def __dir__() -> list[str]:
"GCGOutputConfig",
"GCGResult",
"GCGStrategyConfig",
"LengthPreservingFilter",
"LiteralStringInit",
"LossFunction",
"SamplingStrategy",
"StandardGCGSampling",
"SuffixInitializer",
"load_goals_and_targets",
]
146 changes: 134 additions & 12 deletions pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
get_embedding_matrix,
get_embeddings,
)
from pyrit.auxiliary_attacks.gcg.default_implementations import (
CrossEntropyLoss,
LengthPreservingFilter,
StandardGCGSampling,
)
from pyrit.auxiliary_attacks.gcg.extension_protocols import CandidateFilter, LossFunction, SamplingStrategy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -125,6 +131,99 @@ def sample_control(
class GCGMultiPromptAttack(MultiPromptAttack):
"""GCG-specific multi-prompt attack that implements the GCG optimization step."""

def __init__(
self,
goals: list[str],
targets: list[str],
workers: list[Any],
control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !",
test_prefixes: list[str] | None = None,
logfile: str | None = None,
managers: dict[str, Any] | None = None,
test_goals: list[str] | None = None,
test_targets: list[str] | None = None,
test_workers: list[Any] | None = None,
*,
sampling: SamplingStrategy | None = None,
loss: LossFunction | None = None,
candidate_filter: CandidateFilter | None = None,
) -> None:
super().__init__(
goals,
targets,
workers,
control_init,
test_prefixes,
logfile,
managers,
test_goals,
test_targets,
test_workers,
)
self._sampling = sampling
self._loss = loss
self._candidate_filter = candidate_filter

def _resolve_sampling(self) -> SamplingStrategy:
sampling = getattr(self, "_sampling", None)
if sampling is not None:
return sampling
return StandardGCGSampling()

def _resolve_loss(self, *, target_weight: float, control_weight: float) -> LossFunction:
loss = getattr(self, "_loss", None)
if loss is not None:
return loss
return CrossEntropyLoss(target_weight=target_weight, control_weight=control_weight)

def _resolve_candidate_filter(self, *, filter_cand: bool) -> CandidateFilter:
candidate_filter = getattr(self, "_candidate_filter", None)
if candidate_filter is not None:
return candidate_filter
return LengthPreservingFilter(filter=filter_cand)

def _sample_control_candidates(
self,
*,
worker_index: int,
gradient: torch.Tensor,
batch_size: int,
topk: int,
temp: float,
allow_non_ascii: bool,
) -> torch.Tensor:
sampler = self._resolve_sampling()
prompt_manager = self.prompts[worker_index]
return sampler.sample_candidates(
gradient=gradient,
control_tokens=prompt_manager.control_toks,
batch_size=batch_size,
top_k=topk,
temperature=temp,
allow_non_ascii=allow_non_ascii,
non_ascii_tokens=prompt_manager.disallowed_toks,
)

def _filter_control_candidates(
self,
*,
worker_index: int,
control_cand: torch.Tensor,
filter_cand: bool,
) -> list[str]:
candidate_filter = self._resolve_candidate_filter(filter_cand=filter_cand)
return candidate_filter.filter_candidates(
candidate_tokens=control_cand,
tokenizer=self.workers[worker_index].tokenizer,
current_control=self.control_str,
)

def _get_control_length(self, *, control: str) -> int | None:
try:
return len(self.workers[0].tokenizer(control).input_ids[1:])
except (AttributeError, TypeError, ValueError):
return None

def step(
self,
*,
Expand Down Expand Up @@ -158,6 +257,7 @@ def step(
"""
main_device = self.models[0].device
control_cands = []
loss_function = self._resolve_loss(target_weight=target_weight, control_weight=control_weight)

for j, worker in enumerate(self.workers):
worker(self.prompts[j], "grad", worker.model)
Expand All @@ -171,20 +271,40 @@ def step(
grad = torch.zeros_like(new_grad)
if grad.shape != new_grad.shape:
with torch.no_grad():
control_cand = self.prompts[j - 1].sample_control(grad, batch_size, topk, temp, allow_non_ascii)
control_cand = self._sample_control_candidates(
worker_index=j - 1,
gradient=grad,
batch_size=batch_size,
topk=topk,
temp=temp,
allow_non_ascii=allow_non_ascii,
)
control_cands.append(
self.get_filtered_cands(
j - 1, control_cand, filter_cand=filter_cand, curr_control=self.control_str
self._filter_control_candidates(
worker_index=j - 1,
control_cand=control_cand,
filter_cand=filter_cand,
)
)
grad = new_grad
else:
grad += new_grad

with torch.no_grad():
control_cand = self.prompts[j].sample_control(grad, batch_size, topk, temp, allow_non_ascii)
control_cand = self._sample_control_candidates(
worker_index=j,
gradient=grad,
batch_size=batch_size,
topk=topk,
temp=temp,
allow_non_ascii=allow_non_ascii,
)
control_cands.append(
self.get_filtered_cands(j, control_cand, filter_cand=filter_cand, curr_control=self.control_str)
self._filter_control_candidates(
worker_index=j,
control_cand=control_cand,
filter_cand=filter_cand,
)
)
del grad, control_cand
gc.collect()
Expand All @@ -205,14 +325,14 @@ def step(
worker(self.prompts[k][i], "logits", worker.model, cand, return_ids=True)
logits, ids = zip(*[worker.results.get() for worker in self.workers])
loss[j * batch_size : (j + 1) * batch_size] += sum(
target_weight * self.prompts[k][i].target_loss(logit, id).mean(dim=-1).to(main_device)
loss_function.compute_loss(
logits=logit,
token_ids=id,
target_slice=self.prompts[k][i]._target_slice,
control_slice=self.prompts[k][i]._control_slice,
).to(main_device)
for k, (logit, id) in enumerate(zip(logits, ids))
)
if control_weight != 0:
loss[j * batch_size : (j + 1) * batch_size] += sum(
control_weight * self.prompts[k][i].control_loss(logit, id).mean(dim=-1).to(main_device)
for k, (logit, id) in enumerate(zip(logits, ids))
)
del logits, ids
gc.collect()

Expand All @@ -229,7 +349,9 @@ def step(
del control_cands, loss
gc.collect()

logger.info(f"Current length: {len(self.workers[0].tokenizer(next_control).input_ids[1:])}")
current_length = self._get_control_length(control=next_control)
if current_length is not None:
logger.info(f"Current length: {current_length}")
logger.info(next_control)

return next_control, cand_loss.item() / len(self.prompts[0]) / len(self.workers)
44 changes: 44 additions & 0 deletions pyrit/auxiliary_attacks/gcg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
if TYPE_CHECKING:
from pathlib import Path

from pyrit.auxiliary_attacks.gcg.extension_protocols import (
CandidateFilter,
LossFunction,
SamplingStrategy,
SuffixInitializer,
)

_DEFAULT_CONTROL_INIT: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"


Expand Down Expand Up @@ -147,6 +154,18 @@ class GCGAlgorithmConfig:
random_seed (int): Seed for ``torch``/``numpy``/``random``. Defaults to 42.
control_init (str): Initial suffix string the optimization starts from.
Defaults to twenty space-separated ``!`` tokens.
sampling (SamplingStrategy | None): Optional strategy object that
samples candidate suffix token sequences from the aggregated
gradient. ``None`` uses the built-in default implementation.
loss (LossFunction | None): Optional loss object used to score each
candidate suffix. ``None`` uses the built-in weighted
cross-entropy default that preserves legacy behavior.
candidate_filter (CandidateFilter | None): Optional candidate-filter
object that decodes/prunes sampled candidate token sequences.
``None`` uses the built-in length-preserving filter.
suffix_init (SuffixInitializer | None): Optional initializer object
that produces the initial suffix string at attack construction
time. ``None`` uses ``control_init`` verbatim.
"""

n_steps: int = 500
Expand All @@ -161,6 +180,10 @@ class GCGAlgorithmConfig:
filter_cand: bool = True
random_seed: int = 42
control_init: str = _DEFAULT_CONTROL_INIT
sampling: SamplingStrategy | None = None
loss: LossFunction | None = None
candidate_filter: CandidateFilter | None = None
suffix_init: SuffixInitializer | None = None

def __post_init__(self) -> None:
if self.n_steps <= 0:
Expand All @@ -183,6 +206,27 @@ def __post_init__(self) -> None:
)
if not self.control_init:
raise ValueError("GCGAlgorithmConfig.control_init must be a non-empty string.")
self._validate_extensions()

def _validate_extensions(self) -> None:
from pyrit.auxiliary_attacks.gcg.extension_protocols import (
CandidateFilter,
LossFunction,
SamplingStrategy,
SuffixInitializer,
)

checks = (
("sampling", self.sampling, SamplingStrategy),
("loss", self.loss, LossFunction),
("candidate_filter", self.candidate_filter, CandidateFilter),
("suffix_init", self.suffix_init, SuffixInitializer),
)
for field_name, value, protocol in checks:
if value is not None and not isinstance(value, protocol):
raise ValueError(
f"GCGAlgorithmConfig.{field_name} must satisfy {protocol.__name__}, got {type(value)!r}."
)


@dataclass
Expand Down
Loading
Loading