diff --git a/docs/api.rst b/docs/api.rst index 23118ee4..5a0abfa4 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -13,6 +13,8 @@ Core Protecting Views ---------------- +.. autofunction:: flask_security.decorators.auth_required + .. autofunction:: flask_security.decorators.roles_required .. autofunction:: flask_security.decorators.roles_accepted diff --git a/flask_security/core.py b/flask_security/core.py index 7a25f2e3..370115d5 100644 --- a/flask_security/core.py +++ b/flask_security/core.py @@ -227,7 +227,7 @@ def _on_identity_loaded(sender, identity): identity.provides.add(UserNeed(current_user.id)) for role in getattr(current_user, 'roles', []): - identity.provides.add(RoleNeed(role.name)) + identity.provides.add(RoleNeed(role.id)) identity.user = current_user diff --git a/flask_security/decorators.py b/flask_security/decorators.py index 70c9f1ee..21a6d191 100644 --- a/flask_security/decorators.py +++ b/flask_security/decorators.py @@ -60,7 +60,7 @@ def _get_unauthorized_view(): def auth_required(*auth_methods): """ - Decorator that protects enpoints through multiple mechanisms + Decorator that protects endpoints through multiple mechanisms Example:: @app.route('/dashboard') diff --git a/run-tests.sh b/run-tests.sh index e9a86954..14353245 100755 --- a/run-tests.sh +++ b/run-tests.sh @@ -13,8 +13,18 @@ set -o errexit # Quit on unbound symbols set -o nounset +# Check for arguments +pytest_args=() +for arg in $@; do + case ${arg} in + *) + pytest_args+=( ${arg} ) + ;; + esac +done + python -m check_manifest --ignore ".*-requirements.txt" python -m sphinx.cmd.build -qnN docs docs/_build/html -python -m pytest +python -m pytest ${pytest_args[@]+"${pytest_args[@]}"} tests_exit_code=$? exit "$tests_exit_code" diff --git a/tests/conftest.py b/tests/conftest.py index e266930c..a3ca4042 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ import os import tempfile +import uuid from json import JSONEncoder as BaseEncoder import pytest @@ -161,14 +162,20 @@ def sqlalchemy_datastore(request, app, tmpdir): roles_users = db.Table( "roles_users", db.Column("user_id", db.Integer(), db.ForeignKey("user.id")), - db.Column("role_id", db.Integer(), db.ForeignKey("role.id")), + db.Column("role_id", db.String(80), db.ForeignKey("role.id")), ) class Role(db.Model, RoleMixin): - id = db.Column(db.Integer(), primary_key=True) + id = db.Column(db.String(80), primary_key=True, default=lambda x: str(uuid.uuid4())) name = db.Column(db.String(80), unique=True) description = db.Column(db.String(255)) + def __init__(self, **kwargs): + if kwargs.get("name"): + kwargs.setdefault("id", kwargs["name"]) + super().__init__(**kwargs) + + class User(db.Model, UserMixin): id = db.Column(db.Integer, primary_key=True) email = db.Column(db.String(255), unique=True) @@ -219,11 +226,11 @@ class RolesUsers(Base): __tablename__ = "roles_users" id = Column(Integer(), primary_key=True) user_id = Column("user_id", Integer(), ForeignKey("user.id")) - role_id = Column("role_id", Integer(), ForeignKey("role.id")) + role_id = Column("role_id", String(80), ForeignKey("role.id")) class Role(Base, RoleMixin): __tablename__ = "role" - id = Column(Integer(), primary_key=True) + id = Column(String(80), primary_key=True, default=lambda x: str(uuid.uuid4())) name = Column(String(80), unique=True) description = Column(String(255))