Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions diracx-core/src/diracx/core/config/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import asyncio
import logging
import os
from abc import ABCMeta, abstractmethod
from datetime import datetime, timezone
from pathlib import Path
Expand Down Expand Up @@ -162,7 +161,12 @@ def __init_subclass__(cls) -> None:

@classmethod
def create(cls):
return cls.create_from_url(backend_url=os.environ["DIRACX_CONFIG_BACKEND_URL"])
# Avoid circular import
from diracx.core.settings import FactorySettings

return cls.create_from_url(
backend_url=FactorySettings().diracx_config_backend_url
)

@classmethod
def create_from_url(
Expand Down
105 changes: 105 additions & 0 deletions diracx-core/src/diracx/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

import contextlib
import json
import os
from collections.abc import AsyncIterator
from pathlib import Path
from typing import Annotated, Any, Self, TypeVar, cast

import dotenv
from cryptography.fernet import Fernet
from joserfc.jwk import KeySet, KeySetSerialization
from pydantic import (
Expand All @@ -29,14 +31,18 @@
SecretStr,
TypeAdapter,
UrlConstraints,
field_validator,
model_validator,
)
from pydantic_settings import BaseSettings, SettingsConfigDict
from signurlarity.aio.client import AsyncClient
from signurlarity.exceptions import SignurlarityError

from .config.sources import ConfigSourceUrl
from .extensions import DiracEntryPoint, select_from_extension
from .properties import SecurityProperty
from .s3 import s3_bucket_exists
from .utils import dotenv_files_from_environment

T = TypeVar("T")

Expand Down Expand Up @@ -350,3 +356,102 @@ def s3_client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError("S3 client accessed before lifetime function")
return self._client


def _parse_env_bool(value: str) -> bool:
"""Parse a boolean environment variable value."""
return TypeAdapter(bool).validate_python(value)


class FactorySettings(ServiceSettingsBase):
"""Factory settings.

Settings which do not fit into dedicated classes,
or are dynamically generated.
"""

model_config = SettingsConfigDict(use_attribute_docstrings=True)

diracx_config_backend_url: ConfigSourceUrl | None = None
"""The URL of the configuration backend.
"""

diracx_legacy_exchange_hashed_api_key: str = ""
"""The hashed API key for the legacy exchange endpoint.
"""

diracx_tasks_redis_url: str = "redis://localhost"
"""The url for the redis server to manage tasks"""

enabled_services: dict[str, bool] = Field(default_factory=dict)
"""The following environment variables dictates which routers are enabled."""

opensearch_dbs: dict[str, str] = Field(default_factory=dict)
"""The following environment variables configure the OpenSearch database connections."""

sql_dbs: dict[str, str] = Field(default_factory=dict)
"""The following environment variables configure the SQL database connections."""

@model_validator(mode="before")
@classmethod
def load_dotenv_files(cls, data: Any) -> Any:
"""Load dotenv files before reading settings from environment."""
for env_file in dotenv_files_from_environment("DIRACX_SERVICE_DOTENV"):
if not dotenv.load_dotenv(env_file):
raise NotImplementedError(f"Could not load dotenv file {env_file}")
return data

@field_validator("enabled_services", mode="before")
@classmethod
def build_enabled_services(cls, value: Any) -> dict[str, bool]:
"""Build enabled services from the installed service entry points."""
enabled_services: dict[str, bool] = {
entry_point.name: True
for entry_point in select_from_extension(group=DiracEntryPoint.SERVICES)
if "well-known" not in entry_point.name
}

for service_name in enabled_services:
env_name = f"DIRACX_SERVICE_{service_name.upper()}_ENABLED"
if env_value := os.environ.get(env_name):
enabled_services[service_name] = _parse_env_bool(env_value)

if isinstance(value, dict):
enabled_services.update(value)
return enabled_services

@field_validator("opensearch_dbs", mode="before")
@classmethod
def build_opensearch_dbs(cls, value: Any) -> dict[str, str]:
"""Build OpenSearch database URLs from the installed entry points."""
opensearch_dbs: dict[str, str] = {
entry_point.name: ""
for entry_point in select_from_extension(group=DiracEntryPoint.OS_DB)
}

for db_name in opensearch_dbs:
env_name = f"DIRACX_OS_DB_{db_name.upper()}"
if env_value := os.environ.get(env_name):
opensearch_dbs[db_name] = env_value

if isinstance(value, dict):
opensearch_dbs.update(value)
return opensearch_dbs

@field_validator("sql_dbs", mode="before")
@classmethod
def build_sql_dbs(cls, value: Any) -> dict[str, str]:
"""Build SQL database URLs from the installed entry points."""
sql_dbs: dict[str, str] = {
entry_point.name: ""
for entry_point in select_from_extension(group=DiracEntryPoint.SQL_DB)
}

for db_name in sql_dbs:
env_name = f"DIRACX_DB_URL_{db_name.upper()}"
if env_value := os.environ.get(env_name):
sql_dbs[db_name] = env_value

if isinstance(value, dict):
sql_dbs.update(value)
return sql_dbs
28 changes: 18 additions & 10 deletions diracx-db/src/diracx/db/os/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
import json
import logging
import os
from abc import ABCMeta, abstractmethod
from collections.abc import AsyncIterator
from contextvars import ContextVar
Expand All @@ -14,6 +13,7 @@

from diracx.core.exceptions import InvalidQueryError
from diracx.core.extensions import DiracEntryPoint, select_from_extension
from diracx.core.settings import FactorySettings
from diracx.db.exceptions import DBUnavailableError

logger = logging.getLogger(__name__)
Expand All @@ -38,7 +38,8 @@ class BaseOSDB(metaclass=ABCMeta):
This method returns a dictionary of database names to connection parameters.
The available databases are determined by the `diracx.dbs.os` entrypoint in
the `pyproject.toml` file and the connection parameters are taken from the
environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`.
`opensearch_dbs` field in FactorySettings, which reads from environment variables
prefixed with `DIRACX_OS_DB_{DB_NAME}`.

If extensions to DiracX are being used, there can be multiple implementations
of the same database. To list the available implementations use
Expand Down Expand Up @@ -104,19 +105,26 @@ def available_implementations(cls, db_name: str) -> list[type[BaseOSDB]]:
def available_urls(cls) -> dict[str, dict[str, Any]]:
"""Return a dict of available OpenSearch database urls.

The list of available URLs is determined by environment variables
The list of available URLs is determined by the opensearch_dbs field
in FactorySettings, which reads from environment variables
prefixed with ``DIRACX_OS_DB_{DB_NAME}``.
"""
factory_settings = FactorySettings()
opensearch_dbs = factory_settings.opensearch_dbs

conn_kwargs: dict[str, dict[str, Any]] = {}
for entry_point in select_from_extension(group=DiracEntryPoint.OS_DB):
db_name = entry_point.name
var_name = f"DIRACX_OS_DB_{entry_point.name.upper()}"
if var_name in os.environ:
try:
conn_kwargs[db_name] = json.loads(os.environ[var_name])
except Exception:
logger.error("Error loading connection parameters for %s", db_name)
raise
# Get the field value from the OpenSearchDBSettings model
if field_value := opensearch_dbs.get(db_name):
if field_value:
try:
conn_kwargs[db_name] = json.loads(field_value)
except Exception:
logger.error(
"Error loading connection parameters for %s", db_name
)
raise
return conn_kwargs

@classmethod
Expand Down
19 changes: 12 additions & 7 deletions diracx-db/src/diracx/db/sql/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import contextlib
import logging
import os
import re
from abc import ABCMeta
from collections.abc import AsyncIterator
Expand Down Expand Up @@ -53,8 +52,9 @@ class BaseSQLDB(metaclass=ABCMeta):
The available databases are discovered by calling `BaseSQLDB.available_urls`.
This method returns a mapping of database names to connection URLs. The
available databases are determined by the `diracx.dbs.sql` entrypoint in the
`pyproject.toml` file and the connection URLs are taken from the environment
variables of the form `DIRACX_DB_URL_<db-name>`.
`pyproject.toml` file and the connection URLs are taken from the
`sql_dbs` field in FactorySettings, which reads from environment variables
of the form `DIRACX_DB_URL_<db-name>`.

If extensions to DiracX are being used, there can be multiple implementations
of the same database. To list the available implementations use
Expand Down Expand Up @@ -125,16 +125,21 @@ def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]:
def available_urls(cls) -> dict[str, str]:
"""Return a dict of available database urls.

The list of available URLs is determined by environment variables
The list of available URLs is determined by the sql_dbs field
in FactorySettings, which reads from environment variables
prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
"""
from diracx.core.settings import FactorySettings

factory_settings = FactorySettings()
sql_dbs = factory_settings.sql_dbs

db_urls: dict[str, str] = {}
for entry_point in select_from_extension(group=DiracEntryPoint.SQL_DB):
db_name = entry_point.name
var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
if var_name in os.environ:
# Get the field value from the SqlDBSettings model
if db_url := sql_dbs.get(db_name):
try:
db_url = os.environ[var_name]
if db_url == "sqlite+aiosqlite:///:memory:":
db_urls[db_name] = db_url
# pydantic does not allow for underscore in scheme
Expand Down
7 changes: 4 additions & 3 deletions diracx-logic/src/diracx/logic/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ async def delete_jwk(args):
async def cleanup_authdb(args):
"""Maintain AuthDB partitions and remove expired flows."""
logger.info("Maintaining AuthDB partitions and removing expired flows")
import os

from diracx.core.settings import AuthSettings
from diracx.core.settings import AuthSettings, FactorySettings
from diracx.db.sql import AuthDB
from diracx.logic.auth.management import cleanup_expired_data

settings = AuthSettings()
db_url = os.environ["DIRACX_DB_URL_AUTHDB"]
factory_settings = FactorySettings()
db_url = factory_settings.sql_dbs.AuthDB

db = AuthDB(db_url)
async with db.engine_context():
async with db:
Expand Down
8 changes: 3 additions & 5 deletions diracx-routers/src/diracx/routers/auth/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import logging
import os
from http import HTTPStatus
from typing import Annotated, Literal

Expand All @@ -21,7 +20,7 @@
RefreshTokenPayload,
TokenResponse,
)
from diracx.core.settings import AuthSettings
from diracx.core.settings import AuthSettings, FactorySettings
from diracx.db.sql import AuthDB
from diracx.logic.auth import create_token
from diracx.logic.auth import get_oidc_token as get_oidc_token_bl
Expand Down Expand Up @@ -182,6 +181,7 @@ async def perform_legacy_exchange(
auth_db: AuthDB,
available_properties: AvailableSecurityProperties,
settings: AuthSettings,
factory_settings: FactorySettings,
config: Config,
all_access_policies: Annotated[
dict[str, BaseAccessPolicy], Depends(BaseAccessPolicy.all_used_access_policies)
Expand All @@ -193,9 +193,7 @@ async def perform_legacy_exchange(
This route is disabled if DIRACX_LEGACY_EXCHANGE_HASHED_API_KEY is not set
in the environment.
"""
if not (
expected_api_key := os.environ.get("DIRACX_LEGACY_EXCHANGE_HASHED_API_KEY")
):
if not (expected_api_key := factory_settings.diracx_legacy_exchange_hashed_api_key):
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail="Legacy exchange is not enabled",
Expand Down
17 changes: 4 additions & 13 deletions diracx-routers/src/diracx/routers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@

import inspect
import logging
import os
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
from functools import partial
from http import HTTPStatus
from importlib.metadata import EntryPoint, EntryPoints, entry_points
from logging import Formatter, StreamHandler
from typing import Any, TypeVar, cast

import dotenv
from cachetools import TTLCache
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.dependencies.models import Dependant
Expand All @@ -24,15 +22,13 @@
from fastapi.responses import JSONResponse, Response
from fastapi.routing import APIRoute
from packaging.version import InvalidVersion, parse
from pydantic import TypeAdapter
from starlette.middleware.base import BaseHTTPMiddleware
from uvicorn.logging import AccessFormatter, DefaultFormatter

from diracx.core.config import ConfigSource
from diracx.core.exceptions import DiracError, DiracHttpResponseError, NotReadyError
from diracx.core.extensions import DiracEntryPoint, select_from_extension
from diracx.core.settings import ServiceSettingsBase
from diracx.core.utils import dotenv_files_from_environment
from diracx.core.settings import FactorySettings, ServiceSettingsBase
from diracx.db.exceptions import DBUnavailableError
from diracx.db.os.utils import BaseOSDB
from diracx.db.sql.utils import BaseSQLDB
Expand Down Expand Up @@ -143,7 +139,6 @@ def create_app_inner(
# Please see ServiceSettingsBase for more details

available_settings_classes: set[type[ServiceSettingsBase]] = set()

for service_settings in all_service_settings:
cls = type(service_settings)
assert cls not in available_settings_classes
Expand Down Expand Up @@ -346,17 +341,12 @@ def create_app() -> DiracFastAPI:
We attempt to load each setting classes to make sure that the
settings are correctly defined.
"""
for env_file in dotenv_files_from_environment("DIRACX_SERVICE_DOTENV"):
logger.debug("Loading dotenv file: %s", env_file)
if not dotenv.load_dotenv(env_file):
raise NotImplementedError(f"Could not load dotenv file {env_file}")

# Load all available routers
enabled_systems = set()
settings_classes = set()
factory_settings = FactorySettings()
for entry_point in select_from_extension(group=DiracEntryPoint.SERVICES):
env_var = f"DIRACX_SERVICE_{entry_point.name.upper()}_ENABLED"
enabled = TypeAdapter(bool).validate_json(os.environ.get(env_var, "true"))
enabled = factory_settings.enabled_services.get(entry_point.name, True)
logger.debug("Found service %r: enabled=%s", entry_point, enabled)
if not enabled:
continue
Expand Down Expand Up @@ -443,6 +433,7 @@ async def validation_error_handler(request: Request, exc: RequestValidationError
def find_dependents(
obj: APIRouter | Iterable[Dependant], cls: type[T]
) -> Iterable[type[T]]:

if isinstance(obj, APIRouter):
# TODO: Support dependencies of the router itself
# yield from find_dependents(obj.dependencies, cls)
Expand Down
Loading