diff --git a/.semaphore/publish-test-pypi.yml b/.semaphore/publish-test-pypi.yml index d482eb61c..abbe1aeea 100644 --- a/.semaphore/publish-test-pypi.yml +++ b/.semaphore/publish-test-pypi.yml @@ -33,6 +33,7 @@ blocks: - name: Verify commands: - checkout + - export MISE_PYTHON_GITHUB_ATTESTATIONS=false - sem-version python 3.11 - export CK_VERSION=$(python -c "import tomllib; print(tomllib.load(open('pyproject.toml','rb'))['project']['version'])") - tools/test-released-wheels.sh $CK_VERSION test diff --git a/.semaphore/semaphore.yml b/.semaphore/semaphore.yml index 1fd44be4d..c706b501f 100644 --- a/.semaphore/semaphore.yml +++ b/.semaphore/semaphore.yml @@ -8,7 +8,7 @@ execution_time_limit: global_job_config: env_vars: - name: LIBRDKAFKA_VERSION - value: v2.14.2 + value: v2.14.2-aws-iam-dev prologue: commands: - checkout diff --git a/CHANGELOG.md b/CHANGELOG.md index d905e27ae..615bbdcea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ ## v2.xx.x +### Features + +- New optional install `confluent-kafka[oauthbearer-aws]` provides AWS IAM-based + OAUTHBEARER authentication via AWS STS `GetWebIdentityToken`. Activate by setting `sasl.oauthbearer.method=oidc`, +`sasl.oauthbearer.metadata.authentication.type=aws_iam`, and +`sasl.oauthbearer.config="region=...,audience=..."`. See [`examples/oauth_oidc_ccloud_aws_iam.py`](examples/oauth_oidc_ccloud_aws_iam.py) +for a worked example. + ### Fixes - Fix Encryption fails when expanded union types have two references to the same record (#2262) diff --git a/README.md b/README.md index 51f422be0..13f00f4a7 100644 --- a/README.md +++ b/README.md @@ -277,6 +277,9 @@ pip install "confluent-kafka[protobuf,schemaregistry]" # Protobuf # With Data Contract rules (includes CSFLE support) pip install "confluent-kafka[avro,schemaregistry,rules]" + +# With AWS IAM OAUTHBEARER authentication (mints JWTs via AWS STS GetWebIdentityToken) +pip install "confluent-kafka[oauthbearer-aws]" ``` **Note:** Pre-built Linux wheels do not include SASL Kerberos/GSSAPI support. For Kerberos, see the source installation instructions in [INSTALL.md](INSTALL.md). @@ -304,6 +307,20 @@ When using Data Contract rules (including CSFLE) add the `rules`extra, e.g.: pip install "confluent-kafka[avro,schemaregistry,rules]" ``` +To authenticate to a Kafka cluster using AWS IAM (when running on EC2, EKS, ECS, +Fargate, or Lambda with an IAM role attached), add the `oauthbearer-aws` extra: + +```bash +pip install "confluent-kafka[oauthbearer-aws]" +``` + +Activation is config-only — set `sasl.oauthbearer.method=oidc`, +`sasl.oauthbearer.metadata.authentication.type=aws_iam`, and +`sasl.oauthbearer.config="region=...,audience=..."`. The client mints fresh +JWTs via AWS STS on every token refresh — no static credentials, no Python-side +imports. See [`examples/oauth_oidc_ccloud_aws_iam.py`](examples/oauth_oidc_ccloud_aws_iam.py) +for a worked example. + **Install from source** For source install, see the *Install from source* section in [INSTALL.md](INSTALL.md). diff --git a/examples/docker/Dockerfile.alpine b/examples/docker/Dockerfile.alpine index 2a6156bf5..882b4497f 100644 --- a/examples/docker/Dockerfile.alpine +++ b/examples/docker/Dockerfile.alpine @@ -30,7 +30,7 @@ FROM alpine:3.12 COPY . /usr/src/confluent-kafka-python -ENV LIBRDKAFKA_VERSION="v2.14.2" +ENV LIBRDKAFKA_VERSION="v2.14.2-aws-iam-dev" ENV KCAT_VERSION="master" ENV CKP_VERSION="master" diff --git a/examples/oauth_oidc_ccloud_aws_iam.py b/examples/oauth_oidc_ccloud_aws_iam.py new file mode 100644 index 000000000..e53582070 --- /dev/null +++ b/examples/oauth_oidc_ccloud_aws_iam.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end example for AWS IAM OAUTHBEARER authentication. + +Activation is config-only: Set `sasl.oauthbearer.method=oidc`, +`sasl.oauthbearer.metadata.authentication.type=aws_iam`, and +`sasl.oauthbearer.config="region=...,audience=..."`. + +Install: + pip install 'confluent-kafka[oauthbearer-aws]' + +Runs on AWS compute (EC2 / EKS / ECS / Fargate / Lambda) with an IAM role +attached — boto3's default credential chain resolves it, no static keys. + +To run: + python oauth_oidc_ccloud_aws_iam.py \\ + -b pkc-xxxx.aws.confluent.cloud:9092 \\ + --region eu-north-1 \\ + --audience https://confluent.cloud/oidc \\ + --extensions logicalCluster=lkc-abc,identityPoolId=pool-xyz +""" + +import argparse +import logging +import time +import uuid + +from confluent_kafka import Consumer, Producer +from confluent_kafka.admin import AdminClient, NewTopic +from confluent_kafka.serialization import StringSerializer + + +def common_config(args): + """SASL config shared by Producer, Consumer, and AdminClient.""" + conf = { + 'bootstrap.servers': args.bootstrap_servers, + 'security.protocol': 'SASL_SSL', + 'sasl.mechanisms': 'OAUTHBEARER', + 'sasl.oauthbearer.method': 'oidc', + 'sasl.oauthbearer.metadata.authentication.type': 'aws_iam', + 'sasl.oauthbearer.config': f'region={args.region},' + f'audience={args.audience},' + f'duration_seconds={args.duration_seconds}', + 'debug': 'security', + } + + if args.extensions: + conf['sasl.oauthbearer.extensions'] = args.extensions + + return conf + + +def consumer_config(args, group_id): + cfg = common_config(args) + cfg['group.id'] = group_id + cfg['auto.offset.reset'] = 'earliest' + cfg['enable.auto.offset.store'] = False # commit offsets manually + return cfg + + +def create_topic(admin_conf, topic_name, num_partitions=1, replication_factor=3): + admin = AdminClient(admin_conf) + futures = admin.create_topics( + [ + NewTopic(topic_name, num_partitions=num_partitions, replication_factor=replication_factor), + ] + ) + for topic, future in futures.items(): + try: + future.result() + print(f"[admin] Topic '{topic}' created " f"({num_partitions} partition(s), RF={replication_factor})") + except Exception as exc: + print(f"[admin] Failed to create topic '{topic}': {exc}") + raise + + +def delivery_report(err, msg): + if err is not None: + print(f"[producer] Delivery failed: {err}") + return + print( + f"[producer] Produced to {msg.topic()} [{msg.partition()}] " + f"at offset {msg.offset()}: {msg.value().decode('utf-8')}" + ) + + +def main(args): + # Unique topic + group per run so the example is self-contained. + topic_name = f"aws-iam-{uuid.uuid4()}" + group_id = f"aws-iam-consumer-{uuid.uuid4()}" + + p_conf = common_config(args) + c_conf = consumer_config(args, group_id) + a_conf = common_config(args) + + logging.basicConfig(level=logging.INFO) + + print("\n=== AWS IAM OAUTHBEARER end-to-end example ===") + print(f"bootstrap.servers: {args.bootstrap_servers}") + print(f"region: {args.region}") + print(f"audience: {args.audience}") + print(f"duration_seconds: {args.duration_seconds} " f"(auto-refresh at ~{int(args.duration_seconds * 0.8)}s)") + print(f"run-for: {args.run_for}s") + print(f"topic (generated): {topic_name}") + print(f"group.id (generated): {group_id}\n") + + create_topic(a_conf, topic_name) + + producer = Producer(p_conf) + consumer = Consumer(c_conf) + consumer.subscribe([topic_name]) + serializer = StringSerializer('utf_8') + + start = time.time() + end_at = start + args.run_for + produced = 0 + consumed = 0 + + print( + f"[loop] Producing/consuming for {args.run_for}s — " + f"watch the debug=security logs for token-refresh events.\n" + ) + + try: + while time.time() < end_at: + elapsed = time.time() - start + msg = f"hello-from-aws-iam T+{elapsed:.1f}s" + + producer.produce( + topic_name, + value=serializer(msg), + on_delivery=delivery_report, + ) + producer.poll(0) + produced += 1 + + received = consumer.poll(1.0) + if received is None: + pass # poll timeout, no message yet + elif received.error() is not None: + print(f"[consumer] error: {received.error()}") + else: + consumer.store_offsets(received) + consumed += 1 + print( + f"[consumer] Received from " + f"{received.topic()} [{received.partition()}] " + f"at offset {received.offset()}: " + f"{received.value().decode('utf-8')}" + ) + + time.sleep(args.interval) + except KeyboardInterrupt: + print("\n[main] Interrupted — flushing.") + finally: + print(f"\n[summary] Produced {produced}, consumed {consumed} " f"in {time.time() - start:.1f}s. Flushing...") + producer.flush(timeout=10) + consumer.close() + print("[summary] Done.") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='End-to-end OAUTHBEARER example via AWS IAM autowire ' '(produce + consume + admin).', + ) + parser.add_argument('-b', dest='bootstrap_servers', required=True, help='Bootstrap broker(s) (host[:port])') + parser.add_argument('--region', required=True, help='AWS region (e.g. us-east-1)') + parser.add_argument( + '--audience', + required=True, + help='OIDC audience claim the broker expects ' '(e.g. https://confluent.cloud/oidc)', + ) + parser.add_argument( + '--extensions', + default=None, + help='Optional sasl.oauthbearer.extensions value ' '(comma-separated key=value pairs)', + ) + parser.add_argument( + '--duration-seconds', + dest='duration_seconds', + type=int, + default=60, + help='STS DurationSeconds (default 60 = AWS minimum); ' 'librdkafka auto-refreshes at ~80%% of it.', + ) + parser.add_argument( + '--run-for', dest='run_for', type=int, default=120, help='Run duration in seconds (default 120).' + ) + parser.add_argument('--interval', type=float, default=5.0, help='Seconds between produce calls (default 5).') + + main(parser.parse_args()) diff --git a/pyproject.toml b/pyproject.toml index c8149ed38..efeb0fc67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "confluent-kafka" -version = "2.14.2" +version = "2.14.2.dev5" description = "Confluent's Python client for Apache Kafka" classifiers = [ "Development Status :: 5 - Production/Stable", @@ -107,6 +107,7 @@ optional-dependencies.avro = { file = ["requirements/requirements-avro.txt", "re optional-dependencies.json = { file = ["requirements/requirements-json.txt", "requirements/requirements-schemaregistry.txt"] } optional-dependencies.json-fast = { file = ["requirements/requirements-json.txt", "requirements/requirements-json-fast.txt", "requirements/requirements-schemaregistry.txt"] } optional-dependencies.protobuf = { file = ["requirements/requirements-protobuf.txt", "requirements/requirements-schemaregistry.txt"] } +optional-dependencies.oauthbearer-aws = { file = ["requirements/requirements-oauthbearer-aws.txt"] } optional-dependencies.dev = { file = [ "requirements/requirements-docs.txt", "requirements/requirements-examples.txt", @@ -115,7 +116,8 @@ optional-dependencies.dev = { file = [ "requirements/requirements-rules.txt", "requirements/requirements-avro.txt", "requirements/requirements-json.txt", - "requirements/requirements-protobuf.txt"] } + "requirements/requirements-protobuf.txt", + "requirements/requirements-oauthbearer-aws.txt"] } optional-dependencies.docs = { file = [ "requirements/requirements-docs.txt", "requirements/requirements-schemaregistry.txt", @@ -129,7 +131,8 @@ optional-dependencies.tests = { file = [ "requirements/requirements-rules.txt", "requirements/requirements-avro.txt", "requirements/requirements-json.txt", - "requirements/requirements-protobuf.txt"] } + "requirements/requirements-protobuf.txt", + "requirements/requirements-oauthbearer-aws.txt"] } optional-dependencies.examples = { file = ["requirements/requirements-examples.txt"] } optional-dependencies.soaktest = { file = ["requirements/requirements-soaktest.txt"] } optional-dependencies.all = { file = [ @@ -141,7 +144,8 @@ optional-dependencies.all = { file = [ "requirements/requirements-rules.txt", "requirements/requirements-avro.txt", "requirements/requirements-json.txt", - "requirements/requirements-protobuf.txt"] } + "requirements/requirements-protobuf.txt", + "requirements/requirements-oauthbearer-aws.txt"] } [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt index d2514d3a7..0acf81a06 100644 --- a/requirements/requirements-all.txt +++ b/requirements/requirements-all.txt @@ -7,4 +7,5 @@ -r requirements-examples.txt -r requirements-tests.txt -r requirements-docs.txt --r requirements-soaktest.txt \ No newline at end of file +-r requirements-soaktest.txt +-r requirements-oauthbearer-aws.txt \ No newline at end of file diff --git a/requirements/requirements-oauthbearer-aws.txt b/requirements/requirements-oauthbearer-aws.txt new file mode 100644 index 000000000..3c332eafc --- /dev/null +++ b/requirements/requirements-oauthbearer-aws.txt @@ -0,0 +1 @@ +boto3>=1.42.25 diff --git a/requirements/requirements-tests-install.txt b/requirements/requirements-tests-install.txt index 649aee81d..772e90fe2 100644 --- a/requirements/requirements-tests-install.txt +++ b/requirements/requirements-tests-install.txt @@ -4,4 +4,5 @@ -r requirements-avro.txt -r requirements-protobuf.txt -r requirements-json.txt +-r requirements-oauthbearer-aws.txt tests/trivup/trivup-0.14.0.tar.gz \ No newline at end of file diff --git a/src/confluent_kafka/_util/librdkafka_string_parser.py b/src/confluent_kafka/_util/librdkafka_string_parser.py new file mode 100644 index 000000000..adeeb4ce4 --- /dev/null +++ b/src/confluent_kafka/_util/librdkafka_string_parser.py @@ -0,0 +1,165 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared ``key=value`` string parser with librdkafka-faithful semantics. + +Logically equivalent to librdkafka's ``rd_string_split`` (``src/rdstring.c``) +plus ``rd_kafka_conf_kv_split`` (``src/rdkafka_conf.c``), so OAUTHBEARER +config / extension strings tokenize identically to the native client — most +importantly an ``identityPoolId`` that is itself a comma-separated list whose +commas are backslash-quoted. + +Semantics (mirroring ``rd_string_split``): + +* Fields are separated by a single ``sep`` character. +* ``\\`` escapes the next character: ``\\t`` / ``\\n`` / ``\\r`` / ``\\0`` map to + TAB / LF / CR / NUL; any other escaped character (including an escaped + separator or ``\\\\``) is kept literally. A dangling trailing ``\\`` is dropped. +* Leading whitespace is stripped only when *unescaped*; trailing whitespace is + stripped *unconditionally* (even when it was escaped) — an asymmetry copied + verbatim from librdkafka. "Whitespace" is the ASCII set (C ``isspace``): + space, ``\\t``, ``\\n``, ``\\v``, ``\\f``, ``\\r`` — deliberately not Unicode. +* Empty fields are skipped when ``skip_empty`` is set. +* ``key=value`` splits on the FIRST ``=`` (the value may contain further ``=``); + a field with no ``=`` or an empty key is an error. + +This is the cross-language port of .NET's ``LibrdkafkaStringParser``; its tests +(``tests/_util/test_librdkafka_string_parser.py``) port librdkafka's own +``ut_string_split`` vectors verbatim to lock parity. +""" + +from typing import List, Optional, Tuple + +__all__ = ["split", "parse_key_values"] + + +def _is_ascii_space(c: str) -> bool: + """ASCII ``isspace``: space, ``\\t``(0x09)..``\\r``(0x0D). + + Matches librdkafka's C ``isspace`` behaviour; intentionally NOT + :meth:`str.isspace`, which also matches Unicode whitespace (e.g. U+00A0) + and would diverge from the native client. + """ + return c == " " or ("\t" <= c <= "\r") + + +def split(raw: str, sep: str, skip_empty: bool) -> List[str]: + """Split ``raw`` into fields on ``sep``, applying librdkafka's ``\\``-escaping + and whitespace trimming. Logically equivalent to ``rd_string_split``. + + :param raw: The input string to tokenize. + :param sep: The single field-separator character (only its first character + is used). ``','`` for comma-separated values, etc. + :param skip_empty: When true, empty fields (consecutive separators, or + whitespace-only fields) are omitted from the result. + :raises TypeError: ``raw`` is ``None``. + """ + if raw is None: + raise TypeError("raw must not be None") + + # rd_string_split takes a single separator character. + sep_char = sep[0] if sep else "" + + fields: List[str] = [] + field: List[str] = [] + next_esc = False + n = len(raw) + idx = 0 + + while True: + at_end = idx >= n + is_esc = next_esc + + if not at_end: + c = raw[idx] + + # An unescaped backslash is consumed and escapes the next char. + if not is_esc and c == "\\": + next_esc = True + idx += 1 + continue + + next_esc = False + + # Strip leading whitespace (only when unescaped). + if not is_esc and not field and _is_ascii_space(c): + idx += 1 + continue + + # Content char: any escaped char, or any non-separator char. + if is_esc or c != sep_char: + if is_esc: + # Common escape substitutions; an unknown escape (e.g. an + # escaped separator or "\\") keeps the character as-is. + if c == "t": + c = "\t" + elif c == "n": + c = "\n" + elif c == "r": + c = "\r" + elif c == "0": + c = "\0" + field.append(c) + idx += 1 + continue + + # Otherwise c is an unescaped separator: fall through to finish + # the current field. + + # Finish the current field (reached on a separator or end-of-input). + while field and _is_ascii_space(field[-1]): + field.pop() # strip trailing whitespace (unconditional) + + if not field and skip_empty: + if at_end: + break + idx += 1 # advance past the separator + continue + + fields.append("".join(field)) + field = [] + + if at_end: + break + idx += 1 # advance past the separator + + return fields + + +def parse_key_values( + raw: str, + sep: str, + context_label: Optional[str] = None, +) -> List[Tuple[str, str]]: + """Split ``raw`` via :func:`split` (skipping empty fields) and parse each + field into a ``(key, value)`` pair on its first ``=``. Logically equivalent + to applying librdkafka's ``rd_kafka_conf_kv_split`` over every field. + + :param raw: The raw ``key=value`` string to parse. + :param sep: The single field-separator character (e.g. ``','``). + :param context_label: Label woven into the error message to identify which + config a malformed entry came from. When ``None``, the message falls + back to a generic ``key=value`` phrasing. + :raises TypeError: ``raw`` is ``None``. + :raises ValueError: a field has no ``=``, or has an empty key. + """ + pairs: List[Tuple[str, str]] = [] + for field in split(raw, sep, skip_empty=True): + # Split on the FIRST '='. eq <= 0 means no '=' or an empty key. + eq = field.find("=") + if eq <= 0: + where = f" in {context_label}" if context_label else "" + raise ValueError(f"Malformed entry '{field}'{where} (expected key=value).") + pairs.append((field[:eq], field[eq + 1 :])) + return pairs diff --git a/src/confluent_kafka/cimpl.pyi b/src/confluent_kafka/cimpl.pyi index ddcb91cab..a4a206515 100644 --- a/src/confluent_kafka/cimpl.pyi +++ b/src/confluent_kafka/cimpl.pyi @@ -67,12 +67,14 @@ class KafkaError: DELEGATION_TOKEN_REQUEST_NOT_ALLOWED: int DUPLICATE_RESOURCE: int DUPLICATE_SEQUENCE_NUMBER: int + DUPLICATE_VOTER: int ELECTION_NOT_NEEDED: int ELIGIBLE_LEADERS_NOT_AVAILABLE: int FEATURE_UPDATE_FAILED: int FENCED_INSTANCE_ID: int FENCED_LEADER_EPOCH: int FENCED_MEMBER_EPOCH: int + FENCED_STATE_EPOCH: int FETCH_SESSION_ID_NOT_FOUND: int GROUP_AUTHORIZATION_FAILED: int GROUP_ID_NOT_FOUND: int @@ -81,6 +83,7 @@ class KafkaError: ILLEGAL_GENERATION: int ILLEGAL_SASL_STATE: int INCONSISTENT_GROUP_PROTOCOL: int + INCONSISTENT_TOPIC_ID: int INCONSISTENT_VOTER_SET: int INVALID_COMMIT_OFFSET_SIZE: int INVALID_CONFIG: int @@ -93,20 +96,26 @@ class KafkaError: INVALID_PRODUCER_EPOCH: int INVALID_PRODUCER_ID_MAPPING: int INVALID_RECORD: int + INVALID_RECORD_STATE: int + INVALID_REGISTRATION: int + INVALID_REGULAR_EXPRESSION: int INVALID_REPLICATION_FACTOR: int INVALID_REPLICA_ASSIGNMENT: int INVALID_REQUEST: int INVALID_REQUIRED_ACKS: int INVALID_SESSION_TIMEOUT: int + INVALID_SHARE_SESSION_EPOCH: int INVALID_TIMESTAMP: int INVALID_TRANSACTION_TIMEOUT: int INVALID_TXN_STATE: int INVALID_UPDATE_VERSION: int + INVALID_VOTER_KEY: int KAFKA_STORAGE_ERROR: int LEADER_NOT_AVAILABLE: int LISTENER_NOT_FOUND: int LOG_DIR_NOT_FOUND: int MEMBER_ID_REQUIRED: int + MISMATCHED_ENDPOINT_TYPE: int MSG_SIZE_TOO_LARGE: int NETWORK_EXCEPTION: int NON_EMPTY_GROUP: int @@ -135,9 +144,14 @@ class KafkaError: RESOURCE_NOT_FOUND: int SASL_AUTHENTICATION_FAILED: int SECURITY_DISABLED: int + SHARE_SESSION_LIMIT_REACHED: int + SHARE_SESSION_NOT_FOUND: int STALE_BROKER_EPOCH: int STALE_CTRL_EPOCH: int STALE_MEMBER_EPOCH: int + STREAMS_INVALID_TOPOLOGY: int + STREAMS_INVALID_TOPOLOGY_EPOCH: int + STREAMS_TOPOLOGY_FENCED: int TELEMETRY_TOO_LARGE: int THROTTLING_QUOTA_EXCEEDED: int TOPIC_ALREADY_EXISTS: int @@ -145,9 +159,11 @@ class KafkaError: TOPIC_DELETION_DISABLED: int TOPIC_EXCEPTION: int TRANSACTIONAL_ID_AUTHORIZATION_FAILED: int + TRANSACTION_ABORTABLE: int TRANSACTION_COORDINATOR_FENCED: int UNACCEPTABLE_CREDENTIAL: int UNKNOWN: int + UNKNOWN_CONTROLLER_ID: int UNKNOWN_LEADER_EPOCH: int UNKNOWN_MEMBER_ID: int UNKNOWN_PRODUCER_ID: int @@ -158,9 +174,11 @@ class KafkaError: UNSTABLE_OFFSET_COMMIT: int UNSUPPORTED_ASSIGNOR: int UNSUPPORTED_COMPRESSION_TYPE: int + UNSUPPORTED_ENDPOINT_TYPE: int UNSUPPORTED_FOR_MESSAGE_FORMAT: int UNSUPPORTED_SASL_MECHANISM: int UNSUPPORTED_VERSION: int + VOTER_NOT_FOUND: int _ALL_BROKERS_DOWN: int _APPLICATION: int _ASSIGNMENT_LOST: int diff --git a/src/confluent_kafka/oauthbearer/__init__.py b/src/confluent_kafka/oauthbearer/__init__.py new file mode 100644 index 000000000..79258e6e7 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Namespace package for OAUTHBEARER provider integrations.""" diff --git a/src/confluent_kafka/oauthbearer/aws/__init__.py b/src/confluent_kafka/oauthbearer/aws/__init__.py new file mode 100644 index 000000000..48ecf1f87 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AWS IAM OAUTHBEARER autowire subpackage. + +The only publicly importable name in this subpackage is +:func:`confluent_kafka.oauthbearer.aws.aws_autowire.create_handler`, loaded by +core's C extension when the user sets +``sasl.oauthbearer.metadata.authentication.type=aws_iam``. All other modules +are private (underscore-prefixed) and not re-exported here. + +Install with:: + + pip install 'confluent-kafka[oauthbearer-aws]' +""" diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py b/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py new file mode 100644 index 000000000..7a1cf69ad --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py @@ -0,0 +1,37 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Marker constants identifying the AWS IAM OAUTHBEARER autowire path. + +The config-key/value pair that activates the AWS OAUTHBEARER autowire path. + +The C dispatcher in ``src/confluent_kafka/src/confluent_kafka.c`` keeps its +own literal copies of these values for compile-time use; the drift-guard +test in ``tests/oauthbearer/aws/test_aws_iam_marker.py`` asserts the C-side +literals and these Python constants stay in lock-step. + +These strings are part of the cross-language wire contract — bumping either +is a major version change on ``confluent-kafka``. +""" + +__all__ = ["AWS_IAM_MARKER_KEY", "AWS_IAM_MARKER_VALUE"] + + +#: Config key that activates the AWS IAM autowire path when set +#: to :data:`AWS_IAM_MARKER_VALUE`. +AWS_IAM_MARKER_KEY: str = "sasl.oauthbearer.metadata.authentication.type" + +#: On-wire value of :data:`AWS_IAM_MARKER_KEY` that selects AWS IAM +#: authentication. +AWS_IAM_MARKER_VALUE: str = "aws_iam" diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py new file mode 100644 index 000000000..c5779fe58 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py @@ -0,0 +1,229 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: validated ``sasl.oauthbearer.config`` dataclass + parser. + +The full grammar (comma-separated ``key=value`` pairs, librdkafka grammar — +values may backslash-quote a comma, e.g. ``\\,``): + + region= (required) + audience= (required) + duration_seconds=<60..3600> (default: 300) + signing_algorithm=ES384|RS256 (default: ES384) + sts_endpoint= (optional, FIPS / VPC) + aws_debug=none|console (default: none) + tag_= (zero or more JWT custom claims, max 50) + +SASL extensions arrive separately via :data:`sasl_extensions` (parsed from +the typed ``sasl.oauthbearer.extensions`` config property). They are NOT +accepted inside this string under any ``extension_*`` prefix — that key +shape is rejected as an unknown key. +""" + +from dataclasses import dataclass +from typing import Dict, Optional + +from confluent_kafka._util.librdkafka_string_parser import parse_key_values + +__all__ = [ + "CONFIG_KEY", + "DEFAULT_SIGNING_ALGORITHM", + "ALLOWED_SIGNING_ALGORITHMS", + "MIN_DURATION_SECONDS", + "MAX_DURATION_SECONDS", + "DEFAULT_DURATION_SECONDS", + "TAG_KEY_PREFIX", + "MAX_TAGS", + "AWS_DEBUG_NONE", + "AWS_DEBUG_CONSOLE", + "ALLOWED_AWS_DEBUG_VALUES", + "AwsOAuthBearerConfig", +] + + +#: Config key carrying the AWS-path wire-grammar string. +CONFIG_KEY: str = "sasl.oauthbearer.config" + +#: Default JWT signing algorithm. +DEFAULT_SIGNING_ALGORITHM: str = "ES384" + +#: Signing algorithms accepted by AWS STS ``GetWebIdentityToken``. +ALLOWED_SIGNING_ALGORITHMS = ("ES384", "RS256") + +#: Minimum / default / maximum token lifetime AWS STS allows +MIN_DURATION_SECONDS: int = 60 +MAX_DURATION_SECONDS: int = 3600 +DEFAULT_DURATION_SECONDS: int = 300 + +#: Wire-grammar prefix for STS ``Tags`` entries (e.g. ``tag_team=platform``). +TAG_KEY_PREFIX: str = "tag_" + +#: AWS-enforced upper bound on number of tags per ``GetWebIdentityToken`` call. +MAX_TAGS: int = 50 + +#: Sentinel string values for ``aws_debug``. +AWS_DEBUG_NONE: str = "none" +AWS_DEBUG_CONSOLE: str = "console" + +#: ``aws_debug`` values accepted by the Python client: ``none`` and ``console``. +ALLOWED_AWS_DEBUG_VALUES = (AWS_DEBUG_NONE, AWS_DEBUG_CONSOLE) + + +# Recognised non-tag keys for the wire grammar. Anything else (other than +# ``tag_``) raises "Unknown key" during :meth:`AwsOAuthBearerConfig.parse`. +_RECOGNISED_KEYS = frozenset( + { + "region", + "audience", + "duration_seconds", + "signing_algorithm", + "sts_endpoint", + "aws_debug", + } +) + +_NON_EMPTY_KEYS = frozenset( + { + "region", + "audience", + "signing_algorithm", + "sts_endpoint", + "aws_debug", + } +) + + +@dataclass(frozen=True) +class AwsOAuthBearerConfig: + """Immutable view of the AWS path's ``sasl.oauthbearer.config``.""" + + region: str + audience: str + signing_algorithm: str = DEFAULT_SIGNING_ALGORITHM + duration_seconds: int = DEFAULT_DURATION_SECONDS + sts_endpoint: Optional[str] = None + aws_debug: str = AWS_DEBUG_NONE + tags: Optional[Dict[str, str]] = None + sasl_extensions: Optional[Dict[str, str]] = None + + def __post_init__(self) -> None: + """Final-state validation. Raises :class:`ValueError` on bad input.""" + if not isinstance(self.region, str) or self.region == "": + raise ValueError(f"{CONFIG_KEY} 'region' must not be empty.") + if not isinstance(self.audience, str) or self.audience == "": + raise ValueError(f"{CONFIG_KEY} 'audience' must not be empty.") + if self.signing_algorithm not in ALLOWED_SIGNING_ALGORITHMS: + raise ValueError( + f"{CONFIG_KEY} 'signing_algorithm' must be 'ES384' or 'RS256'; " f"got {self.signing_algorithm!r}." + ) + # bool is-a int in Python — reject explicitly so True/False doesn't + # slip through as 1/0. + if not isinstance(self.duration_seconds, int) or isinstance(self.duration_seconds, bool): + raise ValueError(f"{CONFIG_KEY} 'duration_seconds' must be an integer.") + if not (MIN_DURATION_SECONDS <= self.duration_seconds <= MAX_DURATION_SECONDS): + raise ValueError( + f"{CONFIG_KEY} 'duration_seconds' must be between " + f"{MIN_DURATION_SECONDS} and {MAX_DURATION_SECONDS} inclusive; " + f"got {self.duration_seconds}." + ) + if self.sts_endpoint is not None and self.sts_endpoint == "": + raise ValueError(f"{CONFIG_KEY} 'sts_endpoint' must not be empty.") + if self.aws_debug not in ALLOWED_AWS_DEBUG_VALUES: + raise ValueError(f"{CONFIG_KEY} 'aws_debug' must be one of: none, console. Got {self.aws_debug!r}.") + if self.tags is not None: + if not isinstance(self.tags, dict): + raise ValueError(f"{CONFIG_KEY} 'tags' must be a dict.") + if len(self.tags) > MAX_TAGS: + raise ValueError(f"{CONFIG_KEY} has {len(self.tags)} tags; AWS allows at " f"most {MAX_TAGS}.") + if self.sasl_extensions is not None and not isinstance(self.sasl_extensions, dict): + raise ValueError("sasl_extensions, if set, must be a dict.") + + @classmethod + def parse( + cls, + raw: str, + sasl_extensions: Optional[Dict[str, str]] = None, + ) -> "AwsOAuthBearerConfig": + """Parse the verbatim ``sasl.oauthbearer.config`` value. + + Comma-separated ``key=value`` tokens; the union of recognised + keys plus ``tag_`` entries. Anything else raises + :class:`ValueError`. Empty values for required-non-empty keys raise + the same. Duplicate keys → last-wins. + + :param raw: The verbatim ``sasl.oauthbearer.config`` string. + :param sasl_extensions: Pre-parsed dict from the sibling + ``sasl.oauthbearer.extensions`` property (see + :mod:`._sasl_extensions_parser`). Stored on the config + unchanged. + :raises TypeError: ``raw`` is ``None``. + :raises ValueError: grammar, range, or enum violations. + """ + if raw is None: + raise TypeError("raw must not be None") + + region: Optional[str] = None + audience: Optional[str] = None + signing_algorithm: str = DEFAULT_SIGNING_ALGORITHM + duration_seconds: int = DEFAULT_DURATION_SECONDS + sts_endpoint: Optional[str] = None + aws_debug: str = AWS_DEBUG_NONE + tags: Optional[Dict[str, str]] = None + + for key, value in parse_key_values(raw, ",", CONFIG_KEY): + if key in _NON_EMPTY_KEYS and value == "": + raise ValueError(f"{CONFIG_KEY} {key!r} must not be empty.") + + if key == "region": + region = value + elif key == "audience": + audience = value + elif key == "signing_algorithm": + signing_algorithm = value + elif key == "duration_seconds": + try: + duration_seconds = int(value) + except ValueError as exc: + raise ValueError(f"{CONFIG_KEY} 'duration_seconds' must be an integer; " f"got {value!r}.") from exc + elif key == "sts_endpoint": + sts_endpoint = value + elif key == "aws_debug": + # Normalize case so downstream comparisons against + # ALLOWED_AWS_DEBUG_VALUES are straightforward. + aws_debug = value.lower() + elif key.startswith(TAG_KEY_PREFIX): + tag_name = key[len(TAG_KEY_PREFIX) :] + if tag_name == "": + raise ValueError(f"{CONFIG_KEY} tag key {key!r} has empty name.") + if tags is None: + tags = {} + tags[tag_name] = value # last-wins on duplicate tag names + else: + raise ValueError(f"Unknown key {key!r} in {CONFIG_KEY}.") + + if region is None: + raise ValueError(f"'region' is required in {CONFIG_KEY}.") + if audience is None: + raise ValueError(f"'audience' is required in {CONFIG_KEY}.") + + return cls( + region=region, + audience=audience, + signing_algorithm=signing_algorithm, + duration_seconds=duration_seconds, + sts_endpoint=sts_endpoint, + aws_debug=aws_debug, + tags=tags, + sasl_extensions=sasl_extensions, + ) diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py new file mode 100644 index 000000000..1ca0f6176 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py @@ -0,0 +1,172 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: Fetches OAUTHBEARER tokens via AWS STS GetWebIdentityToken.""" + +import logging +from typing import Any, Dict, Optional, Tuple + +import boto3 + +from . import _jwt_extractor +from ._aws_oauthbearer_config import ( + AWS_DEBUG_CONSOLE, + AwsOAuthBearerConfig, +) + +__all__ = ["AwsStsTokenProvider"] + + +# Logger name targeted by ``aws_debug=console``. Routes botocore's HTTP / +# credential-chain / signing diagnostic logs to stderr at DEBUG level. +_BOTOCORE_LOGGER_NAME = "botocore" + +#: Minimum boto3 version required by the AWS IAM path. +#: ``requirements/requirements-oauthbearer-aws.txt``. +MINIMUM_BOTO3_VERSION = "1.42.25" + + +def _version_tuple(version: str) -> Tuple[int, ...]: + """Parse a dotted version string into a tuple of its leading integers. + + Tolerant of pre-release suffixes (``"1.42.0rc1"`` -> ``(1, 42, 0)``) so the + comparison considers only the numeric ``major.minor.micro`` components. + """ + parts = [] + for segment in version.split("."): + digits = "" + for ch in segment: + if not ch.isdigit(): + break + digits += ch + parts.append(int(digits) if digits else 0) + return tuple(parts) + + +def _require_boto3_version() -> None: + """Raise :class:`ImportError` if the installed boto3 predates + :data:`MINIMUM_BOTO3_VERSION`. + + """ + if _version_tuple(boto3.__version__) < _version_tuple(MINIMUM_BOTO3_VERSION): + raise ImportError( + f"The AWS IAM OAUTHBEARER path requires boto3>={MINIMUM_BOTO3_VERSION} " + f"(for the STS GetWebIdentityToken operation), but found boto3 " + f"{boto3.__version__}. Upgrade with: " + f"pip install -U 'confluent-kafka[oauthbearer-aws]'." + ) + + +class AwsStsTokenProvider: + """Mints OAUTHBEARER tokens via AWS STS ``GetWebIdentityToken``.""" + + def __init__( + self, + config: AwsOAuthBearerConfig, + sts_client: Optional[Any] = None, + ) -> None: + """Construct a provider bound to ``config``. + + :param config: Validated :class:`AwsOAuthBearerConfig` instance. + :param sts_client: Test seam — when supplied, the provider uses this + client directly instead of constructing a real boto3 STS client. + Production callers pass ``None``. + :raises TypeError: ``config`` is ``None``. + :raises ImportError: the installed boto3 is older than + :data:`MINIMUM_BOTO3_VERSION` (checked only on the real-client + path, i.e. when ``sts_client`` is ``None``). + """ + if config is None: + raise TypeError("config must not be None") + self._cfg = config + + self._apply_aws_debug(config.aws_debug) + + if sts_client is not None: + self._sts = sts_client + else: + # Fail fast with a clear message if boto3 is present but too old for + # the STS GetWebIdentityToken operation. + _require_boto3_version() + session = boto3.Session(region_name=config.region) + client_kwargs: Dict[str, Any] = {"region_name": config.region} + if config.sts_endpoint: + client_kwargs["endpoint_url"] = config.sts_endpoint + self._sts = session.client("sts", **client_kwargs) + + @staticmethod + def _apply_aws_debug(aws_debug: str) -> None: + """Apply the ``aws_debug`` side-effect to botocore's logger. + + Process-wide effect, intentionally. When the user opts in with + ``aws_debug=console``, every boto3 client in the process gets + DEBUG-level stderr logs. ``aws_debug=none`` is a no-op so any + logging the user has configured elsewhere is preserved. + """ + if aws_debug == AWS_DEBUG_CONSOLE: + boto3.set_stream_logger(_BOTOCORE_LOGGER_NAME, logging.DEBUG) + # AWS_DEBUG_NONE → no-op. Other values are rejected by config validation. + + def token( + self, + oauthbearer_config: str = "", + ) -> Tuple[str, float, str, Dict[str, str]]: + """Mint a fresh JWT and return the ``oauth_cb`` 4-tuple. + + :param oauthbearer_config: The verbatim ``sasl.oauthbearer.config`` + string librdkafka passes back on every refresh. Accepted for + interface completeness but unused — the AWS path's fields are + sourced from the bound :class:`AwsOAuthBearerConfig` at + construction time, not re-parsed per refresh. + + :returns: 4-tuple ``(token, expiry_epoch_seconds, principal, + extensions)`` matching the C ``oauth_cb`` contract. + + :raises botocore.exceptions.ClientError: STS-side error + (``AccessDenied``, ``OutboundWebIdentityFederationDisabled``, + ...). The C ``oauth_cb`` wrapper converts raised exceptions + into ``rd_kafka_oauthbearer_set_token_failure``. + :raises ValueError: STS returned a malformed JWT or missing + ``Expiration``. + """ + request_kwargs: Dict[str, Any] = { + "Audience": [self._cfg.audience], + "SigningAlgorithm": self._cfg.signing_algorithm, + "DurationSeconds": self._cfg.duration_seconds, + } + if self._cfg.tags: + request_kwargs["Tags"] = [{"Key": k, "Value": v} for k, v in self._cfg.tags.items()] + + response = self._sts.get_web_identity_token(**request_kwargs) + + jwt = response.get("WebIdentityToken") + if not isinstance(jwt, str) or not jwt: + raise ValueError("STS response missing WebIdentityToken; cannot mint OAUTHBEARER token.") + + expiration = response.get("Expiration") + if expiration is None: + raise ValueError("STS response missing Expiration; cannot compute token lifetime.") + # boto3 normalises the timestamp to a tz-aware UTC datetime; + # .timestamp() returns epoch seconds as a float. + expiry_epoch_seconds = expiration.timestamp() + + principal = _jwt_extractor.extract_sub(jwt) + + # Always return a dict for the extensions slot — the C oauth_cb + # wrapper's PyArg_ParseTuple uses "O!" with PyDict_Type for that slot, + # which would reject None. Empty dict is the Pythonic equivalent of + # .NET's null-Extensions case. + extensions = dict(self._cfg.sasl_extensions) if self._cfg.sasl_extensions else {} + + return jwt, expiry_epoch_seconds, principal, extensions diff --git a/src/confluent_kafka/oauthbearer/aws/_jwt_extractor.py b/src/confluent_kafka/oauthbearer/aws/_jwt_extractor.py new file mode 100644 index 000000000..5b318f25a --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_jwt_extractor.py @@ -0,0 +1,90 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: extracts the ``sub`` claim from an unverified JWT. + +No signature verification — STS signs, broker validates. +""" + +import base64 +import binascii +import json + +__all__ = ["extract_sub"] + + +_MAX_TOKEN_LENGTH_CHARS: int = 8192 + + +def extract_sub(jwt: str) -> str: + """Return the ``sub`` claim from the JWT payload. + + :raises ValueError: ``jwt`` is null, empty, oversized, has the wrong + segment count, fails base64url decoding, isn't valid JSON, isn't a + JSON object, or has no ``sub`` string claim (or its value is empty). + """ + if jwt is None: + raise ValueError("JWT is null.") + if jwt == "": + raise ValueError("JWT is empty.") + if len(jwt) > _MAX_TOKEN_LENGTH_CHARS: + raise ValueError(f"JWT length {len(jwt)} exceeds maximum allowed " f"({_MAX_TOKEN_LENGTH_CHARS}).") + + parts = jwt.split(".") + if len(parts) != 3: + raise ValueError(f"JWT must have exactly 3 '.'-separated segments; got {len(parts)}.") + + payload_bytes = _decode_base64url_segment(parts[1]) + try: + payload_string = payload_bytes.decode("utf-8") + except UnicodeDecodeError as exc: + raise ValueError(f"JWT payload is not valid UTF-8: {exc}") from exc + + try: + token = json.loads(payload_string) + except json.JSONDecodeError as exc: + raise ValueError(f"JWT payload is not valid JSON: {exc}") from exc + + if not isinstance(token, dict): + raise ValueError("JWT payload is not a JSON object.") + + if "sub" not in token: + raise ValueError("JWT payload is missing a 'sub' string claim.") + sub = token["sub"] + if not isinstance(sub, str): + raise ValueError("JWT payload is missing a 'sub' string claim.") + if sub == "": + raise ValueError("JWT 'sub' claim value is empty.") + return sub + + +def _decode_base64url_segment(segment: str) -> bytes: + if len(segment) == 0: + raise ValueError("JWT payload segment is empty.") + + s = segment.replace("-", "+").replace("_", "/") + remainder = len(s) % 4 + if remainder == 0: + pass + elif remainder == 2: + s += "==" + elif remainder == 3: + s += "=" + else: + raise ValueError("JWT payload segment has invalid base64url length.") + + try: + return base64.b64decode(s.encode("ascii"), validate=True) + except (binascii.Error, UnicodeEncodeError, ValueError) as exc: + raise ValueError(f"JWT payload segment is not valid base64url: {exc}") from exc diff --git a/src/confluent_kafka/oauthbearer/aws/_sasl_extensions_parser.py b/src/confluent_kafka/oauthbearer/aws/_sasl_extensions_parser.py new file mode 100644 index 000000000..4bb43856c --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_sasl_extensions_parser.py @@ -0,0 +1,50 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: parser for the ``sasl.oauthbearer.extensions`` config property. + +The ``sasl.oauthbearer.extensions`` config carries RFC 7628 SASL +extensions as a comma-separated ``key=value`` list. +""" + +from typing import Dict, Optional + +from confluent_kafka._util.librdkafka_string_parser import parse_key_values + +__all__ = ["CONFIG_KEY", "parse"] + + +#: Config key carrying the SASL extensions list. +CONFIG_KEY: str = "sasl.oauthbearer.extensions" + + +def parse(raw: Optional[str]) -> Optional[Dict[str, str]]: + """Parse the verbatim ``sasl.oauthbearer.extensions`` value into a dict. + + The grammar mirrors the cross-language convention for consistency. + + Returns ``None`` for ``None`` / empty input so the autowire layer can + short-circuit without constructing an empty dict. + + :raises ValueError: A token is missing ``=`` or has an empty key. + """ + if raw is None or raw == "": + return None + + result: Dict[str, str] = {} + for key, value in parse_key_values(raw, ",", CONFIG_KEY): + # Last-wins on duplicate keys, mirroring librdkafka. + result[key] = value + + return result if result else None diff --git a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py new file mode 100644 index 000000000..b29ba05d5 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py @@ -0,0 +1,93 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Public entry-point for AWS IAM OAUTHBEARER autowire. + +This is the **only publicly importable name** in the optional subpackage. +End users do not call :func:`create_handler` directly — the C dispatcher in +``src/confluent_kafka/src/confluent_kafka.c`` reaches it via:: + + PyImport_ImportModule("confluent_kafka.oauthbearer.aws.aws_autowire") + +and resolves :func:`create_handler` by name. The marker-key check is +performed in core; :func:`create_handler` is invoked only when the C +dispatcher has decided to autowire the AWS path. + +User-facing contract — four config keys:: + + "sasl.oauthbearer.method": "oidc" + "sasl.oauthbearer.metadata.authentication.type": "aws_iam" + "sasl.oauthbearer.config": "region=...,audience=..." + "sasl.oauthbearer.extensions": "key=val,..." # optional + +Frozen cross-module contract of :func:`create_handler`: + +* arity: 2 positional parameters +* names: ``sasl_oauthbearer_config``, ``sasl_oauthbearer_extensions`` +* types: ``str``, ``Optional[str]`` +* return: :data:`OAuthBearerCallback` + +Bumping any of these is a breaking change requiring a major version +increment on the ``confluent-kafka`` distribution. Test-guarded by +``tests/oauthbearer/aws/test_contract.py``. +""" + +from typing import Callable, Dict, Optional, Tuple + +from . import _sasl_extensions_parser +from ._aws_iam_marker import AWS_IAM_MARKER_KEY, AWS_IAM_MARKER_VALUE +from ._aws_oauthbearer_config import CONFIG_KEY, AwsOAuthBearerConfig +from ._aws_sts_token_provider import AwsStsTokenProvider + +__all__ = ["create_handler", "OAuthBearerCallback"] + +OAuthBearerCallback = Callable[[str], Tuple[str, float, str, Dict[str, str]]] + + +def create_handler( + sasl_oauthbearer_config: str, + sasl_oauthbearer_extensions: Optional[str], +) -> OAuthBearerCallback: + """Build an OAUTHBEARER refresh callback from the two OAUTHBEARER config strings. + + :param sasl_oauthbearer_config: The verbatim ``sasl.oauthbearer.config`` + value (whitespace-separated ``key=value`` pairs). Must be non-empty. + :param sasl_oauthbearer_extensions: The verbatim + ``sasl.oauthbearer.extensions`` value (comma-separated ``key=value`` + pairs, RFC 7628 §3.1). May be ``None`` or empty when the user has + no extensions configured. + + :returns: A callable matching :data:`OAuthBearerCallback`. + + :raises ValueError: ``sasl_oauthbearer_config`` is ``None`` or empty; + the wire-grammar parse fails (unknown key, malformed token, missing + required field, range/enum violation, etc.). + :raises ImportError: the installed boto3 predates the minimum required by + the AWS IAM path (boto3 is present but too old for STS + ``GetWebIdentityToken``). + :raises RuntimeError: AWS SDK reachability or initialisation failure + (e.g. unknown region, malformed ``sts_endpoint``). + """ + if not sasl_oauthbearer_config: + raise ValueError( + f"'{AWS_IAM_MARKER_KEY}={AWS_IAM_MARKER_VALUE}' is set but " + f"'{CONFIG_KEY}' is missing or empty. The AWS IAM autowire path " + f"requires region and audience to be supplied via " + f"{CONFIG_KEY} (e.g. \"region=us-east-1,audience=https://...\")." + ) + + sasl_extensions = _sasl_extensions_parser.parse(sasl_oauthbearer_extensions) + config = AwsOAuthBearerConfig.parse(sasl_oauthbearer_config, sasl_extensions) + provider = AwsStsTokenProvider(config) + return provider.token diff --git a/src/confluent_kafka/src/confluent_kafka.c b/src/confluent_kafka/src/confluent_kafka.c index 0a2c58495..7b02e2232 100644 --- a/src/confluent_kafka/src/confluent_kafka.c +++ b/src/confluent_kafka/src/confluent_kafka.c @@ -2580,11 +2580,165 @@ static void common_conf_set_software(rd_kafka_conf_t *conf) { /** - * Common config setup for Kafka client handles. + * @brief Detect the aws_iam OAUTHBEARER autowire marker in the user's config + * dict and, before the librdkafka handle is created, register an + * oauth_cb sourced from the optional confluent_kafka.oauthbearer.aws + * subpackage. The aws_iam marker and method=oidc are passed through to + * librdkafka UNCHANGED. * - * Returns a conf object on success or NULL on failure in which case - * an exception has been raised. + * User contract (all three keys required when the marker is set and no explicit + * oauth_cb is supplied): + * sasl.oauthbearer.method = "oidc" + * sasl.oauthbearer.metadata.authentication.type = "aws_iam" + * sasl.oauthbearer.config = "region=...,audience=..." + * + * Plus optional: + * sasl.oauthbearer.extensions = "key=val,key=val" + * + * + * If the user supplies their own oauth_cb, autowire is skipped entirely and the + * config is left untouched (their callback, the marker, and method all pass + * through to librdkafka). + * + * @returns 0 on success (no-op or autowire complete), -1 on error + * (PyErr_* is set; caller goto outer_err). */ +static int resolve_aws_oauthbearer_marker(PyObject *confdict) { + static const char MARKER_KEY[] = + "sasl.oauthbearer.metadata.authentication.type"; + static const char MARKER_VALUE[] = "aws_iam"; + static const char METHOD_KEY[] = "sasl.oauthbearer.method"; + static const char METHOD_OIDC_VALUE[] = "oidc"; + static const char CONFIG_KEY[] = "sasl.oauthbearer.config"; + static const char EXTENSIONS_KEY[] = "sasl.oauthbearer.extensions"; + static const char OAUTH_CB_KEY[] = "oauth_cb"; + static const char AUTOWIRE_MODULE[] = + "confluent_kafka.oauthbearer.aws.aws_autowire"; + static const char CREATE_HANDLER[] = "create_handler"; + static const char METHOD_REQUIREMENT_ERR[] = + "'sasl.oauthbearer.metadata.authentication.type=aws_iam' requires " + "'sasl.oauthbearer.method=oidc'. Current value: %s. method=oidc is " + "mandatory for the AWS IAM authentication path."; + static const char CONFIG_REQUIREMENT_ERR[] = + "'sasl.oauthbearer.metadata.authentication.type=aws_iam' is set " + "but 'sasl.oauthbearer.config' is missing or empty. The AWS IAM " + "autowire path requires region and audience to be supplied via " + "sasl.oauthbearer.config " + "(e.g. \"region=us-east-1,audience=https://...\")."; + static const char FRIENDLY_IMPORT_ERR[] = + "Config 'sasl.oauthbearer.metadata.authentication.type=aws_iam' " + "requires the optional 'oauthbearer-aws' extra. Install with:\n" + " pip install 'confluent-kafka[oauthbearer-aws]'"; + + PyObject *marker; + PyObject *cb; + PyObject *method; + PyObject *cfg_str; + PyObject *ext_str; + PyObject *mod; + PyObject *func; + PyObject *callback; + const char *marker_c; + const char *method_c; + + /* Explicit oauth_cb wins: nothing to autowire, regardless of the marker. */ + cb = PyDict_GetItemString(confdict, OAUTH_CB_KEY); + if (cb && cb != Py_None) { + return 0; + } + + marker = PyDict_GetItemString(confdict, MARKER_KEY); + if (!marker || !PyUnicode_Check(marker)) { + return 0; + } + marker_c = PyUnicode_AsUTF8(marker); + if (!marker_c) { + /* Non-ASCII / malformed unicode — treat as not-our-value. */ + PyErr_Clear(); + return 0; + } + if (strcmp(marker_c, MARKER_VALUE) != 0) { + return 0; + } + + method = PyDict_GetItemString(confdict, METHOD_KEY); + method_c = (method && PyUnicode_Check(method)) + ? PyUnicode_AsUTF8(method) + : NULL; + if (!method_c || strcmp(method_c, METHOD_OIDC_VALUE) != 0) { + char actual_buf[128]; + const char *actual; + if (!method) { + actual = ""; + } else if (!method_c) { + PyErr_Clear(); + actual = ""; + } else { + snprintf(actual_buf, sizeof(actual_buf), "'%s'", + method_c); + actual = actual_buf; + } + PyErr_Format(PyExc_ValueError, METHOD_REQUIREMENT_ERR, actual); + return -1; + } + + cfg_str = PyDict_GetItemString(confdict, CONFIG_KEY); + if (!cfg_str || !PyUnicode_Check(cfg_str) || + PyUnicode_GET_LENGTH(cfg_str) == 0) { + PyErr_SetString(PyExc_ValueError, CONFIG_REQUIREMENT_ERR); + return -1; + } + + ext_str = PyDict_GetItemString(confdict, EXTENSIONS_KEY); + + mod = PyImport_ImportModule(AUTOWIRE_MODULE); + if (!mod) { + PyObject *cause_type = NULL; + PyObject *cause_value = NULL; + PyObject *cause_tb = NULL; + PyObject *new_exc; + + PyErr_Fetch(&cause_type, &cause_value, &cause_tb); + PyErr_NormalizeException(&cause_type, &cause_value, &cause_tb); + + new_exc = PyObject_CallFunction( + PyExc_ImportError, "s", FRIENDLY_IMPORT_ERR); + if (new_exc) { + if (cause_value) { + Py_INCREF(cause_value); + PyException_SetCause(new_exc, cause_value); + } + PyErr_SetObject(PyExc_ImportError, new_exc); + Py_DECREF(new_exc); + } + Py_XDECREF(cause_type); + Py_XDECREF(cause_value); + Py_XDECREF(cause_tb); + return -1; + } + + func = PyObject_GetAttrString(mod, CREATE_HANDLER); + Py_DECREF(mod); + if (!func) { + return -1; + } + callback = PyObject_CallFunction( + func, "OO", cfg_str, ext_str ? ext_str : Py_None); + Py_DECREF(func); + if (!callback) { + return -1; + } + if (PyDict_SetItemString(confdict, OAUTH_CB_KEY, callback) == -1) { + Py_DECREF(callback); + return -1; + } + Py_DECREF(callback); + + + return 0; +} + + rd_kafka_conf_t *common_conf_setup(rd_kafka_type_t ktype, Handle *h, PyObject *args, @@ -2703,6 +2857,16 @@ rd_kafka_conf_t *common_conf_setup(rd_kafka_type_t ktype, PyDict_DelItemString(confdict, "default.topic.config"); } + /* AWS IAM OAUTHBEARER autowire: when the user sets the marker + * sasl.oauthbearer.metadata.authentication.type=aws_iam, wire an + * oauth_cb sourced from the optional oauthbearer-aws extra; the aws_iam + * marker and method=oidc are passed through to librdkafka unchanged. + * No-op when the marker is absent or an explicit oauth_cb is already + * set. See resolve_aws_oauthbearer_marker above for the full flow. */ + if (resolve_aws_oauthbearer_marker(confdict) == -1) { + goto outer_err; + } + /* Convert config dict to config key-value pairs. */ while (PyDict_Next(confdict, &pos, &ko, &vo)) { PyObject *ks; diff --git a/src/confluent_kafka/src/confluent_kafka.h b/src/confluent_kafka/src/confluent_kafka.h index 0a4123b60..d11d2bda2 100644 --- a/src/confluent_kafka/src/confluent_kafka.h +++ b/src/confluent_kafka/src/confluent_kafka.h @@ -38,7 +38,7 @@ /** * @brief confluent-kafka-python version, must match that of pyproject.toml. */ -#define CFL_VERSION_STR "2.14.2" +#define CFL_VERSION_STR "2.14.2.dev5" /** * Minimum required librdkafka version. This is checked both during diff --git a/tests/_util/__init__.py b/tests/_util/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/_util/test_librdkafka_string_parser.py b/tests/_util/test_librdkafka_string_parser.py new file mode 100644 index 000000000..1cd75db6b --- /dev/null +++ b/tests/_util/test_librdkafka_string_parser.py @@ -0,0 +1,226 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka._util.librdkafka_string_parser. + +The Split_* cases are ported VERBATIM from librdkafka's own ut_string_split +(src/rdstring.c) to lock byte-for-byte parity with rd_string_split. The +remaining cases cover the documented edge behaviour and the +rd_kafka_conf_kv_split key/value semantics. +""" + +import pytest + +from confluent_kafka._util.librdkafka_string_parser import parse_key_values, split + +# --------------------------------------------------------------------------- +# split() — ported verbatim from librdkafka ut_string_split (src/rdstring.c). +# Python string literals are written with explicit "\\" per literal backslash; +# the decoded form is shown in a comment where non-obvious. +# --------------------------------------------------------------------------- + + +def test_split_single_field(): + assert split("just one field", ",", skip_empty=True) == ["just one field"] + + +def test_split_empty_skip_empty_no_fields(): + assert split("", ",", skip_empty=True) == [] + + +def test_split_empty_no_skip_empty_one_empty_field(): + assert split("", ",", skip_empty=False) == [""] + + +def test_split_whitespace_and_empties_skip_empty(): + assert split(", a,b ,,c, d, e,f,ghijk, lmn,opq , r s t u, v", ",", skip_empty=True) == [ + "a", + "b", + "c", + "d", + "e", + "f", + "ghijk", + "lmn", + "opq", + "r s t u", + "v", + ] + + +def test_split_whitespace_and_empties_no_skip_empty(): + assert split(", a,b ,,c, d, e,f,ghijk, lmn,opq , r s t u, v", ",", skip_empty=False) == [ + "", + "a", + "b", + "", + "c", + "d", + "e", + "f", + "ghijk", + "lmn", + "opq", + "r s t u", + "v", + ] + + +def test_split_escapes_quoted_separators_and_backslashes(): + # Decoded input: « this is an \,escaped comma,\,,\\, and this is an + # unbalanced escape: \\\\\\\» (7 trailing backslashes) + raw = " this is an \\,escaped comma,\\,,\\\\, and this is an unbalanced escape: \\\\\\\\\\\\\\" + assert split(raw, ",", skip_empty=True) == [ + "this is an ,escaped comma", + ",", + "\\", # one literal backslash + "and this is an unbalanced escape: \\\\\\", # three literal backslashes + ] + + +def test_split_alternate_separator_pipe_no_skip_empty(): + # Decoded input: «using|another ||\|d|elimiter» + assert split("using|another ||\\|d|elimiter", "|", skip_empty=False) == [ + "using", + "another", + "", + "|d", + "elimiter", + ] + + +# --------------------------------------------------------------------------- +# split() — documented edge behaviour (cross-checked against .NET's port). +# --------------------------------------------------------------------------- + + +def test_split_escape_substitutions_tab_newline_cr_nul(): + # \t \n \r \0 substitute to TAB/LF/CR/NUL; kept internal so the + # trailing-trim never touches them. + assert split("a\\tb\\nc\\rd\\0e", ",", skip_empty=True) == ["a\tb\nc\rd\0e"] + + +def test_split_unicode_whitespace_not_trimmed(): + # U+00A0 (NBSP) is Unicode whitespace but NOT ASCII isspace, so neither the + # leading nor trailing NBSP is trimmed — proving we match C isspace, not + # str.isspace (which WOULD trim them, diverging from librdkafka). + assert split(" a ", ",", skip_empty=True) == [" a "] + + +def test_split_escaped_whitespace_leading_kept_trailing_trimmed(): + # Asymmetry copied from librdkafka: leading whitespace is stripped only when + # UNescaped, so an escaped \t at the start survives; trailing whitespace is + # stripped UNCONDITIONALLY, so an escaped \t at the end is removed. + assert split("\\tx", ",", skip_empty=True) == ["\tx"] + assert split("x\\t", ",", skip_empty=True) == ["x"] + + +def test_split_whitespace_only_trims_to_empty(): + # An all-whitespace field trims to empty: skipped when skip_empty, else + # returned as a single empty field. + assert split(" ", ",", skip_empty=True) == [] + assert split(" ", ",", skip_empty=False) == [""] + + +def test_split_internal_whitespace_preserved(): + assert split(" r s t u ", ",", skip_empty=True) == ["r s t u"] + + +def test_split_dangling_trailing_backslash_dropped(): + # A trailing unbalanced escape has nothing to escape and is dropped. + assert split("abc\\", ",", skip_empty=True) == ["abc"] + + +def test_split_none_raises_type_error(): + with pytest.raises(TypeError): + split(None, ",", skip_empty=True) + + +# --------------------------------------------------------------------------- +# parse_key_values() — rd_kafka_conf_kv_split semantics. +# --------------------------------------------------------------------------- + + +def test_parse_kv_basic_pairs(): + assert parse_key_values("region=us-east-1,audience=https://x", ",", "cfg") == [ + ("region", "us-east-1"), + ("audience", "https://x"), + ] + + +def test_parse_kv_splits_on_first_equals(): + assert parse_key_values("a=b=c", ",", "cfg") == [("a", "b=c")] + + +def test_parse_kv_empty_value_allowed(): + assert parse_key_values("a=", ",", "cfg") == [("a", "")] + + +def test_parse_kv_escaped_comma_in_value_stays_one_entry(): + # Headline case: identityPoolId as a comma-separated list, commas quoted + # with a backslash so the value survives as a single entry. + assert parse_key_values("identityPoolId=pool-1\\,pool-2", ",", "cfg") == [ + ("identityPoolId", "pool-1,pool-2"), + ] + + +def test_parse_kv_no_equals_raises_with_context_label(): + with pytest.raises(ValueError, match="cfg"): + parse_key_values("abc", ",", "cfg") + + +def test_parse_kv_empty_key_raises(): + with pytest.raises(ValueError, match="Malformed"): + parse_key_values("=value", ",", "cfg") + + +def test_parse_kv_none_context_label_uses_generic_phrase(): + with pytest.raises(ValueError, match="key=value"): + parse_key_values("abc", ",") + + +def test_parse_kv_duplicate_keys_both_returned(): + # Last-wins dedup is the caller's responsibility; the parser returns both. + assert parse_key_values("a=1,a=2", ",", "cfg") == [("a", "1"), ("a", "2")] + + +def test_parse_kv_consecutive_separators_skipped(): + assert parse_key_values("a=1,,b=2,", ",", "cfg") == [("a", "1"), ("b", "2")] + + +def test_parse_kv_full_config_all_pairs_in_order(): + pairs = parse_key_values( + "region=us-east-1,audience=https://a," + "duration_seconds=900,signing_algorithm=RS256," + "sts_endpoint=https://sts.us-east-1.amazonaws.com," + "aws_debug=none," + "tag_team=platform,tag_environment=prod", + ",", + "SaslOauthbearerConfig", + ) + assert pairs == [ + ("region", "us-east-1"), + ("audience", "https://a"), + ("duration_seconds", "900"), + ("signing_algorithm", "RS256"), + ("sts_endpoint", "https://sts.us-east-1.amazonaws.com"), + ("aws_debug", "none"), + ("tag_team", "platform"), + ("tag_environment", "prod"), + ] + + +def test_parse_kv_none_raises_type_error(): + with pytest.raises(TypeError): + parse_key_values(None, ",") diff --git a/tests/oauthbearer/__init__.py b/tests/oauthbearer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/oauthbearer/aws/__init__.py b/tests/oauthbearer/aws/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/oauthbearer/aws/test_aws_autowire.py b/tests/oauthbearer/aws/test_aws_autowire.py new file mode 100644 index 000000000..91ff36fec --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_autowire.py @@ -0,0 +1,261 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws.aws_autowire.create_handler. + +End-to-end orchestration tests. The frozen-signature contract guard lives +in test_contract.py. +""" + +import pytest + +pytest.importorskip("boto3") + +from confluent_kafka.oauthbearer.aws.aws_autowire import create_handler # noqa: E402 + +# ---- Input validation (defensive checks for direct callers) ---- + + +def test_create_handler_null_config_raises(): + with pytest.raises(ValueError, match="missing or empty"): + create_handler(None, None) + + +def test_create_handler_empty_config_raises(): + with pytest.raises(ValueError, match="missing or empty"): + create_handler("", None) + + +def test_create_handler_error_message_names_marker_and_config_key(): + with pytest.raises(ValueError) as exc_info: + create_handler(None, None) + # User should see both: the marker that's set AND the config key that's missing. + assert "sasl.oauthbearer.metadata.authentication.type" in str(exc_info.value) + assert "aws_iam" in str(exc_info.value) + assert "sasl.oauthbearer.config" in str(exc_info.value) + + +# ---- Parse delegation: errors from AwsOAuthBearerConfig.parse propagate ---- + + +def test_create_handler_missing_region_raises(): + with pytest.raises(ValueError, match="region.*required"): + create_handler("audience=https://a", None) + + +def test_create_handler_missing_audience_raises(): + with pytest.raises(ValueError, match="audience.*required"): + create_handler("region=us-east-1", None) + + +def test_create_handler_invalid_signing_algorithm_raises(): + with pytest.raises(ValueError, match="signing_algorithm"): + create_handler( + "region=us-east-1,audience=https://a,signing_algorithm=HS256", + None, + ) + + +def test_create_handler_invalid_duration_raises(): + with pytest.raises(ValueError, match="duration_seconds"): + create_handler( + "region=us-east-1,audience=https://a,duration_seconds=10", + None, + ) + + +def test_create_handler_unknown_key_raises(): + with pytest.raises(ValueError, match="not_a_key"): + create_handler( + "region=us-east-1,audience=https://a,not_a_key=foo", + None, + ) + + +def test_create_handler_invalid_extensions_grammar_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + create_handler( + "region=us-east-1,audience=https://a", + "noEqualsHere", + ) + + +def test_create_handler_aws_debug_invalid_value_raises(): + """An unsupported aws_debug value is rejected — the client accepts none/console only.""" + with pytest.raises(ValueError, match="aws_debug.*none, console"): + create_handler( + "region=us-east-1,audience=https://a,aws_debug=log4net", + None, + ) + + +# ---- Success cases — handler returned, no throw, no STS call yet ---- + + +def test_create_handler_marker_only_minimum_config_returns_handler(): + handler = create_handler( + "region=us-east-1,audience=https://a", + None, + ) + assert handler is not None + assert callable(handler) + + +def test_create_handler_all_optional_fields_returns_handler(): + handler = create_handler( + "region=us-east-1,audience=https://a," + "duration_seconds=900,signing_algorithm=RS256," + "sts_endpoint=https://sts.us-east-1.amazonaws.com," + "aws_debug=none," + "tag_team=platform,tag_environment=prod", + None, + ) + assert handler is not None + assert callable(handler) + + +def test_create_handler_tag_config_handler_ready(): + handler = create_handler( + "region=us-east-1,audience=https://a,tag_team=platform,tag_environment=prod", + None, + ) + assert callable(handler) + + +# ---- Extensions argument handling ---- + + +def test_create_handler_null_extensions_treats_as_absent(): + handler = create_handler( + "region=us-east-1,audience=https://a", + None, + ) + assert callable(handler) + + +def test_create_handler_empty_extensions_treats_as_absent(): + handler = create_handler( + "region=us-east-1,audience=https://a", + "", + ) + assert callable(handler) + + +def test_create_handler_single_extension_handler_ready(): + handler = create_handler( + "region=us-east-1,audience=https://a", + "logicalCluster=lkc-abc", + ) + assert callable(handler) + + +def test_create_handler_multiple_extensions_handler_ready(): + handler = create_handler( + "region=us-east-1,audience=https://a", + "logicalCluster=lkc-abc,identityPoolId=pool-x", + ) + assert callable(handler) + + +# ---- Returned callable invokes correctly with a real STS round-trip stubbed at the boto3 layer ---- +# We can't easily stub the boto3 client AwsStsTokenProvider creates internally without +# patching boto3.Session. End-to-end with-injected-client tests live in +# test_aws_sts_token_provider.py; here we cover the *closure* level (no STS call yet). + + +def test_create_handler_does_not_call_sts_at_construction(): + """create_handler must NOT make an STS call — credential resolution and + network I/O are deferred to the first invocation of the returned callable. + Verify by patching boto3.Session to detect any client.get_web_identity_token call. + """ + from unittest.mock import MagicMock, patch + + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + mock_client = MagicMock() + mock_session.client.return_value = mock_client + create_handler("region=us-east-1,audience=https://a", None) + # Session/client construction OK; get_web_identity_token must not have been called. + mock_client.get_web_identity_token.assert_not_called() + + +def test_create_handler_returned_callable_when_invoked_calls_sts(): + """When the returned callable is invoked, it triggers exactly one STS call.""" + import datetime + from unittest.mock import MagicMock, patch + + canned_response = { + "WebIdentityToken": _canned_jwt(), + "Expiration": datetime.datetime(2099, 4, 21, 6, 6, 47, tzinfo=datetime.timezone.utc), + } + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + mock_client = MagicMock() + mock_client.get_web_identity_token.return_value = canned_response + mock_session.client.return_value = mock_client + + handler = create_handler("region=us-east-1,audience=https://a", None) + result = handler("ignored-passthrough-string") + + mock_client.get_web_identity_token.assert_called_once() + assert isinstance(result, tuple) + assert len(result) == 4 + token, expiry, principal, extensions = result + assert token == canned_response["WebIdentityToken"] + assert isinstance(expiry, float) + assert principal.startswith("arn:") + assert extensions == {} + + +def test_create_handler_returned_callable_round_trips_extensions(): + """Extensions configured via the typed property flow through to the + 4-tuple's extensions slot.""" + import datetime + from unittest.mock import MagicMock, patch + + canned_response = { + "WebIdentityToken": _canned_jwt(), + "Expiration": datetime.datetime(2099, 4, 21, 6, 6, 47, tzinfo=datetime.timezone.utc), + } + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + mock_client = MagicMock() + mock_client.get_web_identity_token.return_value = canned_response + mock_session.client.return_value = mock_client + + handler = create_handler( + "region=us-east-1,audience=https://a", + "logicalCluster=lkc-abc,identityPoolId=pool-x", + ) + _, _, _, extensions = handler("") + + assert extensions == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +# ---- Test helpers ---- + + +def _base64url(data: bytes) -> str: + import base64 + + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _canned_jwt(sub: str = "arn:aws:iam::123:role/R") -> str: + header = _base64url(b'{"alg":"ES384","typ":"JWT"}') + payload = _base64url(f'{{"sub":"{sub}"}}'.encode("utf-8")) + return f"{header}.{payload}.sig" diff --git a/tests/oauthbearer/aws/test_aws_iam_marker.py b/tests/oauthbearer/aws/test_aws_iam_marker.py new file mode 100644 index 000000000..c421d4c81 --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_iam_marker.py @@ -0,0 +1,48 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Drift-guard tests for the AWS IAM marker constants.""" + +from confluent_kafka.oauthbearer.aws._aws_iam_marker import ( + AWS_IAM_MARKER_KEY, + AWS_IAM_MARKER_VALUE, +) + + +def test_marker_key_is_locked_value(): + assert AWS_IAM_MARKER_KEY == "sasl.oauthbearer.metadata.authentication.type" + + +def test_marker_value_is_locked_value(): + assert AWS_IAM_MARKER_VALUE == "aws_iam" + + +def test_c_dispatcher_recognises_python_authoritative_marker(): + import pytest + + from confluent_kafka import Producer + + with pytest.raises(ValueError, match="method=oidc"): + Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + AWS_IAM_MARKER_KEY: AWS_IAM_MARKER_VALUE, + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a", + # Deliberately omitting sasl.oauthbearer.method to trigger the + # dispatcher's precondition check. If the dispatcher's C-side + # marker literals differ from AWS_IAM_MARKER_KEY/VALUE, the + # dispatcher won't fire and this ValueError won't raise. + } + ) diff --git a/tests/oauthbearer/aws/test_aws_oauthbearer_config.py b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py new file mode 100644 index 000000000..b640a6304 --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py @@ -0,0 +1,363 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._aws_oauthbearer_config. + +The config grammar is comma-separated ``key=value`` pairs (librdkafka grammar), +matching the Azure IMDS convention and the reviewed .NET implementation. +""" + +import pytest + +from confluent_kafka.oauthbearer.aws._aws_oauthbearer_config import ( + AWS_DEBUG_CONSOLE, + AWS_DEBUG_NONE, + AwsOAuthBearerConfig, +) + +# ---- Required fields ---- + + +def test_parse_minimal_required_populates_region_and_audience(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + + +def test_parse_missing_region_raises(): + with pytest.raises(ValueError, match="region.*required"): + AwsOAuthBearerConfig.parse("audience=https://a") + + +def test_parse_missing_audience_raises(): + with pytest.raises(ValueError, match="audience.*required"): + AwsOAuthBearerConfig.parse("region=us-east-1") + + +def test_parse_null_input_raises(): + with pytest.raises(TypeError): + AwsOAuthBearerConfig.parse(None) + + +def test_parse_empty_input_raises_for_missing_region(): + with pytest.raises(ValueError, match="region.*required"): + AwsOAuthBearerConfig.parse("") + + +def test_parse_empty_value_on_required_key_raises(): + with pytest.raises(ValueError, match="region.*must not be empty"): + AwsOAuthBearerConfig.parse("region=,audience=https://a") + + +@pytest.mark.parametrize("key", ["signing_algorithm", "sts_endpoint"]) +def test_parse_empty_value_on_optional_key_raises(key): + with pytest.raises(ValueError, match=key): + AwsOAuthBearerConfig.parse(f"region=us-east-1,audience=https://a,{key}=") + + +# ---- Defaults applied during parse ---- + + +def test_parse_no_signing_algorithm_defaults_to_es384(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + assert cfg.signing_algorithm == "ES384" + + +def test_parse_no_duration_defaults_to_300_seconds(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + assert cfg.duration_seconds == 300 + + +def test_parse_optional_fields_default_to_none(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + assert cfg.sts_endpoint is None + assert cfg.sasl_extensions is None + assert cfg.tags is None + + +# ---- duration_seconds ---- + + +@pytest.mark.parametrize("seconds", [60, 300, 3600]) +def test_parse_duration_in_range_accepted(seconds): + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1,audience=https://a,duration_seconds={seconds}") + assert cfg.duration_seconds == seconds + + +@pytest.mark.parametrize("seconds", [0, 59, 3601, -10]) +def test_parse_duration_out_of_range_raises(seconds): + with pytest.raises(ValueError, match="duration_seconds.*must be between"): + AwsOAuthBearerConfig.parse(f"region=us-east-1,audience=https://a,duration_seconds={seconds}") + + +def test_parse_duration_not_integer_raises(): + with pytest.raises(ValueError, match="duration_seconds.*must be an integer"): + AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,duration_seconds=abc") + + +# ---- signing_algorithm ---- + + +@pytest.mark.parametrize("alg", ["ES384", "RS256"]) +def test_parse_allowed_signing_algorithm_accepted(alg): + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1,audience=https://a,signing_algorithm={alg}") + assert cfg.signing_algorithm == alg + + +@pytest.mark.parametrize("alg", ["HS256", "es384", "RS512"]) +def test_parse_disallowed_signing_algorithm_raises(alg): + with pytest.raises(ValueError, match="signing_algorithm.*ES384.*RS256"): + AwsOAuthBearerConfig.parse(f"region=us-east-1,audience=https://a,signing_algorithm={alg}") + + +# ---- aws_debug (none/console only) ---- + + +def test_parse_no_aws_debug_defaults_to_none(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + assert cfg.aws_debug == AWS_DEBUG_NONE + + +@pytest.mark.parametrize( + "value,expected", + [ + ("none", AWS_DEBUG_NONE), + ("console", AWS_DEBUG_CONSOLE), + ], +) +def test_parse_aws_debug_values_accepted(value, expected): + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1,audience=https://a,aws_debug={value}") + assert cfg.aws_debug == expected + + +@pytest.mark.parametrize( + "value,expected", + [ + ("Console", AWS_DEBUG_CONSOLE), + ("CONSOLE", AWS_DEBUG_CONSOLE), + ("NONE", AWS_DEBUG_NONE), + ], +) +def test_parse_aws_debug_case_insensitive_accepted(value, expected): + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1,audience=https://a,aws_debug={value}") + assert cfg.aws_debug == expected + + +# Every unsupported aws_debug value is rejected with the same uniform error. +@pytest.mark.parametrize( + "value", + ["verbose", "etw", "debug", "true", "foo", "log4net", "systemdiagnostics", "Log4Net", "SystemDiagnostics"], +) +def test_parse_aws_debug_invalid_value_raises(value): + with pytest.raises(ValueError, match="aws_debug.*none, console"): + AwsOAuthBearerConfig.parse(f"region=us-east-1,audience=https://a,aws_debug={value}") + + +def test_parse_aws_debug_empty_value_raises(): + with pytest.raises(ValueError, match="aws_debug.*must not be empty"): + AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,aws_debug=") + + +# ---- sts_endpoint ---- + + +def test_parse_sts_endpoint_stored_verbatim(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1,audience=https://a,sts_endpoint=https://sts-fips.us-east-1.amazonaws.com" + ) + assert cfg.sts_endpoint == "https://sts-fips.us-east-1.amazonaws.com" + + +def test_parse_principal_name_rejected_as_unknown_key(): + with pytest.raises(ValueError, match="Unknown key.*principal_name"): + AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,principal_name=x") + + +# ---- sasl_extensions argument (typed property pass-through) ---- + + +def test_parse_extension_prefix_in_config_rejected_as_unknown_key(): + """sasl extensions are accepted via typed sasl.oauthbearer.extensions + property only; embedded extension_* keys are rejected.""" + with pytest.raises(ValueError, match="extension_logicalCluster"): + AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,extension_logicalCluster=lkc-abc") + + +def test_parse_sasl_extensions_arg_stored_on_config(): + ext = {"logicalCluster": "lkc-abc", "identityPoolId": "pool-xyz"} + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a", ext) + assert cfg.sasl_extensions is ext + + +def test_parse_sasl_extensions_arg_null_keeps_sasl_extensions_null(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a", None) + assert cfg.sasl_extensions is None + + +# ---- tag_ ---- + + +def test_parse_single_tag_collected_into_tags(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,tag_team=platform") + assert cfg.tags == {"team": "platform"} + + +def test_parse_multiple_tags_all_collected(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,tag_team=platform,tag_environment=prod") + assert cfg.tags == {"team": "platform", "environment": "prod"} + + +def test_parse_empty_tag_name_raises(): + with pytest.raises(ValueError, match="empty name"): + AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,tag_=value") + + +def test_parse_empty_tag_value_accepted(): + # AWS allows tag values of 0 chars; mirror that. + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,tag_team=") + assert cfg.tags == {"team": ""} + + +def test_parse_duplicate_tag_name_last_wins(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,tag_team=infra,tag_team=platform") + assert cfg.tags == {"team": "platform"} + + +def test_parse_exactly_max_tags_accepted(): + parts = ["region=us-east-1", "audience=https://a"] + parts.extend(f"tag_k{i}=v{i}" for i in range(50)) + cfg = AwsOAuthBearerConfig.parse(",".join(parts)) + assert len(cfg.tags) == 50 + + +def test_parse_over_max_tags_raises(): + parts = ["region=us-east-1", "audience=https://a"] + parts.extend(f"tag_k{i}=v{i}" for i in range(51)) + with pytest.raises(ValueError, match="50"): + AwsOAuthBearerConfig.parse(",".join(parts)) + + +# ---- Unknown keys ---- + + +def test_parse_unknown_key_raises(): + with pytest.raises(ValueError, match="Unknown key.*not_a_key"): + AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,not_a_key=foo") + + +# ---- Comma grammar: whitespace handling, quoting, ordering ---- + + +def test_parse_whitespace_around_commas_tolerated(): + # librdkafka strips leading (unescaped) and trailing whitespace per field, + # so spaces around the comma separators are tolerated. + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 , audience=https://a , duration_seconds=600") + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + assert cfg.duration_seconds == 600 + + +def test_parse_leading_and_trailing_whitespace_tolerated(): + cfg = AwsOAuthBearerConfig.parse(" region=us-east-1,audience=https://a ") + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + + +def test_parse_escaped_comma_in_value_kept(): + # A backslash-quoted comma stays in the value rather than splitting the + # field (e.g. a tag value that is itself a comma-separated list). + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,tag_list=a\\,b\\,c") + assert cfg.tags == {"list": "a,b,c"} + + +def test_parse_order_invariant(): + cfg = AwsOAuthBearerConfig.parse("duration_seconds=600,audience=https://a,region=us-east-1,signing_algorithm=RS256") + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + assert cfg.duration_seconds == 600 + assert cfg.signing_algorithm == "RS256" + + +def test_parse_duplicate_key_last_wins(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,region=us-west-2") + assert cfg.region == "us-west-2" + + +# ---- Malformed entries ---- + + +def test_parse_no_equals_raises(): + with pytest.raises(ValueError, match="Malformed"): + AwsOAuthBearerConfig.parse("region-us-east-1,audience=https://a") + + +def test_parse_leading_equals_raises(): + with pytest.raises(ValueError, match="Malformed"): + AwsOAuthBearerConfig.parse("=value,audience=https://a,region=us-east-1") + + +# ---- Integration ---- + + +def test_parse_all_fields_together_all_populated_correctly(): + sasl_extensions = {"logicalCluster": "lkc-abc"} + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1," + "audience=https://confluent.cloud/oidc," + "duration_seconds=1800," + "signing_algorithm=RS256," + "sts_endpoint=https://sts.us-east-1.amazonaws.com," + "aws_debug=console," + "tag_team=platform", + sasl_extensions, + ) + + assert cfg.region == "us-east-1" + assert cfg.audience == "https://confluent.cloud/oidc" + assert cfg.duration_seconds == 1800 + assert cfg.signing_algorithm == "RS256" + assert cfg.sts_endpoint == "https://sts.us-east-1.amazonaws.com" + assert cfg.aws_debug == AWS_DEBUG_CONSOLE + assert cfg.sasl_extensions == {"logicalCluster": "lkc-abc"} + assert cfg.tags == {"team": "platform"} + + +# ---- Direct construction (bypassing parse) still validated ---- + + +def test_direct_construction_with_empty_region_raises(): + with pytest.raises(ValueError, match="region.*must not be empty"): + AwsOAuthBearerConfig(region="", audience="https://a") + + +def test_direct_construction_with_out_of_range_duration_raises(): + with pytest.raises(ValueError, match="duration_seconds.*must be between"): + AwsOAuthBearerConfig(region="us-east-1", audience="https://a", duration_seconds=10) + + +def test_direct_construction_with_bool_duration_rejected(): + """bool is-a int in Python; reject explicitly so True/False doesn't pass.""" + with pytest.raises(ValueError, match="duration_seconds.*must be an integer"): + AwsOAuthBearerConfig(region="us-east-1", audience="https://a", duration_seconds=True) + + +# ---- Surface invariants ---- + + +def test_aws_oauth_bearer_config_not_exposed_via_subpackage_init(): + """Public surface stays minimal — AwsOAuthBearerConfig is private to the + autowire layer (only `create_handler` is publicly importable).""" + import confluent_kafka.oauthbearer.aws as aws_pkg + + assert not hasattr(aws_pkg, "AwsOAuthBearerConfig") diff --git a/tests/oauthbearer/aws/test_aws_sts_token_provider.py b/tests/oauthbearer/aws/test_aws_sts_token_provider.py new file mode 100644 index 000000000..a07846688 --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_sts_token_provider.py @@ -0,0 +1,461 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._aws_sts_token_provider. + +- Python uses a tuple return (not a typed record) so the "no extensions → + null" assertion becomes "no extensions → empty dict" (the C oauth_cb + contract requires a dict, not None, at that slot). +- Python returns ``expiry_epoch_seconds`` (float seconds) not + ``LifetimeMs`` (long milliseconds) — the C wrapper multiplies by 1000. +""" + +import base64 +import datetime +import logging +from pathlib import Path +from typing import Any, Dict, Optional +from unittest.mock import patch + +import pytest + +# botocore is a transitive of boto3; available in opt-in venv. +pytest.importorskip("boto3") +pytest.importorskip("botocore") + +from botocore.exceptions import ClientError # noqa: E402 + +from confluent_kafka.oauthbearer.aws._aws_oauthbearer_config import AwsOAuthBearerConfig # noqa: E402 +from confluent_kafka.oauthbearer.aws._aws_sts_token_provider import ( # noqa: E402 + MINIMUM_BOTO3_VERSION, + AwsStsTokenProvider, + _require_boto3_version, + _version_tuple, +) + +# ---- Test helpers ---- + + +_ROLE_ARN = "arn:aws:iam::123:role/R" + + +def _base64url(data: bytes) -> str: + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _canned_jwt(sub: str = _ROLE_ARN) -> str: + header = _base64url(b'{"alg":"ES384","typ":"JWT"}') + payload = _base64url(f'{{"sub":"{sub}"}}'.encode("utf-8")) + return f"{header}.{payload}.sig" + + +_CANNED_JWT = _canned_jwt() +_CANNED_EXPIRY = datetime.datetime( + 2099, + 4, + 21, + 6, + 6, + 47, + 641_000, + tzinfo=datetime.timezone.utc, +) + + +class FakeStsClient: + """Test double mirroring .NET's FakeStsClient. + + Records the last request kwargs and returns a canned response (or raises + a canned exception). Avoids the ceremony of ``botocore.stub.Stubber`` for + the simple cases here. + """ + + def __init__(self, responder=None, raises: Optional[BaseException] = None) -> None: + self.last_request: Optional[Dict[str, Any]] = None + self._responder = responder + self._raises = raises + + def get_web_identity_token(self, **kwargs: Any) -> Dict[str, Any]: + self.last_request = kwargs + if self._raises is not None: + raise self._raises + if self._responder is not None: + return self._responder(kwargs) + return _ok_response() + + +def _ok_response( + jwt: str = _CANNED_JWT, + expiration: datetime.datetime = _CANNED_EXPIRY, +) -> Dict[str, Any]: + return {"WebIdentityToken": jwt, "Expiration": expiration} + + +# ---- Constructor: checks ---- + + +def test_ctor_null_config_raises(): + with pytest.raises(TypeError): + AwsStsTokenProvider(None) + + +def test_ctor_valid_parsed_config_succeeds(): + """No AWS / no HTTP call at construction time (lazy credential chain).""" + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + # Does not throw; does not call AWS (lazy credential chain). + assert provider is not None + + +# ---- Constructor: aws_debug ---- + + +def test_ctor_no_aws_debug_does_not_mutate_botocore_logger(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + # Capture botocore logger's level before & after construction; without + # an explicit aws_debug=console, we must not touch boto3.set_stream_logger. + with patch("boto3.set_stream_logger") as mock_setter: + AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + mock_setter.assert_not_called() + + +def test_ctor_aws_debug_none_does_not_mutate_botocore_logger(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,aws_debug=none") + with patch("boto3.set_stream_logger") as mock_setter: + AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + mock_setter.assert_not_called() + + +def test_ctor_aws_debug_console_routes_botocore_logger_to_stream(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,aws_debug=console") + with patch("boto3.set_stream_logger") as mock_setter: + AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + mock_setter.assert_called_once_with("botocore", logging.DEBUG) + + +# ---- token(): request shape ---- + + +def test_token_audience_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://my.audience") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["Audience"] == ["https://my.audience"] + + +def test_token_signing_algorithm_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,signing_algorithm=RS256") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["SigningAlgorithm"] == "RS256" + + +def test_token_duration_seconds_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,duration_seconds=900") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["DurationSeconds"] == 900 + + +def test_token_default_duration_sends_300_seconds(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["DurationSeconds"] == 300 + + +def test_token_default_signing_algorithm_sends_es384(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["SigningAlgorithm"] == "ES384" + + +def test_token_tags_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a,tag_team=platform,tag_environment=prod") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + + tags = fake.last_request["Tags"] + assert len(tags) == 2 + assert {"Key": "team", "Value": "platform"} in tags + assert {"Key": "environment", "Value": "prod"} in tags + + +def test_token_no_tags_omits_tags_field_from_request(): + """When no tags configured, omit the Tags field rather than send empty list.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert "Tags" not in fake.last_request + + +# ---- token(): response mapping ---- + + +def test_token_returns_mapped_fields(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + jwt, expiry, principal, extensions = provider.token() + + assert jwt == _CANNED_JWT + assert principal == _ROLE_ARN + assert expiry == _CANNED_EXPIRY.timestamp() + assert extensions == {} + + +def test_token_sasl_extensions_passthrough(): + fake = FakeStsClient() + sasl_extensions = {"logicalCluster": "lkc-123", "identityPoolId": "pool-x"} + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1,audience=https://a", + sasl_extensions, + ) + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, _, _, extensions = provider.token() + assert extensions == {"logicalCluster": "lkc-123", "identityPoolId": "pool-x"} + + +def test_token_no_extensions_configured_returns_empty_dict(): + """Python deviation from .NET: empty dict, not None, because the C + oauth_cb contract uses PyArg_ParseTuple "O!" with PyDict_Type at that + slot (rejects None). Empty dict is the Pythonic equivalent of the + .NET null-extensions case.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, _, _, extensions = provider.token() + assert extensions == {} + assert isinstance(extensions, dict) + + +def test_token_expiry_is_epoch_seconds_float(): + """The C wrapper expects expiry as epoch seconds (float). It multiplies + by 1000 internally to get milliseconds for librdkafka.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, expiry, _, _ = provider.token() + assert isinstance(expiry, float) + assert expiry == _CANNED_EXPIRY.timestamp() + + +# ---- token(): error propagation ---- + + +def test_token_missing_expiration_raises(): + """STS response without Expiration → ValueError surfaced as + rd_kafka_oauthbearer_set_token_failure by the C wrapper.""" + fake = FakeStsClient(responder=lambda kwargs: {"WebIdentityToken": _CANNED_JWT}) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ValueError, match="Expiration"): + provider.token() + + +def test_token_missing_token_value_raises(): + fake = FakeStsClient(responder=lambda kwargs: {"Expiration": _CANNED_EXPIRY}) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ValueError, match="WebIdentityToken"): + provider.token() + + +def test_token_malformed_jwt_raises_value_error(): + fake = FakeStsClient( + responder=lambda kwargs: { + "WebIdentityToken": "not-a-jwt", + "Expiration": _CANNED_EXPIRY, + } + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ValueError): + provider.token() + + +def test_token_sts_access_denied_propagates(): + fake = FakeStsClient( + raises=ClientError( + error_response={ + "Error": { + "Code": "AccessDenied", + "Message": "User is not authorized to perform: sts:GetWebIdentityToken", + } + }, + operation_name="GetWebIdentityToken", + ), + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ClientError) as exc_info: + provider.token() + assert exc_info.value.response["Error"]["Code"] == "AccessDenied" + + +def test_token_outbound_federation_disabled_propagates(): + fake = FakeStsClient( + raises=ClientError( + error_response={ + "Error": { + "Code": "OutboundWebIdentityFederationDisabledException", + "Message": "Outbound web identity federation is not enabled on this account.", + } + }, + operation_name="GetWebIdentityToken", + ), + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ClientError) as exc_info: + provider.token() + assert exc_info.value.response["Error"]["Code"] == "OutboundWebIdentityFederationDisabledException" + + +# ---- Surface invariants ---- + + +def test_aws_sts_token_provider_not_exposed_via_subpackage_init(): + """Public surface stays minimal — AwsStsTokenProvider is private.""" + import confluent_kafka.oauthbearer.aws as aws_pkg + + assert not hasattr(aws_pkg, "AwsStsTokenProvider") + + +# ---- Lazy credential resolution ---- + + +def test_ctor_does_not_invoke_sts(): + """Construction must not make any STS call (lazy credential chain). + Verify by checking the fake client received no requests after __init__ + but before any .token() invocation.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + AwsStsTokenProvider(cfg, sts_client=fake) + assert fake.last_request is None + + +# ---- sts_endpoint plumbed to boto3.client ---- + + +def test_ctor_sts_endpoint_plumbed_to_boto3_client(): + """When sts_endpoint is set on config, boto3.Session().client('sts', ...) + receives an endpoint_url kwarg.""" + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1,audience=https://a,sts_endpoint=https://sts-fips.us-east-1.amazonaws.com" + ) + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + AwsStsTokenProvider(cfg) + mock_session.client.assert_called_once_with( + "sts", + region_name="us-east-1", + endpoint_url="https://sts-fips.us-east-1.amazonaws.com", + ) + + +def test_ctor_no_sts_endpoint_omits_endpoint_url_kwarg(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + AwsStsTokenProvider(cfg) + mock_session.client.assert_called_once_with( + "sts", + region_name="us-east-1", + ) + + +# ---- boto3 version floor check ---- + + +def test_version_tuple_parses_numeric_components(): + assert _version_tuple("1.42.25") == (1, 42, 25) + assert _version_tuple("1.42.97") == (1, 42, 97) + assert _version_tuple("2.0.0") == (2, 0, 0) + + +def test_version_tuple_tolerates_prerelease_suffix(): + assert _version_tuple("1.42.0rc1") == (1, 42, 0) + assert _version_tuple("1.42.25.dev0") == (1, 42, 25, 0) + + +def test_require_boto3_version_passes_on_installed_version(): + # The opt-in test venv installs boto3 >= the floor, so this must not raise. + _require_boto3_version() + + +def test_require_boto3_version_below_floor_raises(monkeypatch): + import boto3 + + monkeypatch.setattr(boto3, "__version__", "1.40.0") + with pytest.raises(ImportError, match="requires boto3>=") as exc_info: + _require_boto3_version() + # message names both the required floor and the offending installed version + assert MINIMUM_BOTO3_VERSION in str(exc_info.value) + assert "1.40.0" in str(exc_info.value) + + +def test_require_boto3_version_at_floor_ok(monkeypatch): + import boto3 + + monkeypatch.setattr(boto3, "__version__", MINIMUM_BOTO3_VERSION) + _require_boto3_version() # exactly the floor satisfies ">=" + + +def test_ctor_old_boto3_raises_before_building_client(monkeypatch): + """The floor check is wired into __init__ on the real-client path and fails + fast — before any boto3.Session/client is constructed.""" + import boto3 + + monkeypatch.setattr(boto3, "__version__", "1.40.0") + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + with patch("boto3.Session") as mock_session_cls: + with pytest.raises(ImportError, match="GetWebIdentityToken"): + AwsStsTokenProvider(cfg) # no sts_client → real path → check runs + mock_session_cls.assert_not_called() # raised before building a client + + +def test_ctor_injected_client_skips_version_check(monkeypatch): + """Injecting an sts_client (test seam) bypasses the boto3 floor check — the + real boto3 isn't used for the STS call in that case.""" + import boto3 + + monkeypatch.setattr(boto3, "__version__", "1.0.0") # ancient, but skipped + cfg = AwsOAuthBearerConfig.parse("region=us-east-1,audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + assert provider is not None + + +def test_minimum_boto3_version_matches_requirements_file(): + """Guard against MINIMUM_BOTO3_VERSION drifting from the pinned floor in + requirements/requirements-oauthbearer-aws.txt.""" + req = Path(__file__).resolve().parents[3] / "requirements" / "requirements-oauthbearer-aws.txt" + floor = None + for line in req.read_text().splitlines(): + stripped = line.strip() + if stripped.startswith("boto3") and ">=" in stripped: + floor = stripped.split(">=", 1)[1].strip() + break + assert floor == MINIMUM_BOTO3_VERSION, ( + f"requirements-oauthbearer-aws.txt pins boto3>={floor} but " + f"MINIMUM_BOTO3_VERSION is {MINIMUM_BOTO3_VERSION}; keep them in sync." + ) diff --git a/tests/oauthbearer/aws/test_contract.py b/tests/oauthbearer/aws/test_contract.py new file mode 100644 index 000000000..9d5a02472 --- /dev/null +++ b/tests/oauthbearer/aws/test_contract.py @@ -0,0 +1,135 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Frozen cross-module contract guard for the autowire entry-point. + +These tests guard against accidental drift: parameter names, type +annotations, return annotation, arity, and absence of defaults. +""" + +import inspect +from typing import Callable, Dict, Optional, Tuple + +import pytest + +pytest.importorskip("boto3") + +from confluent_kafka.oauthbearer.aws import aws_autowire # noqa: E402 +from confluent_kafka.oauthbearer.aws.aws_autowire import OAuthBearerCallback, create_handler # noqa: E402 + +# ---- Module surface ---- + + +def test_module_importable_at_canonical_path(): + """The C dispatcher does PyImport_ImportModule(...) with this exact path.""" + import importlib + + mod = importlib.import_module("confluent_kafka.oauthbearer.aws.aws_autowire") + assert mod is aws_autowire + + +def test_create_handler_is_module_level_callable(): + assert callable(aws_autowire.create_handler) + assert aws_autowire.create_handler is create_handler + + +def test_oauthbearer_callback_type_alias_exported(): + assert hasattr(aws_autowire, "OAuthBearerCallback") + assert aws_autowire.OAuthBearerCallback is OAuthBearerCallback + + +def test_module_all_lists_public_surface(): + """__all__ should advertise only the two public names.""" + assert set(aws_autowire.__all__) == {"create_handler", "OAuthBearerCallback"} + + +# ---- create_handler frozen signature ---- + + +def test_create_handler_arity_is_two(): + sig = inspect.signature(create_handler) + assert len(sig.parameters) == 2 + + +def test_create_handler_parameter_names_are_frozen(): + sig = inspect.signature(create_handler) + names = list(sig.parameters.keys()) + assert names == ["sasl_oauthbearer_config", "sasl_oauthbearer_extensions"] + + +def test_create_handler_parameter_annotations_are_frozen(): + sig = inspect.signature(create_handler) + assert sig.parameters["sasl_oauthbearer_config"].annotation is str + assert sig.parameters["sasl_oauthbearer_extensions"].annotation == Optional[str] + + +def test_create_handler_no_default_values(): + """Both parameters must be required positional — the C dispatcher always + passes two arguments (extensions may be the empty string or None, but + is always supplied).""" + sig = inspect.signature(create_handler) + assert sig.parameters["sasl_oauthbearer_config"].default is inspect.Parameter.empty + assert sig.parameters["sasl_oauthbearer_extensions"].default is inspect.Parameter.empty + + +def test_create_handler_return_annotation_is_oauthbearer_callback(): + sig = inspect.signature(create_handler) + assert sig.return_annotation is OAuthBearerCallback + + +def test_create_handler_parameters_accept_positional_arguments(): + """The C dispatcher calls via PyObject_CallFunction with positional args.""" + sig = inspect.signature(create_handler) + for name in ["sasl_oauthbearer_config", "sasl_oauthbearer_extensions"]: + kind = sig.parameters[name].kind + # POSITIONAL_OR_KEYWORD covers what PyObject_CallFunction provides. + assert kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + + +# ---- OAuthBearerCallback shape (the type returned by create_handler) ---- + + +def test_oauthbearer_callback_matches_c_oauth_cb_contract(): + """The C oauth_cb wrapper at confluent_kafka.c:2291 does: + + PyArg_ParseTuple(result, "sd|sO!", &token, &expiry, &principal, &PyDict_Type, &extensions) + + So the callable returned by create_handler must accept one string + argument (the sasl.oauthbearer.config pass-through) and return a tuple + of (str, float, str, Dict[str, str]). + """ + expected = Callable[[str], Tuple[str, float, str, Dict[str, str]]] + assert OAuthBearerCallback == expected + + +# ---- create_handler docstring sanity ---- + + +def test_create_handler_docstring_present(): + """The frozen cross-module contract is documented in the function's + docstring so it's visible to anyone reading the source. Block a future + contributor from accidentally stripping the docstring.""" + assert create_handler.__doc__ is not None + assert "frozen" in aws_autowire.__doc__.lower() + + +def test_create_handler_docstring_references_key_contract_terms(): + doc = create_handler.__doc__ or "" + # Names of the two arguments must appear so users grepping for them + # find the documentation. + assert "sasl.oauthbearer.config" in doc + assert "sasl.oauthbearer.extensions" in doc diff --git a/tests/oauthbearer/aws/test_dispatch.py b/tests/oauthbearer/aws/test_dispatch.py new file mode 100644 index 000000000..04e346883 --- /dev/null +++ b/tests/oauthbearer/aws/test_dispatch.py @@ -0,0 +1,469 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the C-extension dispatcher resolve_aws_oauthbearer_marker(). + +The dispatcher fires inside common_conf_setup() (src/confluent_kafka.c) on +every Producer / Consumer / AdminClient construction (and transitively +AIOProducer / AIOConsumer, which wrap the sync clients). + +These tests exercise the C dispatcher indirectly via client construction: +no direct path to the helper exists from Python — that's the design. +""" + +import base64 +import datetime +import sys +from typing import Any, Dict, Optional +from unittest.mock import MagicMock, patch + +import pytest + +pytest.importorskip("boto3") + +import confluent_kafka # noqa: E402 +from confluent_kafka.admin import AdminClient # noqa: E402 + +# ---- Test helpers ---- + + +def _base64url(data: bytes) -> str: + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _canned_jwt(sub: str = "arn:aws:iam::123:role/R") -> str: + header = _base64url(b'{"alg":"ES384","typ":"JWT"}') + payload = _base64url(f'{{"sub":"{sub}"}}'.encode("utf-8")) + return f"{header}.{payload}.sig" + + +def _canned_response() -> Dict[str, Any]: + return { + "WebIdentityToken": _canned_jwt(), + "Expiration": datetime.datetime( + 2099, + 4, + 21, + 6, + 6, + 47, + tzinfo=datetime.timezone.utc, + ), + } + + +@pytest.fixture +def mocked_boto3(): + """Patch boto3.Session so the autowired provider's STS client is a mock. + The actual STS call doesn't happen at client construction (lazy), but the + Session/client construction does — without this, boto3 would try to + resolve real credentials.""" + with patch("boto3.Session") as session_cls: + session = session_cls.return_value + client = MagicMock() + client.get_web_identity_token.return_value = _canned_response() + session.client.return_value = client + yield client + + +def _minimal_aws_iam_config(extra: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + cfg = { + "bootstrap.servers": "broker.invalid:9092", + # security.protocol=SASL_SSL is required for librdkafka to actually + # engage the OAUTHBEARER refresh path. Without it, sasl.mechanisms is + # inert (librdkafka logs CONFWARN and the refresh_cb never fires). + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a", + } + if extra: + cfg.update(extra) + return cfg + + +# ---- 1. Marker absent → dispatcher is a no-op (regression check) ---- + + +def test_marker_absent_producer_constructs_unchanged(): + p = confluent_kafka.Producer({"bootstrap.servers": "localhost:9092"}) + p.flush(timeout=0.1) + + +def test_marker_absent_consumer_constructs_unchanged(): + c = confluent_kafka.Consumer( + { + "bootstrap.servers": "localhost:9092", + "group.id": "test-group", + } + ) + c.close() + + +def test_marker_absent_admin_client_constructs_unchanged(): + a = AdminClient({"bootstrap.servers": "localhost:9092"}) + assert a is not None + + +def test_marker_absent_does_not_import_aws_modules(): + """If the marker is absent the dispatcher must NOT touch the optional + subpackage — boto3 import must not happen.""" + # Clear any prior imports of the AWS subpackage so we can detect a fresh import. + for name in list(sys.modules): + if name.startswith("confluent_kafka.oauthbearer.aws"): + del sys.modules[name] + confluent_kafka.Producer({"bootstrap.servers": "localhost:9092"}) + for name in [ + "confluent_kafka.oauthbearer.aws.aws_autowire", + "confluent_kafka.oauthbearer.aws._aws_sts_token_provider", + ]: + assert name not in sys.modules, ( + f"Dispatcher imported {name} when no marker was set — " "this means the no-op short-circuit broke." + ) + + +# ---- 2. Other marker values pass through verbatim ---- + + +def test_other_marker_value_passes_through_unchanged(): + """azure_imds / unknown values are not our concern — librdkafka handles them.""" + # azure_imds is a librdkafka-recognised value; we expect Producer construction + # to NOT raise our ValueError. (librdkafka itself may complain about the + # value, but our dispatcher must not.) + with pytest.raises(Exception) as exc_info: + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "azure_imds", + } + ) + # Whatever librdkafka raises, it must NOT be our method-requirement + # error or our config-requirement error. + msg = str(exc_info.value) + assert "AWS IAM" not in msg + assert "aws_iam" not in msg + + +# ---- 3. method=oidc requirement ---- + + +def test_marker_without_method_raises(): + with pytest.raises(ValueError, match="method=oidc"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a", + } + ) + + +def test_marker_with_method_default_raises(): + with pytest.raises(ValueError, match="method=oidc"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "default", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a", + } + ) + + +def test_marker_with_method_oidc_uppercase_raises(): + """Strict matching — librdkafka's method values are lowercase canonical.""" + with pytest.raises(ValueError, match="method=oidc"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "OIDC", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a", + } + ) + + +# ---- 4. sasl.oauthbearer.config requirement ---- + + +def test_marker_with_method_oidc_but_missing_config_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.config.*missing or empty"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + } + ) + + +def test_marker_with_method_oidc_but_empty_config_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.config.*missing or empty"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "", + } + ) + + +# ---- 5. Happy path: marker + method=oidc + valid config + boto3 stubbed ---- + + +def test_happy_path_producer_constructs_with_marker(mocked_boto3): + p = confluent_kafka.Producer(_minimal_aws_iam_config()) + assert p is not None + p.flush(timeout=0.1) + + +def test_happy_path_consumer_constructs_with_marker(mocked_boto3): + c = confluent_kafka.Consumer(_minimal_aws_iam_config({"group.id": "test-group"})) + assert c is not None + c.close() + + +def test_happy_path_admin_client_constructs_with_marker(mocked_boto3): + a = AdminClient(_minimal_aws_iam_config()) + assert a is not None + + +async def test_happy_path_aio_producer_constructs_with_marker(mocked_boto3): + """AIO wraps sync — same dispatcher fires via the inner cimpl.Producer. + + AIOProducer.__init__ calls asyncio.get_running_loop(), so this test + must run inside an async context (pyproject.toml's asyncio_mode=auto + handles that for async def test functions).""" + from confluent_kafka.aio import AIOProducer + + p = AIOProducer(_minimal_aws_iam_config()) + assert p is not None + + +async def test_happy_path_aio_consumer_constructs_with_marker(mocked_boto3): + """Same async-context requirement as AIOProducer.""" + from confluent_kafka.aio import AIOConsumer + + c = AIOConsumer(_minimal_aws_iam_config({"group.id": "test-group"})) + assert c is not None + + +# ---- 6. Explicit oauth_cb wins (precedence rule) ---- + + +def test_explicit_oauth_cb_wins_over_marker(): + """When the user supplies their own oauth_cb, the dispatcher skips autowire + entirely and leaves the config untouched (their callback + the marker + + method all pass through). boto3 is NOT touched.""" + sentinel_called = [] + + def user_oauth_cb(config_str): + sentinel_called.append(config_str) + return ("user-token", 9999999999.0, "user-principal", {}) + + # Clear sys.modules so we can verify aws_autowire was NOT imported. + for name in list(sys.modules): + if name.startswith("confluent_kafka.oauthbearer.aws"): + del sys.modules[name] + + p = confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a", + "oauth_cb": user_oauth_cb, + } + ) + assert p is not None + p.flush(timeout=0.1) + # The autowire module must not have been imported. + assert "confluent_kafka.oauthbearer.aws.aws_autowire" not in sys.modules + + +# ---- 7. Parser errors surface from the autowire path ---- + + +def test_marker_with_invalid_config_grammar_raises(mocked_boto3): + """Config parser ValueError surfaces through the dispatcher.""" + with pytest.raises(ValueError, match="Unknown key.*not_a_key"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a,not_a_key=foo", + } + ) + + +def test_marker_with_invalid_signing_algorithm_raises(mocked_boto3): + with pytest.raises(ValueError, match="signing_algorithm"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a,signing_algorithm=HS256", + } + ) + + +def test_marker_with_invalid_extensions_grammar_raises(mocked_boto3): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a", + "sasl.oauthbearer.extensions": "malformed-no-equals", + } + ) + + +# ---- 8. Friendly ImportError when the optional extra is missing ---- + + +@pytest.fixture +def boto3_absent(monkeypatch): + """Simulates an opt-out environment: boto3 import fails + relevant + aws.* submodules cleared from sys.modules so the C dispatcher's + PyImport_ImportModule re-executes the module body and hits the + boto3=None gate.""" + # Force any subsequent `import boto3` to raise ImportError. + monkeypatch.setitem(sys.modules, "boto3", None) + # Clear the cached aws submodules so PyImport_ImportModule re-executes + # their top-level statements (including `import boto3`). + for name in list(sys.modules): + if name.startswith("confluent_kafka.oauthbearer.aws"): + monkeypatch.delitem(sys.modules, name, raising=False) + yield + # monkeypatch reverts on teardown. + + +def test_marker_with_missing_extra_raises_friendly_import_error(boto3_absent): + """When boto3 isn't available, the dispatcher catches the + ModuleNotFoundError from the import chain and rewrites it into a + friendly install hint. __cause__ chain preserves the original.""" + with pytest.raises(ImportError) as exc_info: + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a", + } + ) + msg = str(exc_info.value) + assert "oauthbearer-aws" in msg + assert "pip install" in msg + assert "aws_iam" in msg + # __cause__ preserves the original failure for diagnostic tools. + assert exc_info.value.__cause__ is not None + + +def test_friendly_import_error_on_consumer_too(boto3_absent): + with pytest.raises(ImportError, match="oauthbearer-aws"): + confluent_kafka.Consumer( + { + "bootstrap.servers": "broker.invalid:9092", + "group.id": "g", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a", + } + ) + + +def test_friendly_import_error_on_admin_client_too(boto3_absent): + with pytest.raises(ImportError, match="oauthbearer-aws"): + AdminClient( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1,audience=https://a", + } + ) + + +# ---- 9. Marker is RETAINED and passed through to the AWS-IAM librdkafka ---- + + +def test_marker_retained_and_token_refresh_fires(mocked_boto3): + """Keep-marker contract: the dispatcher leaves + sasl.oauthbearer.metadata.authentication.type=aws_iam AND + sasl.oauthbearer.method=oidc in the config and only registers the autowired + oauth_cb — no marker strip, no method rewrite.""" + import time + + p = confluent_kafka.Producer(_minimal_aws_iam_config()) + assert p is not None + deadline = time.time() + 10 + while mocked_boto3.get_web_identity_token.call_count == 0 and time.time() < deadline: + p.poll(0.5) + assert mocked_boto3.get_web_identity_token.call_count >= 1, ( + "librdkafka never invoked the autowired oauth_cb — the keep-marker " + "refresh path did not engage (handle creation or marker pass-through failed)" + ) + p.flush(timeout=0.1) + + +# ---- 10. Extensions plumbing through the dispatcher ---- + + +def test_marker_with_extensions_plumbs_through(mocked_boto3): + """The dispatcher reads sasl.oauthbearer.extensions and passes it as the + second arg to create_handler. Verify by checking the returned callable + yields the extensions in its 4-tuple result.""" + p = confluent_kafka.Producer( + _minimal_aws_iam_config( + { + "sasl.oauthbearer.extensions": "logicalCluster=lkc-123", + } + ) + ) + assert p is not None + p.flush(timeout=0.1) + + +def test_marker_with_empty_extensions_treated_as_absent(mocked_boto3): + """Empty extensions string → autowire treats as None (no extensions).""" + p = confluent_kafka.Producer( + _minimal_aws_iam_config( + { + "sasl.oauthbearer.extensions": "", + } + ) + ) + assert p is not None + p.flush(timeout=0.1) diff --git a/tests/oauthbearer/aws/test_jwt_extractor.py b/tests/oauthbearer/aws/test_jwt_extractor.py new file mode 100644 index 000000000..8d52c9f36 --- /dev/null +++ b/tests/oauthbearer/aws/test_jwt_extractor.py @@ -0,0 +1,238 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._jwt_extractor.""" + +import base64 + +import pytest + +from confluent_kafka.oauthbearer.aws._jwt_extractor import extract_sub + +# ---- Test helpers ---- + + +def _base64url_encode(data: bytes) -> str: + """Base64url-encodes a byte array: standard base64, trim '=' padding, + swap '+' → '-' and '/' → '_'. + """ + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _make_jwt(payload_json: str, alg: str = "ES384") -> str: + """Build a 3-segment JWT with the given payload JSON. + + Header and signature segments are placeholders — ``extract_sub`` only + decodes the payload (``parts[1]``). + """ + header = _base64url_encode(f'{{"alg":"{alg}","typ":"JWT"}}'.encode("utf-8")) + payload = _base64url_encode(payload_json.encode("utf-8")) + return f"{header}.{payload}.sig" + + +# ---- Real STS payloads (from .NET test suite — signature stripped, only payload matters) ---- + + +_EXPECTED_ROLE_ARN = "arn:aws:iam::708975691912:role/prashah-iam-sts-test-role" + +_REAL_ES384_JWT = ( + "eyJraWQiOiJFQzM4NF8wIiwidHlwIjoiSldUIiwiYWxnIjoiRVMzODQifQ" + ".eyJhdWQiOiJodHRwczovL2FwaS5leGFtcGxlLmNvbSIsInN1YiI6ImFybjphd3M6aWFtOjo3MDg5N" + "zU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtdGVzdC1yb2xlIiwiaHR0cHM6Ly9zdHMuYW1hem9u" + "YXdzLmNvbS8iOnsiZWMyX2luc3RhbmNlX3NvdXJjZV92cGMiOiJ2cGMtYWQ4NzMzYzQiLCJlYzJfcm9" + "sZV9kZWxpdmVyeSI6IjIuMCIsIm9yZ19pZCI6Im8tMHgzdDh1bW9seiIsImF3c19hY2NvdW50IjoiNz" + "A4OTc1NjkxOTEyIiwib3VfcGF0aCI6WyJvLTB4M3Q4dW1vbHovci16YzVqL291LXpjNWotNXJwd3pqb" + "HIvIl0sIm9yaWdpbmFsX3Nlc3Npb25fZXhwIjoiMjAyNi0wNS0xNFQyMzoxNjozMloiLCJzb3VyY2Vf" + "cmVnaW9uIjoiZXUtbm9ydGgtMSIsImVjMl9zb3VyY2VfaW5zdGFuY2VfYXJuIjoiYXJuOmF3czplYzI" + "6ZXUtbm9ydGgtMTo3MDg5NzU2OTE5MTI6aW5zdGFuY2UvaS0wOTc5MGY5OTY4YzExYTFjNyIsInByaW" + "5jaXBhbF9pZCI6ImFybjphd3M6aWFtOjo3MDg5NzU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtd" + "GVzdC1yb2xlIiwicHJpbmNpcGFsX3RhZ3MiOnsiZGl2dnlfb3duZXIiOiJwcmFzaGFoQGNvbmZsdWVu" + "dC5pbyIsImRpdnZ5X2xhc3RfbW9kaWZpZWRfYnkiOiJwcmFzaGFoQGNvbmZsdWVudC5pbyJ9LCJlYzJ" + "faW5zdGFuY2Vfc291cmNlX3ByaXZhdGVfaXB2NCI6IjE3Mi4zMS4zLjEwOSJ9LCJpc3MiOiJodHRwcz" + "ovL2ExZWJjNzA1LWNkNGQtNDJiNC05M2I1LTk2ZTkzYWNmYjQzMS50b2tlbnMuc3RzLmdsb2JhbC5hc" + "GkuYXdzIiwiZXhwIjoxNzc4Nzc4NzY2LCJpYXQiOjE3Nzg3Nzg0NjYsImp0aSI6IjFkM2QzZmMyLTBl" + "NzktNDI2OS05NjcxLTJmODQ4NDYxOWZiNyJ9" + ".sig" +) + +_REAL_RS256_JWT = ( + "eyJraWQiOiJSU0FfMCIsInR5cCI6IkpXVCIsImFsZyI6IlJTMjU2In0" + ".eyJhdWQiOiJodHRwczovL2FwaS5leGFtcGxlLmNvbSIsInN1YiI6ImFybjphd3M6aWFtOjo3MDg5N" + "zU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtdGVzdC1yb2xlIiwiaHR0cHM6Ly9zdHMuYW1hem9u" + "YXdzLmNvbS8iOnsiZWMyX2luc3RhbmNlX3NvdXJjZV92cGMiOiJ2cGMtYWQ4NzMzYzQiLCJlYzJfcm9" + "sZV9kZWxpdmVyeSI6IjIuMCIsIm9yZ19pZCI6Im8tMHgzdDh1bW9seiIsImF3c19hY2NvdW50IjoiNz" + "A4OTc1NjkxOTEyIiwib3VfcGF0aCI6WyJvLTB4M3Q4dW1vbHovci16YzVqL291LXpjNWotNXJwd3pqb" + "HIvIl0sIm9yaWdpbmFsX3Nlc3Npb25fZXhwIjoiMjAyNi0wNS0xNFQyMzoxNjozMloiLCJzb3VyY2Vf" + "cmVnaW9uIjoiZXUtbm9ydGgtMSIsImVjMl9zb3VyY2VfaW5zdGFuY2VfYXJuIjoiYXJuOmF3czplYzI" + "6ZXUtbm9ydGgtMTo3MDg5NzU2OTE5MTI6aW5zdGFuY2UvaS0wOTc5MGY5OTY4YzExYTFjNyIsInByaW" + "5jaXBhbF9pZCI6ImFybjphd3M6aWFtOjo3MDg5NzU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtd" + "GVzdC1yb2xlIiwicHJpbmNpcGFsX3RhZ3MiOnsiZGl2dnlfb3duZXIiOiJwcmFzaGFoQGNvbmZsdWVu" + "dC5pbyIsImRpdnZ5X2xhc3RfbW9kaWZpZWRfYnkiOiJwcmFzaGFoQGNvbmZsdWVudC5pbyJ9LCJlYzJ" + "faW5zdGFuY2Vfc291cmNlX3ByaXZhdGVfaXB2NCI6IjE3Mi4zMS4zLjEwOSJ9LCJpc3MiOiJodHRwcz" + "ovL2ExZWJjNzA1LWNkNGQtNDJiNC05M2I1LTk2ZTkzYWNmYjQzMS50b2tlbnMuc3RzLmdsb2JhbC5hc" + "GkuYXdzIiwiZXhwIjoxNzc4Nzc4NzY2LCJpYXQiOjE3Nzg3Nzg0NjYsImp0aSI6IjU5OTliZDc1LTBm" + "NTctNDMzMS1iOGExLWJkYzYwMjgxN2UyZiJ9" + ".sig" +) + + +# ---- Happy paths ---- + + +def test_extract_sub_role_arn_returned(): + jwt = _make_jwt('{"sub":"arn:aws:iam::123456789012:role/MyRole","iat":1}') + assert extract_sub(jwt) == "arn:aws:iam::123456789012:role/MyRole" + + +def test_extract_sub_assumed_role_arn_returned(): + jwt = _make_jwt('{"sub":"arn:aws:sts::123456789012:assumed-role/MyRole/session-name"}') + assert extract_sub(jwt) == "arn:aws:sts::123456789012:assumed-role/MyRole/session-name" + + +def test_extract_sub_other_claims_ignored(): + jwt = _make_jwt('{"iss":"https://x","sub":"arn:aws:iam::1:role/R",' '"aud":"a","exp":1,"iat":0,"jti":"j"}') + assert extract_sub(jwt) == "arn:aws:iam::1:role/R" + + +@pytest.mark.parametrize("jwt", [_REAL_ES384_JWT, _REAL_RS256_JWT]) +def test_extract_sub_real_sts_jwt_returns_expected_arn(jwt): + assert extract_sub(jwt) == _EXPECTED_ROLE_ARN + + +# ---- Base64url padding branches (% 4 ∈ {0, 2, 3}) ---- + + +def test_extract_sub_unpadded_base64url_works(): + unpadded = _make_jwt('{"sub":"a"}') + assert extract_sub(unpadded) == "a" + + +def test_extract_sub_padded_base64url_also_works(): + """If user pre-pads with '=' (non-canonical but tolerated), still decodes.""" + header = _base64url_encode(b'{"alg":"none"}') + # Encode WITHOUT stripping padding. + payload = base64.b64encode(b'{"sub":"abc"}').decode("ascii").replace("+", "-").replace("/", "_") + jwt_3seg = f"{header}.{payload}.sig" + assert extract_sub(jwt_3seg) == "abc" + + +def test_extract_sub_url_safe_chars_handled(): + bytes_ = b'{"sub":"x"}' + normal = base64.b64encode(bytes_).decode("ascii") + url_safe = normal.replace("+", "-").replace("/", "_").rstrip("=") + jwt = "aGVhZGVy." + url_safe + ".c2ln" + assert extract_sub(jwt) == "x" + + +@pytest.mark.parametrize( + "payload_json,expected_sub", + [ + ('{"sub":"a"}', "a"), # → encoded length % 4 = 3 (one '=' padded) + ('{"sub":"ab"}', "ab"), # → encoded length % 4 = 0 (no padding) + ('{"sub":"abc"}', "abc"), # → encoded length % 4 = 2 (two '==' padded) + ], +) +def test_extract_sub_padding_branches_all_hit_decodes_correctly(payload_json, expected_sub): + """Explicit per-branch coverage of the % 4 padding switch.""" + assert extract_sub(_make_jwt(payload_json)) == expected_sub + + +# ---- Top-level shape errors ---- + + +def test_extract_sub_null_raises(): + with pytest.raises(ValueError, match="null"): + extract_sub(None) + + +def test_extract_sub_empty_raises(): + with pytest.raises(ValueError, match="empty"): + extract_sub("") + + +def test_extract_sub_one_segment_raises(): + with pytest.raises(ValueError, match="3"): + extract_sub("onlyonepart") + + +def test_extract_sub_two_segments_raises(): + with pytest.raises(ValueError, match="3"): + extract_sub("a.b") + + +def test_extract_sub_four_segments_raises(): + with pytest.raises(ValueError, match="3"): + extract_sub("a.b.c.d") + + +def test_extract_sub_oversized_input_raises(): + oversized = "a" * 8193 + with pytest.raises(ValueError, match="exceeds maximum"): + extract_sub(oversized) + + +def test_extract_sub_at_ceiling_reaches_parser(): + """8192-char input passes the size gate and fails downstream on segment count.""" + at_ceiling = "a" * 8192 + with pytest.raises(ValueError, match="3"): + extract_sub(at_ceiling) + + +def test_extract_sub_empty_payload_segment_raises(): + with pytest.raises(ValueError, match="empty"): + extract_sub("header..sig") + + +# ---- Payload-content errors ---- + + +def test_extract_sub_malformed_base64_in_payload_raises(): + with pytest.raises(ValueError, match="base64url"): + extract_sub("aGVhZGVy.not!base64.c2ln") + + +def test_extract_sub_malformed_json_in_payload_raises(): + bad_payload = _base64url_encode(b"not json") + with pytest.raises(ValueError, match="not valid JSON"): + extract_sub(f"aGVhZGVy.{bad_payload}.c2ln") + + +def test_extract_sub_payload_is_json_array_raises(): + array_payload = _base64url_encode(b'["not","an","object"]') + with pytest.raises(ValueError, match="not a JSON object"): + extract_sub(f"aGVhZGVy.{array_payload}.c2ln") + + +def test_extract_sub_missing_sub_claim_raises(): + jwt = _make_jwt('{"iss":"https://x","aud":"a"}') + with pytest.raises(ValueError, match="'sub'"): + extract_sub(jwt) + + +def test_extract_sub_sub_claim_is_number_raises(): + jwt = _make_jwt('{"sub":12345}') + with pytest.raises(ValueError, match="'sub'"): + extract_sub(jwt) + + +def test_extract_sub_sub_claim_is_null_raises(): + jwt = _make_jwt('{"sub":null}') + with pytest.raises(ValueError, match="'sub'"): + extract_sub(jwt) + + +def test_extract_sub_sub_claim_is_empty_string_raises(): + jwt = _make_jwt('{"sub":""}') + with pytest.raises(ValueError, match="empty"): + extract_sub(jwt) diff --git a/tests/oauthbearer/aws/test_real_sts.py b/tests/oauthbearer/aws/test_real_sts.py new file mode 100644 index 000000000..d47cea3e9 --- /dev/null +++ b/tests/oauthbearer/aws/test_real_sts.py @@ -0,0 +1,376 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Phase 6 — real-AWS integration tests for the autowire path. + +Gated on ``RUN_AWS_STS_REAL=1``. Default ``pytest tests/oauthbearer/aws`` +runs every other Phase 1-5 case but skips this file's tests so CI doesn't +hit the AWS API. + +Why this file lives under ``tests/oauthbearer/aws/`` rather than +``tests/integration/``: ``tests/integration/conftest.py`` eagerly imports +``trivup.clusters.KafkaCluster``, which is a heavyweight broker-only +dependency. The STS provider has no broker dependency — the env-var gate +is sufficient isolation. Same gate-placement call as the 29-April POC. + +Cross-language invariant: on the shared EC2 test box +(``ktrue-iam-sts-test-role`` in ``eu-north-1``, account ``708975691912``, +audience ``https://api.example.com``) Go / .NET / JS / librdkafka all +mint a **1256-byte JWT** with principal +``arn:aws:iam::708975691912:role/ktrue-iam-sts-test-role``. The Python +port preserves that invariant. + +Run instructions:: + + # On EC2 with role attached, audience-trust-policy enabled: + export RUN_AWS_STS_REAL=1 + export AWS_STS_TEST_REGION=eu-north-1 # AWS_REGION also accepted (fallback) + export AWS_STS_TEST_AUDIENCE=https://api.example.com + /tmp/ckp-optin/bin/pytest -v -s tests/oauthbearer/aws/test_real_sts.py + +The ``-s`` flag preserves the diagnostic ``print(...)`` lines that capture +the JWT length, principal, and timing — important for cross-language +parity evidence in PR descriptions. +""" + +import base64 +import json +import os +import re +import time + +import pytest + +# boto3 is loaded transitively via the autowire path. Skip the whole file +# in opt-out venvs so collection doesn't blow up. +pytest.importorskip("boto3") + +# Gate the entire module behind RUN_AWS_STS_REAL=1 so CI doesn't accidentally +# hit STS. Individual @pytest.mark.skipif decorators would work too but the +# module-level pytestmark is cleaner and less error-prone. +pytestmark = pytest.mark.skipif( + os.environ.get("RUN_AWS_STS_REAL") != "1", + reason="Set RUN_AWS_STS_REAL=1 and provide AWS credentials to run.", +) + +from confluent_kafka.oauthbearer.aws.aws_autowire import create_handler # noqa: E402 + +# ============================================================================= +# Configuration sourced from env vars so the test can be re-pointed at a +# different role / audience without code changes. +# ============================================================================= + +_REGION = os.environ.get("AWS_STS_TEST_REGION") or os.environ.get("AWS_REGION", "eu-north-1") +_AUDIENCE = os.environ.get("AWS_STS_TEST_AUDIENCE", "https://api.example.com") +_DURATION = os.environ.get("DURATION_SECONDS", "300") + +# Cross-language invariant: 1256 bytes on the shared test role + audience. +# Allow a tolerance band for variation across audiences (URL length affects +# payload size). Catches order-of-magnitude bugs without flaking on tiny +# audience-string changes. +_JWT_LENGTH_MIN = int(os.environ.get("JWT_LENGTH_MIN", "1100")) +_JWT_LENGTH_MAX = int(os.environ.get("JWT_LENGTH_MAX", "1500")) + +_JWT_PATTERN = re.compile(r"^[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+$") +_ARN_PATTERN = re.compile(r"^arn:aws:(?:iam|sts)::\d+:[^/]+/.+$") + + +def _config_string(**parts) -> str: + """Render a sasl.oauthbearer.config wire string from key=value parts.""" + return ",".join(f"{k}={v}" for k, v in parts.items()) + + +def _default_config(**extra) -> str: + """Build a sasl.oauthbearer.config string with the defaults plus extras.""" + parts = {"region": _REGION, "audience": _AUDIENCE, "duration_seconds": _DURATION} + parts.update(extra) + return _config_string(**parts) + + +def _decode_jwt_payload(jwt: str) -> dict: + """Decode the JWT payload (middle segment) into a Python dict. + + Helper for tests that need to inspect specific claims (tags, audience, + algorithm) end-to-end. + """ + payload_b64url = jwt.split(".")[1] + # base64url → base64 with padding + padding = "=" * (-len(payload_b64url) % 4) + standard_b64 = (payload_b64url + padding).replace("-", "+").replace("_", "/") + return json.loads(base64.b64decode(standard_b64).decode("utf-8")) + + +# ============================================================================= +# create_handler() — the cross-module entry-point contract +# ============================================================================= + + +def test_create_handler_mints_valid_jwt(): + handler = create_handler(_default_config(), None) + token, expiry, principal, extensions = handler("") + + # 3-segment base64url JWT. + assert _JWT_PATTERN.match(token), f"not a 3-segment JWT: {token[:60]}..." + + # Length within cross-language tolerance. + assert _JWT_LENGTH_MIN <= len(token) <= _JWT_LENGTH_MAX, ( + f"JWT length {len(token)} outside [{_JWT_LENGTH_MIN}, " + f"{_JWT_LENGTH_MAX}] (Go/.NET/JS/librdkafka observe 1256 on the shared " + f"test box). Set JWT_LENGTH_MIN/MAX env vars if running on a different " + f"role/audience." + ) + + # Expiry within the requested window. + now = time.time() + assert expiry > now, "expiry must be in the future" + assert expiry < now + 10 * 60, "expiry must be within 10 min for DurationSeconds <= 300" + + # Principal is a bare role ARN (AWS STS GetWebIdentityToken convention). + assert _ARN_PATTERN.match(principal), f"unexpected principal: {principal!r}" + + # Default config has no extensions. + assert extensions == {} + + # Diagnostic for the run log so EC2 evidence is captured in CI / PR. + print(f"\n[autowire-real] jwt_length={len(token)} " f"principal={principal} expires_in={int(expiry - now)}s") + + +def test_create_handler_jwt_length_matches_cross_language(): + """Cross-language byte-exact JWT length parity check. + + Different roles produce different JWT lengths because AWS STS bakes + role-attached ``principal_tags`` into the payload (e.g. provisioning- + system tags like ``divvy_owner`` add ~70 bytes). The 1256-byte length + seen on Go/.NET/JS/librdkafka in earlier validation was specific to + one tag-less role. + + This test only runs when ``EXPECTED_TOKEN_LEN`` is explicitly set — + its purpose is to confirm byte-for-byte parity against a baseline you've + measured by running any sibling client (Go / .NET / JS) against the + same role + audience. + """ + expected_str = os.environ.get("EXPECTED_TOKEN_LEN") + if not expected_str: + pytest.skip( + "EXPECTED_TOKEN_LEN not set — cross-language byte-exact parity " + "check only runs when you supply a known-good baseline. Run a " + "sibling client (Go/.NET/JS) against the same role+audience, " + "note the JWT length, and set EXPECTED_TOKEN_LEN to that value." + ) + + expected = int(expected_str) + handler = create_handler(_default_config(), None) + token, *_ = handler("") + assert len(token) == expected, f"expected {expected} bytes (cross-language match), got {len(token)}" + + +def test_create_handler_principal_matches_role_arn(): + """JWT 'sub' claim is the bare role ARN. AWS STS GetWebIdentityToken + convention (no session prefix on bare role tokens).""" + handler = create_handler(_default_config(), None) + _, _, principal, _ = handler("") + assert principal.startswith("arn:aws:iam::"), f"not an IAM role ARN: {principal!r}" + assert ":role/" in principal, f"not a role ARN: {principal!r}" + + +def test_create_handler_repeats_mint_distinct_tokens(): + """Provider does not cache; STS mints a fresh JWT each invocation.""" + handler = create_handler(_default_config(), None) + _, exp1, _, _ = handler("") + time.sleep(1.0) + _, exp2, _, _ = handler("") + assert exp2 > exp1, "second token's expiry must be later than the first" + + +def test_create_handler_honours_duration_seconds(): + """Custom duration_seconds=60 produces an expiry ~1 min out, not the 300s + default. + + The shared test role's IAM policy caps DurationSeconds at 300 (per .NET + PR validation §2g and the JS hybrid plan). So we go *below* the default + rather than above to prove the parameter is honored end-to-end. Same + value JS uses for the same reason. + """ + handler = create_handler(_default_config(duration_seconds="60"), None) + _, expiry, _, _ = handler("") + seconds_remaining = expiry - time.time() + assert 50 <= seconds_remaining <= 80, f"duration_seconds=60 → expected ~60s window, got {seconds_remaining:.0f}s" + + +def test_create_handler_honours_signing_algorithm_rs256(): + """signing_algorithm=RS256 produces a JWT with alg=RS256 in the header. + + Cross-language parity check: the .NET test asserts the same header.alg + round-trip. Catches any SDK-side algorithm-shaping surprises. + """ + handler = create_handler(_default_config(signing_algorithm="RS256"), None) + token, *_ = handler("") + # Decode the JWT header (first segment). + header_b64url = token.split(".")[0] + padding = "=" * (-len(header_b64url) % 4) + header_json = json.loads( + base64.b64decode((header_b64url + padding).replace("-", "+").replace("_", "/")).decode("utf-8") + ) + assert header_json.get("alg") == "RS256", f"header.alg expected RS256, got {header_json.get('alg')!r}" + + +def test_create_handler_round_trips_sasl_extensions(): + """The typed sasl.oauthbearer.extensions property is parsed separately + (comma-separated) and surfaces in the returned extensions dict. + + Note: extensions are forwarded to the broker — not to STS — so they + don't appear in the JWT. This test asserts the dispatcher-to-tuple + round-trip, not the JWT payload.""" + handler = create_handler( + _default_config(), + "logicalCluster=lkc-abc,identityPoolId=pool-xyz", + ) + _, _, _, extensions = handler("") + assert extensions == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-xyz", + } + + +def test_create_handler_tag_claims_flow_to_sts(): + """tag_= in the wire grammar flows into the STS Tags + parameter without breaking the request. + + The primary evidence is **token minting succeeding** — if our Tags + parameter caused STS to reject the request (e.g. malformed shape or + over the 50-tag cap), we'd see a ClientError exception. Token returned + cleanly means dispatcher → provider → boto3 → AWS STS handoff worked. + + Whether AWS STS surfaces our custom tags in the JWT payload's + ``principal_tags`` claim is a SEPARATE question depending on the role's + IAM trust policy — specifically whether ``sts:TagSession`` is allowed. + Without that permission, STS silently drops user tags from the JWT. + The diagnostic prints below show what STS actually embedded so users + can spot the case where tags are silently dropped. + + Set ``ASSERT_TAGS_IN_JWT=1`` to require that our tags appear in the + JWT — only meaningful on roles with ``sts:TagSession`` allowed. + """ + handler = create_handler( + _default_config(tag_team="platform", tag_environment="prod"), + None, + ) + # If our Tags param breaks the STS call, this raises. + token, *_ = handler("") + payload = _decode_jwt_payload(token) + + # AWS STS surfaces tags either under "principal_tags" at the JWT root + # OR nested under "https://sts.amazonaws.com/" → "principal_tags". + found_tags = {} + if "principal_tags" in payload: + found_tags.update(payload["principal_tags"]) + for nested_value in payload.values(): + if isinstance(nested_value, dict) and "principal_tags" in nested_value: + found_tags.update(nested_value["principal_tags"]) + + our_tags_present = "team" in found_tags and "environment" in found_tags + print(f"\n[autowire-real] JWT payload keys: {list(payload.keys())}") + print(f"[autowire-real] tags surfaced in JWT: {found_tags}") + print( + f"[autowire-real] our injected tags appear in JWT: {our_tags_present}" + + ("" if our_tags_present else " (role likely missing sts:TagSession permission)") + ) + + # Token minted successfully — Tags param didn't cause STS rejection. + assert len(token) > 100, "Token should be a valid JWT" + + # Strict assertion only when the user explicitly opts in via env var. + if os.environ.get("ASSERT_TAGS_IN_JWT") == "1": + assert our_tags_present, ( + f"ASSERT_TAGS_IN_JWT=1 set but our tags aren't in the JWT. " + f"Role likely missing sts:TagSession permission. " + f"Found tags: {found_tags}" + ) + + +def test_create_handler_jwt_audience_matches_request(): + """The JWT payload's 'aud' claim matches the audience we asked AWS for. + + Defensive — catches any audience-shaping surprises (AWS should not munge + the audience string). Mirrors the .NET integration test's assertion. + """ + handler = create_handler(_default_config(), None) + token, *_ = handler("") + payload = _decode_jwt_payload(token) + assert payload.get("aud") == _AUDIENCE, f"JWT aud {payload.get('aud')!r} doesn't match requested {_AUDIENCE!r}" + + +# ============================================================================= +# Full dispatcher flow — Producer with marker +# ============================================================================= + + +def test_producer_with_marker_succeeds_against_real_sts(): + """End-to-end: construct a Producer with the marker set; the C dispatcher + in common_conf_setup() invokes create_handler() which builds a real + provider; librdkafka's background thread invokes the autowired callback; + STS mints a real token; rd_kafka_oauthbearer_set_token marks the token + as set; wait_for_oauth_token_set returns successfully and the Producer + constructor returns. + + The bootstrap.servers points at a non-resolvable host so we don't waste + a real broker connection — the OAUTHBEARER refresh path doesn't need + one. It fires immediately on background-thread startup. + """ + from confluent_kafka import Producer + + t0 = time.time() + Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": _default_config(), + } + ) + elapsed = time.time() - t0 + + # Producer construction returns once the autowire callback has set a + # real token (typically <2s on EC2). wait_for_oauth_token_set timeout + # is 10s — anything over that means the autowire chain didn't fire. + assert elapsed < 5.0, ( + f"Producer construction took {elapsed:.1f}s — autowire path slow " + f"or broken (10s timeout = OAuth refresh never set the token)" + ) + print(f"\n[autowire-real] Producer constructed via dispatcher in {elapsed:.2f}s") + + +def test_producer_with_marker_and_extensions_succeeds(): + """Same end-to-end path with sasl.oauthbearer.extensions set on the + typed property. Proves the C dispatcher reads the extensions string + and passes it as the second arg to create_handler.""" + from confluent_kafka import Producer + + t0 = time.time() + Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": _default_config(), + "sasl.oauthbearer.extensions": "logicalCluster=lkc-abc", + } + ) + elapsed = time.time() - t0 + assert elapsed < 5.0, f"Producer with extensions took {elapsed:.1f}s — autowire path " f"slow or broken" + print(f"\n[autowire-real] Producer (with extensions) constructed via " f"dispatcher in {elapsed:.2f}s") diff --git a/tests/oauthbearer/aws/test_sasl_extensions_parser.py b/tests/oauthbearer/aws/test_sasl_extensions_parser.py new file mode 100644 index 000000000..923463c5e --- /dev/null +++ b/tests/oauthbearer/aws/test_sasl_extensions_parser.py @@ -0,0 +1,86 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._sasl_extensions_parser.""" + +import pytest + +from confluent_kafka.oauthbearer.aws._sasl_extensions_parser import parse + +# ---- Null / empty input → None ---- + + +def test_null_raw_returns_none(): + assert parse(None) is None + + +def test_empty_string_returns_none(): + assert parse("") is None + + +# ---- Happy paths ---- + + +def test_single_entry_returns_one_item(): + result = parse("logicalCluster=lkc-abc") + assert result == {"logicalCluster": "lkc-abc"} + + +def test_multiple_entries_returns_all(): + result = parse("logicalCluster=lkc-abc,identityPoolId=pool-x") + assert result == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +def test_whitespace_around_commas_trimmed_and_parsed(): + # Each comma-delimited entry is trimmed before parsing. + result = parse(" logicalCluster=lkc-abc , identityPoolId=pool-x ") + assert result == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +def test_empty_entries_tolerated(): + result = parse("logicalCluster=lkc-abc,,identityPoolId=pool-x,") + assert result == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +def test_empty_value_accepted(): + # RFC 7628 SASL extensions allow empty values; mirror that. + result = parse("logicalCluster=") + assert result == {"logicalCluster": ""} + + +def test_duplicate_key_last_wins(): + result = parse("k=a,k=b") + assert result == {"k": "b"} + + +# ---- Malformed input → throws ---- + + +def test_missing_equals_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + parse("noEqualsHere") + + +def test_empty_key_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + parse("=value") diff --git a/tests/soak/setup_all_versions.py b/tests/soak/setup_all_versions.py index 61c4dcc11..d4d0da533 100755 --- a/tests/soak/setup_all_versions.py +++ b/tests/soak/setup_all_versions.py @@ -5,7 +5,7 @@ PYTHON_SOAK_TEST_BRANCH = 'master' LIBRDKAFKA_VERSIONS = [ - '2.14.2', + '2.14.2-aws-iam-dev', '2.13.2', '2.13.0', '2.12.1', @@ -22,7 +22,7 @@ ] PYTHON_VERSIONS = [ - '2.14.2', + '2.14.2.dev5', '2.13.2', '2.13.0', '2.12.1', diff --git a/tools/source-package-verification.sh b/tools/source-package-verification.sh index 713153f5b..68304550d 100755 --- a/tools/source-package-verification.sh +++ b/tools/source-package-verification.sh @@ -77,7 +77,7 @@ if [[ $OS_NAME == linux && $ARCH == x64 ]]; then python3 -c " import importlib.metadata, sys extras = set(importlib.metadata.metadata('confluent-kafka').get_all('Provides-Extra') or []) -required = {'schema-registry', 'schemaregistry', 'avro', 'json', 'protobuf', 'rules'} +required = {'schema-registry', 'schemaregistry', 'avro', 'json', 'protobuf', 'rules', 'oauthbearer-aws'} missing = required - extras if missing: print(f'Failing: package does not provide extras: {missing}', file=sys.stderr)