Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 194 additions & 5 deletions packages/asgardeo-ai/src/asgardeo_ai/agent_auth_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
Copyright (c) 2025-2026, WSO2 LLC. (https://www.wso2.com).
WSO2 LLC. licenses this file to you under the Apache License,
Version 2.0 (the "License"); you may not use this file except
in compliance with the License.
Expand All @@ -20,7 +20,7 @@
import base64
import os
import time
from typing import Callable, Dict, List, Optional, Tuple, Any
from typing import Callable, Dict, List, Literal, Optional, Tuple, Any
from urllib.parse import urlencode
from dataclasses import dataclass

Expand All @@ -43,6 +43,8 @@

logger = logging.getLogger(__name__)

OrgDiscoveryType = Literal["orgID", "orgHandle", "org", "emailDomain"]


@dataclass
class AgentConfig:
Expand Down Expand Up @@ -144,6 +146,35 @@ async def get_agent_token(self, scopes: Optional[List[str]] = None) -> OAuthToke
logger.error(f"Agent authentication failed: {e}")
raise AuthenticationError(f"Agent authentication failed: {e}")

async def get_organization_agent_token(
self,
switching_organization: str,
agent_scopes: Optional[List[str]] = None,
org_scopes: Optional[List[str]] = None
) -> OAuthToken:
"""Get access token for the AI agent and switch it to a sub-organization.

:param switching_organization: The ID or UUID of the target organization.
:param agent_scopes: Optional list of OAuth2 scopes to request for the initial agent token.
:param org_scopes: Optional list of OAuth2 scopes to request for the switched token.
:return: OAuth2 token for the switched organization.
"""
if not switching_organization:
raise ValidationError("switching_organization is required.")

# 1. Get agent token.
agent_token = await self.get_agent_token(scopes=agent_scopes)

if not agent_token or not agent_token.access_token:
raise TokenError("Failed to obtain a valid agent access token.")

# 2. Switch token to organization.
return await self.switch_token_to_organization(
token=agent_token.access_token,
switching_organization=switching_organization,
scopes=org_scopes
)

def get_authorization_url(
self,
scopes: List[str],
Expand Down Expand Up @@ -228,6 +259,129 @@ def get_authorization_url_with_pkce(
)
return auth_url, state, code_verifier

def _build_org_discovery_params(self, org_discovery_type: OrgDiscoveryType, discovery_value: str) -> dict:
match org_discovery_type:
case "orgID":
return {"orgId": discovery_value}
case "orgHandle":
return {"orgHandle": discovery_value}
case "org":
return {"org": discovery_value}
case "emailDomain":
return {"login_hint": discovery_value, "orgDiscoveryType": "emailDomain"}
case _:
raise ValidationError(f"Unsupported org_discovery_type: {org_discovery_type}")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

def get_org_authorization_url(
self,
scopes: List[str],
org_discovery_type: OrgDiscoveryType,
discovery_value: str,
Comment thread
HasiniSama marked this conversation as resolved.
Outdated
state: Optional[str] = None,
resource: Optional[str] = None,
isEnhancedOrgAuth: Optional[bool] = False,
**kwargs: Any,
) -> Tuple[str, str]:
"""Generate authorization URL for organization-specific user authentication.

:param scopes: List of OAuth2 scopes to request
:param org_discovery_type: The type of organization discovery ('orgID', 'orgHandle', 'org', 'emailDomain')
:param discovery_value: The identifier whose meaning depends on ``org_discovery_type``:
``"orgID"`` → organization UUID, ``"orgHandle"`` → org handle slug,
``"org"`` → org name, ``"emailDomain"`` → user email address used as login hint.
:param state: Optional state parameter (generated if not provided)
:param resource: Optional resource parameter
:param isEnhancedOrgAuth: If true, omits the fidp=OrganizationSSO parameter
:param kwargs: Additional parameters for the authorization URL
:return: Tuple of (authorization_url, state)
"""
if not state:
state = generate_state()

auth_params = {
"client_id": self.config.client_id,
"redirect_uri": self.config.redirect_uri,
"scope": " ".join(scopes),
"state": state,
"response_type": "code",
}

if not isEnhancedOrgAuth:
auth_params["fidp"] = "OrganizationSSO"

auth_params.update(self._build_org_discovery_params(org_discovery_type, discovery_value))

if resource:
auth_params["resource"] = resource

if self.agent_config:
auth_params["requested_actor"] = self.agent_config.agent_id

auth_params.update(kwargs)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

auth_url = build_authorization_url(
f"{self.config.base_url}/oauth2/authorize",
auth_params
)
return auth_url, state

def get_org_authorization_url_with_pkce(
self,
scopes: List[str],
org_discovery_type: OrgDiscoveryType,
discovery_value: str,
state: Optional[str] = None,
resource: Optional[str] = None,
isEnhancedOrgAuth: Optional[bool] = False,
**kwargs: Any,
) -> Tuple[str, str, str]:
"""Generate authorization URL for organization-specific user authentication with PKCE.

:param scopes: List of OAuth2 scopes to request
:param org_discovery_type: The type of organization discovery ('orgID', 'orgHandle', 'org', 'emailDomain')
:param discovery_value: The identifier whose meaning depends on ``org_discovery_type``:
``"orgID"`` → organization UUID, ``"orgHandle"`` → org handle slug,
``"org"`` → org name, ``"emailDomain"`` → user email address used as login hint.
:param state: Optional state parameter (generated if not provided)
:param resource: Optional resource parameter
:param isEnhancedOrgAuth: If true, omits the fidp=OrganizationSSO parameter
:param kwargs: Additional parameters for the authorization URL
:return: Tuple of (authorization_url, state, code_verifier)
"""
if not state:
state = generate_state()

code_verifier, code_challenge = generate_pkce_pair()

auth_params = {
"client_id": self.config.client_id,
"redirect_uri": self.config.redirect_uri,
"scope": " ".join(scopes),
"state": state,
"response_type": "code",
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}

if not isEnhancedOrgAuth:
auth_params["fidp"] = "OrganizationSSO"

auth_params.update(self._build_org_discovery_params(org_discovery_type, discovery_value))

if resource:
auth_params["resource"] = resource

if self.agent_config:
auth_params["requested_actor"] = self.agent_config.agent_id

auth_params.update(kwargs)

auth_url = build_authorization_url(
f"{self.config.base_url}/oauth2/authorize",
auth_params
)
return auth_url, state, code_verifier

async def get_obo_token(
self,
auth_code: str,
Expand Down Expand Up @@ -339,7 +493,7 @@ async def get_obo_token_with_ciba(

:param login_hint: Username or identifier of the user to authenticate
:param agent_token: The agent's OAuthToken (used as actor_token for delegation)
:param scopes: List of OAuth scopes to request
:param scopes: List of OAuth2 scopes to request
:param binding_message: Message displayed to the user during authentication
:param notification_channel: Notification channel (email, sms, external)
:param timeout: Maximum time to wait for authentication in seconds
Expand Down Expand Up @@ -385,8 +539,43 @@ async def get_obo_token_with_ciba(
except (CIBAAuthenticationError, ValidationError):
raise
except Exception as e:
logger.error(f"CIBA OBO token exchange failed: {e}")
raise TokenError(f"CIBA OBO token exchange failed: {e}")
logger.error(f"CIBA OBO token exchange failed: {e}", exc_info=True)
raise TokenError(f"CIBA OBO token exchange failed: {e}") from e

async def switch_token_to_organization(
self,
token: str,
switching_organization: str,
scopes: Optional[List[str]] = None
) -> OAuthToken:
"""Switch token to a sub-organization.

:param token: The current access token to be switched.
:param switching_organization: The ID or UUID of the target organization.
:param scopes: Optional list of scopes to request.
:return: OAuth2 token for the switched organization.
"""
if not token:
raise ValidationError("Token is required for organization switch.")
if not switching_organization:
raise ValidationError("switching_organization is required.")

scope_str = ' '.join(scopes) if scopes else None

try:
switched_token = await self.token_client.get_token(
'organization_switch',
token=token,
switching_organization=switching_organization,
scope=scope_str
)
return switched_token

except (TokenError, ValidationError):
raise
except Exception as e:
logger.error(f"Organization switch failed: {e}", exc_info=True)
raise TokenError(f"Organization switch failed: {e}") from e

async def revoke_token(
self,
Expand Down
14 changes: 13 additions & 1 deletion packages/asgardeo/src/asgardeo/auth/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

"""
Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
Copyright (c) 2025-2026, WSO2 LLC. (https://www.wso2.com).
WSO2 LLC. licenses this file to you under the Apache License,
Version 2.0 (the "License"); you may not use this file except
in compliance with the License.
Expand Down Expand Up @@ -327,6 +327,18 @@ async def get_token(self, grant_type: str, **kwargs: Any) -> OAuthToken:
scope = kwargs.get("scope")
if scope:
data["scope"] = scope
elif grant_type == "organization_switch":
token = kwargs.get("token")
switching_organization = kwargs.get("switching_organization")
if not token or not switching_organization:
raise ValidationError(
"token and switching_organization are required for 'organization_switch' grant type.",
)
data["token"] = token
data["switching_organization"] = switching_organization
scope = kwargs.get("scope")
if scope:
data["scope"] = scope
else:
raise ValidationError(f"Unsupported grant type: {grant_type}")

Expand Down
Loading