Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@
"host": "0.0.0.0",
"port": 6185,
"disable_access_log": True,
"totp": {
"enable": False,
"secret": "",
"recovery_code_hash": "",
},
"ssl": {
"enable": False,
"cert_file": "",
Expand Down Expand Up @@ -4180,6 +4185,12 @@
"type": "bool",
"hint": "启用后,WebUI 将直接使用 HTTPS 提供服务。",
},
"dashboard.totp.enable": {
"description": "启用 WebUI TOTP 双因素认证",
"type": "bool",
"hint": "启用后,登录 WebUI 需要额外输入验证码。",
"_special": "dashboard_totp_manager",
},
"dashboard.ssl.cert_file": {
"description": "SSL 证书文件路径",
"type": "string",
Expand Down
15 changes: 15 additions & 0 deletions astrbot/core/db/po.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,21 @@ class ApiKey(TimestampMixin, SQLModel, table=True):
)


class DashboardTrustedDevice(TimestampMixin, SQLModel, table=True):
"""Trusted dashboard device token used to skip TOTP for a limited time."""

__tablename__: str = "dashboard_trusted_devices"

id: int | None = Field(
default=None,
primary_key=True,
sa_column_kwargs={"autoincrement": True},
)
token_hash: str = Field(max_length=64, nullable=False, unique=True, index=True)
totp_secret_hash: str = Field(max_length=64, nullable=False, index=True)
expires_at: datetime = Field(nullable=False, index=True)


class ChatUIProject(TimestampMixin, SQLModel, table=True):
"""This class represents projects for organizing ChatUI conversations.

Expand Down
210 changes: 210 additions & 0 deletions astrbot/core/utils/totp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from __future__ import annotations

import asyncio
import base64
import datetime
import hashlib
import hmac
import secrets

import pyotp
from sqlmodel import col, delete, select

from astrbot.core.db.po import DashboardTrustedDevice

TOTP_TRUSTED_DEVICE_COOKIE_NAME = "astrbot_totp_trusted_device"
TOTP_TRUSTED_DEVICE_MAX_AGE = 30 * 24 * 60 * 60
RECOVERY_CODE_GROUP_COUNT = 4
RECOVERY_CODE_GROUP_LENGTH = 8
RECOVERY_CODE_LENGTH = RECOVERY_CODE_GROUP_COUNT * RECOVERY_CODE_GROUP_LENGTH
_RECOVERY_CODE_KDF_ITERATIONS = 600_000
_RECOVERY_CODE_KDF_SALT_BYTES = 16
_RECOVERY_CODE_KDF_ALGORITHM = "pbkdf2_sha256"

_last_totp_timecode: dict[str, int] = {}
_totp_replay_lock = asyncio.Lock()


def _get_totp_config(config) -> dict:
totp_config = config.get("dashboard", {}).get("totp", {})
return totp_config if isinstance(totp_config, dict) else {}


def is_totp_enabled(config) -> bool:
Comment on lines +24 to +33
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Replay protection map _last_totp_timecode can grow unbounded over long uptime or many secrets.

Because _last_totp_timecode is keyed by the raw secret and never cleaned up, it can grow indefinitely in long-lived processes or with frequent secret rotation. Consider adding a pruning strategy (e.g., when TOTP is disabled/rotated, or by keeping only the last N secrets) or keying by a stable hash and cleaning up when it changes, to prevent unbounded memory growth while preserving replay protection.

Suggested implementation:

# Limit how many secrets we keep in memory for TOTP replay protection.
# This prevents unbounded growth over long uptimes or frequent secret rotation.
MAX_TOTP_REPLAY_ENTRIES = 10_000

# Map of hashed TOTP secrets -> last accepted timecode.
# We key by a stable hash of the secret so rotation to a new secret naturally
# evicts old entries as the map reaches MAX_TOTP_REPLAY_ENTRIES.
_last_totp_timecode: "OrderedDict[str, int]" = OrderedDict()
_totp_replay_lock = asyncio.Lock()


def _totp_replay_key(secret: str) -> str:
    """Return a stable hash key for a TOTP secret suitable for in-memory indexing."""
    # Using SHA-256 avoids keeping the raw secret as the dictionary key and gives
    # us a fixed-size identifier that works well with pruning / rotation.
    return hashlib.sha256(secret.encode("utf-8")).hexdigest()


async def _get_last_totp_timecode(secret: str) -> int | None:
    """Fetch the last accepted TOTP timecode for a secret, if any."""
    async with _totp_replay_lock:
        key = _totp_replay_key(secret)
        return _last_totp_timecode.get(key)


async def _update_last_totp_timecode(secret: str, timecode: int) -> None:
    """Update the last accepted TOTP timecode for a secret, pruning old entries."""
    async with _totp_replay_lock:
        key = _totp_replay_key(secret)

        # Update/move to the end for a simple LRU-like eviction strategy.
        _last_totp_timecode[key] = timecode
        _last_totp_timecode.move_to_end(key)

        # Prune oldest secrets if we exceed the maximum size.
        while len(_last_totp_timecode) > MAX_TOTP_REPLAY_ENTRIES:
            _last_totp_timecode.popitem(last=False)

To fully implement this change and avoid unbounded growth:

  1. Imports

    • At the top of astrbot/core/utils/totp.py, add:
      • from collections import OrderedDict
      • import hashlib
    • If typing.OrderedDict is used instead of the runtime class anywhere else, adjust the annotation accordingly:
      • from collections.abc import MutableMapping or from typing import OrderedDict depending on your typing style.
  2. Replace direct uses of _last_totp_timecode

    • Anywhere in this file where _last_totp_timecode is accessed directly with the raw secret as a key, change those usages to call the helpers:
      • Read access:
        • Replace last = _last_totp_timecode.get(secret) (or similar) with:
          • last = await _get_last_totp_timecode(secret)
      • Write/update access:
        • Replace _last_totp_timecode[secret] = timecode or mutations inside an async with _totp_replay_lock block with:
          • await _update_last_totp_timecode(secret, timecode)
    • If there is existing manual locking around _last_totp_timecode using _totp_replay_lock, remove the redundant lock usage around call sites of _get_last_totp_timecode / _update_last_totp_timecode, since the helpers already handle locking.
  3. Type annotations

    • If your project targets Python versions that do not support int | None, replace it with Optional[int] and add from typing import Optional.

These changes ensure the replay protection map is bounded in size, keyed by a stable hash instead of the raw secret, and automatically prunes old entries while preserving replay protection guarantees for recent/active secrets.

"""TOTP is fully configured and operational (enable + secret + recovery hash all present)."""
totp_config = _get_totp_config(config)
if not totp_config.get("enable", False):
return False
secret = totp_config.get("secret", "")
if not isinstance(secret, str) or not secret.strip():
return False
recovery_code_hash = totp_config.get("recovery_code_hash", "")
if not isinstance(recovery_code_hash, str) or not recovery_code_hash.strip():
return False
return True


def _get_verified_totp_timecode(secret: str, code: str) -> int | None:
code = code.strip()
try:
totp = pyotp.TOTP(secret.strip())
now = datetime.datetime.now()
Comment thread
Raven95676 marked this conversation as resolved.
Outdated
for offset in (-1, 0, 1):
candidate_time = now + datetime.timedelta(seconds=offset * totp.interval)
if hmac.compare_digest(str(totp.at(candidate_time)), code):
return int(totp.timecode(candidate_time))
except Exception:
return None
return None


async def consume_totp_code(secret: str, code: str) -> bool:
global _last_totp_timecode
timecode = _get_verified_totp_timecode(secret, code)
if timecode is None:
return False
secret = secret.strip()
async with _totp_replay_lock:
if _last_totp_timecode.get(secret, -1) >= timecode:
return False
_last_totp_timecode[secret] = timecode
return True


async def consume_configured_totp_code(config, code: str) -> bool:
if not is_totp_enabled(config):
return False
secret = _get_totp_config(config).get("secret", "")
return await consume_totp_code(secret, code)


def _hash_totp_trusted_device_token(config, token: str) -> str:
jwt_secret = config["dashboard"].get("jwt_secret", "")
if not isinstance(jwt_secret, str) or not jwt_secret:
return ""
return hmac.new(
jwt_secret.encode("utf-8"),
token.encode("utf-8"),
hashlib.sha256,
).hexdigest()


def _hash_totp_secret(config) -> str:
secret = _get_totp_config(config).get("secret", "")
if not isinstance(secret, str) or not secret.strip():
return ""
return hashlib.sha256(secret.strip().encode("utf-8")).hexdigest()


async def is_totp_trusted_device_valid(config, db, cookie_token: str) -> bool:
if not cookie_token:
return False
token_hash = _hash_totp_trusted_device_token(config, cookie_token)
totp_secret_hash = _hash_totp_secret(config)
if not token_hash or not totp_secret_hash:
return False

await _cleanup_expired_totp_trusted_devices(db)
async with db.get_db() as session:
result = await session.execute(
select(DashboardTrustedDevice).where(
col(DashboardTrustedDevice.token_hash) == token_hash,
col(DashboardTrustedDevice.totp_secret_hash) == totp_secret_hash,
col(DashboardTrustedDevice.expires_at)
> datetime.datetime.now(datetime.timezone.utc),
)
)
return result.scalar_one_or_none() is not None


async def issue_totp_trusted_device(config, db) -> str | None:
"""Issue a trusted device token, save to DB, and return the raw token for cookie."""
raw_token = secrets.token_urlsafe(48)
token_hash = _hash_totp_trusted_device_token(config, raw_token)
totp_secret_hash = _hash_totp_secret(config)
if not token_hash or not totp_secret_hash:
return None

expires_at = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
seconds=TOTP_TRUSTED_DEVICE_MAX_AGE
)
async with db.get_db() as session:
async with session.begin():
await session.execute(
delete(DashboardTrustedDevice).where(
col(DashboardTrustedDevice.token_hash) == token_hash
)
)
trusted_device = DashboardTrustedDevice.model_validate(
{
"token_hash": token_hash,
"totp_secret_hash": totp_secret_hash,
"expires_at": expires_at,
}
)
session.add(trusted_device)
return raw_token


async def _cleanup_expired_totp_trusted_devices(db) -> None:
async with db.get_db() as session:
async with session.begin():
await session.execute(
delete(DashboardTrustedDevice).where(
col(DashboardTrustedDevice.expires_at)
<= datetime.datetime.now(datetime.timezone.utc)
)
)


async def revoke_user_trusted_devices(db) -> None:
async with db.get_db() as session:
async with session.begin():
await session.execute(delete(DashboardTrustedDevice))


def generate_recovery_code() -> tuple[str, str]:
raw = secrets.token_bytes(20)
recovery_code = base64.b32encode(raw).decode("ascii").rstrip("=")
salt = secrets.token_hex(_RECOVERY_CODE_KDF_SALT_BYTES)
digest = hashlib.pbkdf2_hmac(
"sha256",
recovery_code.encode("utf-8"),
bytes.fromhex(salt),
_RECOVERY_CODE_KDF_ITERATIONS,
).hex()
kdf_hash = f"{_RECOVERY_CODE_KDF_ALGORITHM}${_RECOVERY_CODE_KDF_ITERATIONS}${salt}${digest}"
parts = [
recovery_code[i : i + RECOVERY_CODE_GROUP_LENGTH]
for i in range(0, len(recovery_code), RECOVERY_CODE_GROUP_LENGTH)
]
return "-".join(parts), kdf_hash


def verify_recovery_code(config, code: str) -> bool:
"""Verify a recovery code against configured recovery_code_hash (PBKDF2)."""
cleaned = "".join(char for char in code.upper() if char.isalnum())
if len(cleaned) != RECOVERY_CODE_LENGTH:
return False
totp_config = _get_totp_config(config)
stored_hash = totp_config.get("recovery_code_hash", "")
if not isinstance(stored_hash, str) or not stored_hash:
return False

parts = stored_hash.split("$")
if len(parts) != 4 or parts[0] != _RECOVERY_CODE_KDF_ALGORITHM:
return False
try:
iterations = int(parts[1])
salt = parts[2]
expected_digest = parts[3]
except (ValueError, IndexError):
return False

candidate = hashlib.pbkdf2_hmac(
"sha256",
cleaned.encode("utf-8"),
bytes.fromhex(salt),
iterations,
).hex()
return hmac.compare_digest(candidate, expected_digest)
Loading