diff --git a/ayon_api/_api.py b/ayon_api/_api.py index 42eac9827..4f0bf7f59 100644 --- a/ayon_api/_api.py +++ b/ayon_api/_api.py @@ -130,10 +130,10 @@ def login(self, username: str, password: str): login is skipped. """ - previous_token = self._access_token + previous_token = self._token_info.token super().login(username, password) - if self.has_valid_token and previous_token != self._access_token: - os.environ[SERVER_API_ENV_KEY] = self._access_token + if self.has_valid_token and previous_token != self._token_info.token: + os.environ[SERVER_API_ENV_KEY] = self._token_info.token @staticmethod def get_url(): diff --git a/ayon_api/server_api.py b/ayon_api/server_api.py index 7b60d68a7..ee16d8e86 100644 --- a/ayon_api/server_api.py +++ b/ayon_api/server_api.py @@ -6,6 +6,7 @@ from __future__ import annotations import copy +from dataclasses import dataclass import os import re import io @@ -67,6 +68,7 @@ get_media_mime_type_for_stream, get_machine_name, fill_own_attribs, + get_user_info_by_token, ) from ._api_helpers import ( InstallersAPI, @@ -187,9 +189,8 @@ def as_user(self, username: Optional[str]) -> Generator[None, None, None]: user_id = uuid.uuid4().hex self._user_ids.append(user_id) self._users_by_id[user_id] = username - try: - yield - finally: + + def _cleanup(): self._users_by_id.pop(user_id, None) if not self._user_ids: return @@ -208,6 +209,19 @@ def as_user(self, username: Optional[str]) -> Generator[None, None, None]: new_last_user = self._users_by_id.get(self._user_ids[-1]) self._last_user = new_last_user + try: + yield + finally: + _cleanup() + + +@dataclass +class TokenInfo: + token: str | None = None + is_valid: bool | None = None + is_service: bool | None = None + unauthorized_response: requests.Response | None = None + class ServerAPI( InstallersAPI, @@ -290,7 +304,7 @@ def __init__( self._rest_url: str = f"{base_url}/api" self._graphql_url: str = f"{base_url}/graphql" self._log: logging.Logger = logging.getLogger(self.__class__.__name__) - self._access_token: Optional[str] = token + # Allow to have 'site_id' to 'None' if site_id is NOT_SET: site_id = get_default_site_id() @@ -322,9 +336,8 @@ def __init__( self._ssl_verify = ssl_verify self._cert = cert - self._access_token_is_service = None - self._token_is_valid = None - self._token_validation_started = False + self._token_info = TokenInfo(token=token) + self._server_available = None self._server_version = None self._server_version_tuple = None @@ -351,7 +364,7 @@ def __init__( self._as_user_stack = _AsUserStack() # Create session - if self._access_token and create_session: + if self._token_info.token and create_session: self.validate_server_availability() self.create_session() @@ -498,7 +511,7 @@ def access_token(self) -> Optional[str]: Optional[str]: Token string or None if not authorized yet. """ - return self._access_token + return self._token_info.token def is_service_user(self) -> bool: """Check if connection is using service API key. @@ -509,7 +522,7 @@ def is_service_user(self) -> bool: """ if not self.has_valid_token: raise ValueError("User is not logged in.") - return bool(self._access_token_is_service) + return bool(self._token_info.is_service) def get_site_id(self) -> Optional[str]: """Site id used for connection. @@ -678,7 +691,7 @@ def set_default_service_username(self, username: Optional[str] = None): "Authentication of connection did not happen yet." ) - if not self._access_token_is_service: + if not self._token_info.is_service: raise ValueError( "Can't set service username. API key is not a service token." ) @@ -712,7 +725,7 @@ def as_username( "Authentication of connection did not happen yet." ) - if not self._access_token_is_service: + if not self._token_info.is_service: if ignore_service_error: yield None return @@ -733,19 +746,20 @@ def is_server_available(self) -> bool: response = requests.get( self._base_url, cert=self._cert, - verify=self._ssl_verify + verify=self._ssl_verify, + timeout=self.timeout, ) self._server_available = response.status_code == 200 return self._server_available @property def has_valid_token(self) -> bool: - if self._access_token is None: + if self._token_info.token is None: return False - if self._token_is_valid is None: + if self._token_info.is_valid is None: self.validate_token() - return self._token_is_valid + return self._token_info.is_valid def validate_server_availability(self): if not self.is_server_available: @@ -754,30 +768,45 @@ def validate_server_availability(self): ) def validate_token(self) -> bool: - try: - self._token_validation_started = True - # TODO add other possible validations - # - existence of 'user' key in info - # - validate that 'site_id' is in 'sites' in info - self.get_info() - self.get_user() - self._token_is_valid = True + if self._token_info.token is None: + self._token_info.is_valid = False + self._token_info.unauthorized_response = None + self.close_session() + return False - except UnauthorizedError: - self._token_is_valid = False + # TODO add other possible validations + # - existence of 'user' key in info + # - validate that 'site_id' is in 'sites' in info - finally: - self._token_validation_started = False - return self._token_is_valid + # Check server url + self._get_server_info() + + user_info = get_user_info_by_token( + self.base_url, + self._token_info.token, + verify=self._ssl_verify, + cert=self._cert, + timeout=self.timeout, + ) + self._token_info.is_valid = user_info.is_valid + self._token_info.unauthorized_response = user_info.response + is_service = None + if user_info.is_valid: + is_service = user_info.is_service + self._token_info.is_service = is_service + + return self._token_info.is_valid def set_token(self, token: Optional[str]): self.reset_token() - self._access_token = token - self.get_user() + self._token_info.token = token + self.validate_token() def reset_token(self): - self._access_token = None - self._token_is_valid = None + self._token_info.token = None + self._token_info.is_service = None + self._token_info.is_valid = None + self._token_info.unauthorized_response = None self.close_session() def create_session( @@ -861,7 +890,15 @@ def get_info(self) -> dict[str, Any]: dict[str, Any]: Information from server. """ - response = self.get("info") + handle_invalid_token = ( + self._token_info.token + and self._token_info.is_valid + ) + + response = self.raw_get( + "info", + handle_invalid_token=handle_invalid_token, + ) response.raise_for_status() return response.data @@ -934,29 +971,6 @@ def links_graphql_support_data(self) -> bool: ) return self._links_graphql_support_data - def _get_user_info(self) -> Optional[dict[str, Any]]: - if self._access_token is None: - return None - - if self._access_token_is_service is not None: - response = self.get("users/me") - if response.status == 200: - return response.data - return None - - self._access_token_is_service = False - response = self.get("users/me") - if response.status == 200: - return response.data - - self._access_token_is_service = True - response = self.get("users/me") - if response.status == 200: - return response.data - - self._access_token_is_service = None - return None - def get_users( self, project_name: Optional[str] = None, @@ -1129,14 +1143,14 @@ def get_headers( if self._sender is not None: headers["x-sender"] = self._sender - if self._access_token: - if self._access_token_is_service: - headers["X-Api-Key"] = self._access_token + if self._token_info.token and self._token_info.is_valid is not False: + if self._token_info.is_service: + headers["X-Api-Key"] = self._token_info.token username = self._as_user_stack.username if username: headers["X-as-user"] = username else: - headers["Authorization"] = f"Bearer {self._access_token}" + headers["Authorization"] = f"Bearer {self._token_info.token}" return headers def login( @@ -1147,7 +1161,7 @@ def login( Args: username (str): Username. password (str): Password. - create_session (Optional[bool]): Create session after login. + create_session (bool): Create session after login. Default: True. Raises: @@ -1171,26 +1185,27 @@ def login( self.validate_server_availability() - self._token_validation_started = True - - try: - response = self.post( - "auth/login", + response = self.raw_post( + "auth/login", + json=dict( name=username, - password=password - ) - if response.status_code != 200: - _detail = response.data.get("detail") - details = "" - if _detail: - details = f" {_detail}" + password=password, + ), + handle_invalid_token=False, + ) + if response.status_code != 200: + _detail = response.data.get("detail") + details = "" + if _detail: + details = f" {_detail}" - raise AuthenticationError(f"Login failed {details}") + raise AuthenticationError(f"Login failed {details}") - finally: - self._token_validation_started = False - - self._access_token = response["token"] + self._token_info.token = response["token"] + # Should be valid if was just logged in + self._token_info.is_valid = True + # Service token can't be obtained by login, so it is not service token + self._token_info.is_service = False if not self.has_valid_token: raise AuthenticationError("Invalid credentials") @@ -1199,108 +1214,11 @@ def login( self.create_session() def logout(self, soft: bool = False): - if self._access_token: + if self._token_info.token: if not soft: self._logout() self.reset_token() - def _logout(self): - logout_from_server(self._base_url, self._access_token) - - def _do_rest_request(self, function, url, **kwargs): - kwargs.setdefault("timeout", self.timeout) - max_retries = kwargs.get("max_retries", self.max_retries) - if max_retries < 1: - max_retries = 1 - if self._session is None: - # Validate token if was not yet validated - # - ignore validation if we're in middle of - # validation - if ( - self._token_is_valid is None - and not self._token_validation_started - ): - self.validate_token() - - if "headers" not in kwargs: - kwargs["headers"] = self.get_headers() - - if isinstance(function, RequestType): - function = self._base_functions_mapping[function] - - elif isinstance(function, RequestType): - function = self._session_functions_mapping[function] - - response = None - new_response = None - for retry_idx in reversed(range(max_retries)): - try: - response = function(url, **kwargs) - - # Usually these mean, try later. - # 502: returned by the proxy: nginx - # 503: returned by the server: if no capacity - if response.status_code in {502, 503}: - new_response = RestApiResponse(response) - self.log.warning( - "Server returned %s status code." - " Retrying with longer delay...", - response.status_code - ) - if retry_idx != 0: - time.sleep(2) - continue - break - - except ConnectionRefusedError: - if retry_idx == 0: - self.log.warning( - "Connection error happened.", exc_info=True - ) - - # Server may be restarting - new_response = RestApiResponse( - None, - { - "detail": ( - "Unable to connect the server. Connection refused" - ) - } - ) - - except requests.exceptions.Timeout: - # Connection timed out - new_response = RestApiResponse( - None, - {"detail": "Connection timed out."} - ) - - except requests.exceptions.ConnectionError: - # Log warning only on last attempt - if retry_idx == 0: - self.log.warning( - "Connection error happened.", exc_info=True - ) - - new_response = RestApiResponse( - None, - { - "detail": ( - "Unable to connect the server. Connection error" - ) - } - ) - - if retry_idx != 0: - time.sleep(0.1) - - if new_response is not None: - return new_response - - new_response = RestApiResponse(response) - self.log.debug(f"Response {str(new_response)}") - return new_response - def raw_post(self, entrypoint: str, **kwargs): url = self._endpoint_to_url(entrypoint) self.log.debug(f"Executing [POST] {url}") @@ -1555,6 +1473,164 @@ def _endpoint_to_url( base_url = self._rest_url if use_rest else self._base_url return f"{base_url}/{endpoint}" + def _logout(self): + if self._token_info.is_valid: + logout_from_server(self._base_url, self._token_info.token) + + def _get_server_info(self) -> dict[str, Any]: + """Get server info without a session.""" + response = self.raw_get( + "info", + handle_invalid_token=False, + ) + response.raise_for_status() + return response.data + + def _get_user_info(self) -> Optional[dict[str, Any]]: + if ( + self._token_info.token is None + or self._token_info.is_valid is False + ): + return None + + if self._token_info.is_service is None: + self.validate_token() + if self._token_info.is_valid is False: + return None + + response = self.get("users/me") + if response.status == 200: + return response.data + return None + + def _do_rest_request( + self, + function: Any, + url: str, + *, + handle_invalid_token: bool = True, + **kwargs + ): + kwargs.setdefault("timeout", self.timeout) + max_retries = kwargs.get("max_retries", self.max_retries) + if max_retries < 1: + max_retries = 1 + + if handle_invalid_token and self._token_info.is_valid is False: + # Return a fake error response if the token is known to be invalid. + # Added to prevent DDOS attack on server when many requests + # with invalid token are send. It is better to return error + # immediately without trying to send a request to server. + if self._token_info.unauthorized_response is not None: + return RestApiResponse(self._token_info.unauthorized_response) + + new_response = RestApiResponse( + None, + {"code": 401, "detail": "AYON api error: Invalid API key"} + ) + new_response.status = 401 + return new_response + + if self._session is None: + # Validate token if was not yet validated + if ( + handle_invalid_token + and self._token_info.is_valid is None + ): + self.validate_token() + + if "headers" not in kwargs: + kwargs["headers"] = self.get_headers() + + if isinstance(function, RequestType): + function = self._base_functions_mapping[function] + + elif isinstance(function, RequestType): + function = self._session_functions_mapping[function] + + response = None + new_response = None + for retry_idx in reversed(range(max_retries)): + try: + response = function(url, **kwargs) + + # Usually these mean, try later. + # 502: returned by the proxy: nginx + # 503: returned by the server: if no capacity + if response.status_code in {502, 503}: + new_response = RestApiResponse(response) + self.log.warning( + "Server returned %s status code." + " Retrying with longer delay...", + response.status_code + ) + if retry_idx != 0: + time.sleep(2) + continue + break + + except ConnectionRefusedError: + if retry_idx == 0: + self.log.warning( + "AYON api error: Connection error happened.", + exc_info=True, + ) + + # Server may be restarting + new_response = RestApiResponse( + None, + { + "detail": ( + "AYON api error: Unable to connect the server." + " Connection refused" + ) + } + ) + + except requests.exceptions.Timeout: + # Connection timed out + new_response = RestApiResponse( + None, + {"detail": "AYON api error: Connection timed out."} + ) + + except requests.exceptions.ConnectionError: + # Log warning only on last attempt + if retry_idx == 0: + self.log.warning( + "AYON api error: Connection error happened.", + exc_info=True + ) + + new_response = RestApiResponse( + None, + { + "detail": ( + "AYON api error: Unable to connect the server." + " Connection error." + ) + } + ) + + if retry_idx != 0: + time.sleep(0.1) + + if new_response is not None: + return new_response + + new_response = RestApiResponse(response) + if ( + handle_invalid_token + and new_response.status_code == 401 + and self._token_info.is_valid + ): + self._token_info.is_valid = False + self._token_info.unauthorized_response = response + self.close_session() + + self.log.debug(f"Response {str(new_response)}") + return new_response + def _download_file_to_stream( self, endpoint: str, diff --git a/ayon_api/utils.py b/ayon_api/utils.py index d2cb88570..ab2e4c7dd 100644 --- a/ayon_api/utils.py +++ b/ayon_api/utils.py @@ -1,8 +1,10 @@ from __future__ import annotations +import functools import os import re import datetime +from dataclasses import dataclass import copy import logging import json @@ -15,6 +17,7 @@ from urllib.parse import urlparse, urlencode, ParseResult import typing from typing import Any, Iterable +import warnings from enum import IntEnum import requests @@ -67,6 +70,38 @@ ) ) +@dataclass +class _TimeoutWrapInfo: + func = None + args_pos = 2 + + +def _timeout_kwarg_deprecation(arg): + """Decorator to add timeout kwarg to function.""" + # TODO remove this deprecation + wrap_info = _TimeoutWrapInfo() + + def wrapper(*args, **kwargs): + if len(args) > wrap_info.args_pos: + warnings.warn( + "Timeout was passed as a positional argument please" + " use timeout=... keyword argument instead. This will stop" + " working in future versions on ayon-api.", + category=FutureWarning, + stacklevel=2, + ) + return wrap_info.func(*args, **kwargs) + + if not isinstance(arg, int): + wrap_info.func = arg + return functools.wraps(arg)(wrapper) + + wrap_info.args_pos = arg + def main_wrapper(func): + wrap_info.func = func + return functools.wraps(func)(wrapper) + return main_wrapper + class SortOrder(IntEnum): """Sort order for GraphQl requests.""" @@ -172,11 +207,11 @@ def ok(self) -> bool: def raise_for_status(self, message=None): if self._response is None: if self._data and self._data.get("detail"): + if self.status_code == 401: + raise UnauthorizedError(self._data["detail"]) raise ServerError(self._data["detail"]) raise ValueError("Response is not available.") - if self.status_code == 401: - raise UnauthorizedError("Missing or invalid authentication token") try: self._response.raise_for_status() except requests.exceptions.HTTPError as exc: @@ -197,6 +232,8 @@ def raise_for_status(self, message=None): detail = self.data.get("detail") if detail: message = f"{message} ({detail})" + if self.status_code == 401: + raise UnauthorizedError(message, exc.response) raise HTTPRequestError(message, exc.response) def __enter__(self, *args, **kwargs): @@ -587,6 +624,7 @@ def _try_connect_to_server( return None +@_timeout_kwarg_deprecation(3) def login_to_server( url: str, username: str, @@ -628,6 +666,7 @@ def login_to_server( return token +@_timeout_kwarg_deprecation def logout_from_server( url: str, token: str, @@ -655,32 +694,57 @@ def logout_from_server( ) -def get_user_by_token( +@dataclass +class UserInfo: + """User information.""" + is_valid: bool = False + is_service: bool = False + response: requests.Response | None = None + + +def get_user_info_by_token( url: str, token: str, + *, + verify: str | bool | None = None, + cert: str | None = None, timeout: float | None = None, -) -> dict[str, Any] | None: +) -> UserInfo: """Get user information by url and token. Args: url (str): Server url. token (str): User's token. + verify (str | bool | None): SSL verification for request. Value from + 'AYON_CA_FILE' environment variable is used if not specified. + cert (str | None): SSL certificate for request. Value from + 'AYON_CERT_FILE' environment variable is used if not specified. timeout (float | None): Timeout for request. Value from 'get_default_timeout' is used if not specified. Returns: - dict[str, Any] | None: User information if url and token are valid. + UserInfo: User information if url and token are valid. """ + output = UserInfo() + if not token: + return output + if timeout is None: timeout = get_default_timeout() + if verify is None: + verify = os.environ.get("AYON_CA_FILE") or True + + if cert is None: + cert = os.environ.get("AYON_CERT_FILE") or None + base_headers = { "Content-Type": "application/json", } - for header_value in ( - {"Authorization": f"Bearer {token}"}, - {"X-Api-Key": token}, + for header_value, is_service in ( + ({"Authorization": f"Bearer {token}"}, False), + ({"X-Api-Key": token}, True), ): headers = base_headers.copy() headers.update(header_value) @@ -688,16 +752,61 @@ def get_user_by_token( f"{url}/api/users/me", headers=headers, timeout=timeout, + verify=verify, + cert=cert, + ) + + output = UserInfo( + is_valid=response.status_code == 200, + is_service=is_service, + response=response, ) - if response.status_code == 200: - return response.json() + if output.is_valid: + break + return output + + +@_timeout_kwarg_deprecation +def get_user_by_token( + url: str, + token: str, + timeout: float | None = None, + *, + verify: str | bool | None = None, + cert: str | None = None, +) -> dict[str, Any] | None: + """Get user information by url and token. + + Args: + url (str): Server url. + token (str): User's token. + timeout (float | None): Timeout for request. Value from + 'get_default_timeout' is used if not specified. + verify (str | bool | None): SSL verification for request. Value from + 'AYON_CA_FILE' environment variable is used if not specified. + cert (str | None): SSL certificate for request. Value from + 'AYON_CERT_FILE' environment variable is used if not specified. + + Returns: + dict[str, Any] | None: User information if url and token are valid. + + """ + user_info = get_user_info_by_token( + url, token, timeout=timeout, verify=verify, cert=cert, + ) + if user_info.is_valid: + return user_info.data return None +@_timeout_kwarg_deprecation def is_token_valid( url: str, token: str, timeout: float | None = None, + *, + verify: str | bool | None = None, + cert: str | None = None, ) -> bool: """Check if token is valid. @@ -708,16 +817,22 @@ def is_token_valid( token (str): User's token. timeout (float | None): Timeout for request. Value from 'get_default_timeout' is used if not specified. + verify (str | bool | None): SSL verification for request. Value from + 'AYON_CA_FILE' environment variable is used if not specified. + cert (str | None): SSL certificate for request. Value from + 'AYON_CERT_FILE' environment variable is used if not specified. Returns: bool: True if token is valid. """ - if get_user_by_token(url, token, timeout=timeout): - return True - return False + user_info = get_user_info_by_token( + url, token, timeout=timeout, verify=verify, cert=cert + ) + return user_info.is_valid +@_timeout_kwarg_deprecation(1) def validate_url( url: str, timeout: int | None = None,