diff --git a/src/api/filters/__init__.py b/src/api/filters/__init__.py new file mode 100644 index 00000000..fdcb1d7a --- /dev/null +++ b/src/api/filters/__init__.py @@ -0,0 +1,6 @@ +""" +© Ocado Group +Created on 20/01/2025 at 16:25:55(+00:00). +""" + +from .auth_factor import AuthFactorFilterSet diff --git a/src/api/filters/auth_factor.py b/src/api/filters/auth_factor.py new file mode 100644 index 00000000..5d5b2577 --- /dev/null +++ b/src/api/filters/auth_factor.py @@ -0,0 +1,20 @@ +""" +© Ocado Group +Created on 20/01/2025 at 16:26:47(+00:00). +""" + +from codeforlife.filters import FilterSet # isort: skip +from codeforlife.user.models import AuthFactor # isort: skip +from django_filters import ( # type: ignore[import-untyped] # isort: skip + rest_framework as filters, +) + + +# pylint: disable-next=missing-class-docstring +class AuthFactorFilterSet(FilterSet): + user = filters.NumberFilter("user") + type = filters.ChoiceFilter(choices=AuthFactor.Type.choices) + + class Meta: + model = AuthFactor + fields = ["user", "type"] diff --git a/src/api/permissions/__init__.py b/src/api/permissions/__init__.py index bd8dd228..aa64ed61 100644 --- a/src/api/permissions/__init__.py +++ b/src/api/permissions/__init__.py @@ -3,4 +3,5 @@ Created on 24/04/2024 at 11:57:02(+01:00). """ +from .has_auth_factor import HasAuthFactor from .is_invited_school_teacher import IsInvitedSchoolTeacher diff --git a/src/api/permissions/has_auth_factor.py b/src/api/permissions/has_auth_factor.py new file mode 100644 index 00000000..3a8b26dc --- /dev/null +++ b/src/api/permissions/has_auth_factor.py @@ -0,0 +1,27 @@ +""" +© Ocado Group +Created on 20/01/2025 at 18:52:16(+00:00). +""" + +from codeforlife.permissions import IsAuthenticated +from codeforlife.user.models import AuthFactor, User + + +class HasAuthFactor(IsAuthenticated): + """Request's user must have a auth factor enabled.""" + + def __init__(self, t: AuthFactor.Type): + super().__init__() + + self.t = t + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.t == other.t + + def has_permission(self, request, view): + user = request.user + return ( + super().has_permission(request, view) + and isinstance(user, User) + and user.auth_factors.filter(type=self.t).exists() + ) diff --git a/src/api/serializers/auth_factor.py b/src/api/serializers/auth_factor.py index ec075055..2c369650 100644 --- a/src/api/serializers/auth_factor.py +++ b/src/api/serializers/auth_factor.py @@ -9,8 +9,10 @@ from codeforlife.user.models import AuthFactor, User from rest_framework import serializers +# pylint: disable=missing-class-docstring +# pylint: disable=too-many-ancestors + -# pylint: disable-next=missing-class-docstring,too-many-ancestors class AuthFactorSerializer(ModelSerializer[User, AuthFactor]): otp = serializers.CharField(required=False, write_only=True) diff --git a/src/api/views/auth_factor.py b/src/api/views/auth_factor.py index 5b106455..af6c52f7 100644 --- a/src/api/views/auth_factor.py +++ b/src/api/views/auth_factor.py @@ -4,13 +4,15 @@ """ import pyotp -from codeforlife.permissions import AllowNone +from codeforlife.permissions import NOT, AllowNone from codeforlife.request import Request from codeforlife.response import Response from codeforlife.user.models import AuthFactor, User from codeforlife.user.permissions import IsTeacher from codeforlife.views import ModelViewSet, action +from ..filters import AuthFactorFilterSet +from ..permissions import HasAuthFactor from ..serializers import AuthFactorSerializer @@ -19,6 +21,7 @@ class AuthFactorViewSet(ModelViewSet[User, AuthFactor]): request_user_class = User model_class = AuthFactor http_method_names = ["get", "post", "delete"] + filterset_class = AuthFactorFilterSet serializer_class = AuthFactorSerializer # pylint: disable-next=missing-function-docstring @@ -41,16 +44,23 @@ def get_queryset(self): def get_permissions(self): if self.action in ["retrieve", "bulk"]: return [AllowNone()] + if self.action == "get_otp_secret": + return [IsTeacher(), NOT(HasAuthFactor(AuthFactor.Type.OTP))] return [IsTeacher()] - @action(detail=False, methods=["post"]) - def generate_otp_provisioning_uri(self, request: Request[User]): - """Generate a time-based one-time-password provisioning URI.""" + @action(detail=False, methods=["get"]) + def get_otp_secret(self, request: Request[User]): + """Get the secret for the user's one-time-password.""" # TODO: make otp_secret non-nullable and delete code block user = request.auth_user if not user.userprofile.otp_secret: user.userprofile.otp_secret = pyotp.random_base32() user.userprofile.save(update_fields=["otp_secret"]) - return Response(user.totp_provisioning_uri, content_type="text/plain") + return Response( + { + "secret": user.totp.secret, + "provisioning_uri": user.totp_provisioning_uri, + } + ) diff --git a/src/api/views/auth_factor_test.py b/src/api/views/auth_factor_test.py index 16f9bf6c..cb0d957a 100644 --- a/src/api/views/auth_factor_test.py +++ b/src/api/views/auth_factor_test.py @@ -3,25 +3,26 @@ Created on 23/01/2024 at 11:22:16(+00:00). """ -from unittest.mock import patch - import pyotp -from codeforlife.permissions import AllowNone +from codeforlife.permissions import NOT, AllowNone from codeforlife.tests import ModelViewSetTestCase from codeforlife.user.models import ( AdminSchoolTeacherUser, AuthFactor, NonAdminSchoolTeacherUser, + School, TeacherUser, User, ) from codeforlife.user.permissions import IsTeacher -from pyotp import TOTP +from django.db.models import Count, Q +from ..permissions import HasAuthFactor from .auth_factor import AuthFactorViewSet # pylint: disable=missing-class-docstring # pylint: disable=too-many-ancestors +# pylint: disable=too-many-public-methods class TestAuthFactorViewSet(ModelViewSetTestCase[User, AuthFactor]): @@ -101,12 +102,47 @@ def test_get_queryset__destroy__non_admin(self): request=self.client.request_factory.get(user=user), ) - def test_get_queryset__generate_otp_provisioning_uri(self): - """Can only generate an OTP provisioning URI yourself.""" + def test_get_queryset__get_otp_secret(self): + """Can only get your own OTP secret.""" + user = self.mfa_non_admin_school_teacher_user + + self.assert_get_queryset( + action="get_otp_secret", + values=list(user.auth_factors.all()), + request=self.client.request_factory.get(user=user), + ) + + def test_get_queryset__check_if_exists__admin(self): + """ + Can check if a author factor exists for all teachers in your school if + you are an admin. + """ + user = self.mfa_non_admin_school_teacher_user + admin_school_teacher_user = AdminSchoolTeacherUser.objects.filter( + new_teacher__school=user.teacher.school + ).first() + assert admin_school_teacher_user + + self.assert_get_queryset( + action="list", + values=list( + user.auth_factors.all() + | admin_school_teacher_user.auth_factors.all() + ), + request=self.client.request_factory.get( + user=admin_school_teacher_user + ), + ) + + def test_get_queryset__check_if_exists__non_admin(self): + """ + Can check if a author factor exists for only yourself if you are not an + admin. + """ user = self.mfa_non_admin_school_teacher_user self.assert_get_queryset( - action="generate_otp_provisioning_uri", + action="list", values=list(user.auth_factors.all()), request=self.client.request_factory.get(user=user), ) @@ -133,12 +169,17 @@ def test_get_permissions__destroy(self): """Only a teacher-user can disable an auth factor.""" self.assert_get_permissions([IsTeacher()], action="destroy") - def test_get_permissions__generate_otp_provisioning_uri(self): - """Only a teacher-user can generate a OTP provisioning URI.""" + def test_get_permissions__get_otp_secret(self): + """Only a teacher-user can get an OTP secret.""" self.assert_get_permissions( - [IsTeacher()], action="generate_otp_provisioning_uri" + [IsTeacher(), NOT(HasAuthFactor(AuthFactor.Type.OTP))], + action="get_otp_secret", ) + def test_get_permissions__check_if_exists(self): + """Only a teacher-user can check if an auth factor exists.""" + self.assert_get_permissions([IsTeacher()], action="check_if_exists") + # test: actions def test_list(self): @@ -148,6 +189,52 @@ def test_list(self): self.client.login_as(user) self.client.list(user.auth_factors.all()) + def test_list__user(self): + """Can list enabled auth-factors, filtered by a user's ID.""" + # Get a school that has at least: + # - one admin teacher; + # - two teachers with auth factors enabled. + school = ( + School.objects.annotate( + admin_teacher_count=Count( + "teacher_school", + filter=Q(teacher_school__is_admin=True), + ), + mfa_teacher_count=Count( + "teacher_school", + filter=Q( + teacher_school__new_user__auth_factors__isnull=False + ), + ), + ) + .filter(admin_teacher_count__gte=1, mfa_teacher_count__gte=2) + .first() + ) + assert school + + user = AdminSchoolTeacherUser.objects.filter( + new_teacher__school=school + ).first() + assert user + + self.client.login_as(user) + self.client.list( + user.auth_factors.all(), + filters={"user": str(user.pk)}, + ) + + def test_list__type(self): + """Can list enabled auth-factors, filtered by type.""" + user = self.mfa_non_admin_school_teacher_user + auth_factor = user.auth_factors.first() + assert auth_factor + + self.client.login_as(user) + self.client.list( + [auth_factor], + filters={"type": auth_factor.type}, + ) + def test_create__otp(self): """Can enable OTP.""" teacher_user = TeacherUser.objects.exclude( @@ -177,7 +264,7 @@ def test_destroy(self): self.client.login_as(user) self.client.destroy(auth_factor) - def test_generate_otp_provisioning_uri(self): + def test_get_otp_secret(self): """Can successfully generate a OTP provisioning URI.""" user = TeacherUser.objects.exclude( auth_factors__type__in=[AuthFactor.Type.OTP] @@ -187,17 +274,12 @@ def test_generate_otp_provisioning_uri(self): # TODO: normalize password to "password" self.client.login_as(user, password="abc123") - with patch.object( - TOTP, "provisioning_uri", return_value=user.totp_provisioning_uri - ) as provisioning_uri: - response = self.client.post( - self.reverse_action("generate_otp_provisioning_uri") - ) - - provisioning_uri.assert_called_once_with( - name=user.email, - issuer_name="Code for Life", - ) + response = self.client.get(self.reverse_action("get_otp_secret")) - assert response.data == provisioning_uri.return_value - assert response.content_type == "text/plain" + self.assertDictEqual( + response.json(), + { + "secret": user.totp.secret, + "provisioning_uri": user.totp_provisioning_uri, + }, + ) diff --git a/src/api/views/klass.py b/src/api/views/klass.py index e7f25c61..5a5058f8 100644 --- a/src/api/views/klass.py +++ b/src/api/views/klass.py @@ -16,6 +16,7 @@ class ClassViewSet(_ClassViewSet): http_method_names = ["get", "post", "patch", "delete"] + # pylint: disable-next=missing-function-docstring def get_permissions(self): # Only bulk-partial-update allowed for classes. if self.action == "bulk": @@ -31,12 +32,14 @@ def get_permissions(self): return super().get_permissions() + # pylint: disable-next=missing-function-docstring def get_serializer_class(self): if self.action in ["create", "partial_update", "bulk"]: return WriteClassSerializer return ReadClassSerializer + # pylint: disable-next=missing-function-docstring def get_queryset(self): if self.action in ["retrieve", "list"]: return super().get_queryset() @@ -48,6 +51,7 @@ def get_queryset(self): else teacher.classes.filter(teacher=teacher) ) + # pylint: disable-next=missing-function-docstring def destroy(self, request, *args, **kwargs): klass = self.get_object()