Skip to content

Commit

Permalink
fix: list filters
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Feb 4, 2025
1 parent ca7f6b0 commit e08db77
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 109 deletions.
3 changes: 2 additions & 1 deletion src/api/filters/auth_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# pylint: disable-next=missing-class-docstring
class AuthFactorFilterSet(FilterSet):
user = filters.NumberFilter("user")
type = filters.ChoiceFilter(choices=AuthFactor.Type.choices)

class Meta:
model = AuthFactor
fields = ["user"]
fields = ["user", "type"]
2 changes: 1 addition & 1 deletion src/api/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Created on 23/01/2024 at 16:13:13(+00:00).
"""

from .auth_factor import AuthFactorSerializer, CheckIfAuthFactorExistsSerializer
from .auth_factor import AuthFactorSerializer
from .klass import ReadClassSerializer, WriteClassSerializer
from .school import SchoolSerializer
from .school_teacher_invitation import (
Expand Down
17 changes: 0 additions & 17 deletions src/api/serializers/auth_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from codeforlife.serializers import ModelSerializer
from codeforlife.user.models import AuthFactor, User
from rest_framework import serializers
from rest_framework.validators import UniqueTogetherValidator

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors
Expand Down Expand Up @@ -64,19 +63,3 @@ def create(self, validated_data):
validated_data["user_id"] = self.request.auth_user.id
validated_data.pop("otp", None)
return super().create(validated_data)


class CheckIfAuthFactorExistsSerializer(ModelSerializer[User, AuthFactor]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Allow duplicate pairs [to the DB] since we're checking if they exist.
self.validators = [
validator
for validator in self.validators
if not isinstance(validator, UniqueTogetherValidator)
]

class Meta:
model = AuthFactor
fields = ["user", "type"]
27 changes: 1 addition & 26 deletions src/api/serializers/auth_factor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

from codeforlife.tests import ModelSerializerTestCase
from codeforlife.user.models import AuthFactor, TeacherUser, User
from rest_framework.validators import UniqueTogetherValidator

from .auth_factor import AuthFactorSerializer, CheckIfAuthFactorExistsSerializer
from .auth_factor import AuthFactorSerializer

# pylint: disable=missing-class-docstring
# pylint: disable=too-many-ancestors
Expand Down Expand Up @@ -100,27 +99,3 @@ def test_create__otp(self):
new_data={"user": user.id},
context={"request": self.request_factory.post(user=user)},
)


class TestCheckIfAuthFactorExistsSerializer(
ModelSerializerTestCase[User, AuthFactor]
):
model_serializer_class = CheckIfAuthFactorExistsSerializer
fixtures = ["school_2"]

def setUp(self):
auth_factor = AuthFactor.objects.first()
assert auth_factor
self.auth_factor = auth_factor

def test_init(self):
"""Initializing the serializer removes unique-together validators."""
model_serializer = self.model_serializer_class(
data={
"user": self.auth_factor.user.pk,
"type": self.auth_factor.type,
}
)

for validator in model_serializer.validators:
assert not isinstance(validator, UniqueTogetherValidator)
29 changes: 3 additions & 26 deletions src/api/views/auth_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@

from ..filters import AuthFactorFilterSet
from ..permissions import HasAuthFactor
from ..serializers import (
AuthFactorSerializer,
CheckIfAuthFactorExistsSerializer,
)
from ..serializers import AuthFactorSerializer


# pylint: disable-next=missing-class-docstring,too-many-ancestors
Expand All @@ -25,21 +22,15 @@ class AuthFactorViewSet(ModelViewSet[User, AuthFactor]):
model_class = AuthFactor
http_method_names = ["get", "post", "delete"]
filterset_class = AuthFactorFilterSet

# pylint: disable-next=missing-function-docstring
def get_serializer_class(self):
if self.action == "check_if_exists":
return CheckIfAuthFactorExistsSerializer

return AuthFactorSerializer
serializer_class = AuthFactorSerializer

# pylint: disable-next=missing-function-docstring
def get_queryset(self):
queryset = AuthFactor.objects.all()
user = self.request.teacher_user

if (
self.action in ["list", "destroy", "check_if_exists"]
self.action in ["list", "destroy"]
and user.teacher.school
and user.teacher.is_admin
):
Expand All @@ -58,20 +49,6 @@ def get_permissions(self):

return [IsTeacher()]

@action(detail=False, methods=["post"])
def check_if_exists(self, request: Request[User]):
"""Check if an auth factor exists for the requesting user."""
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)

return Response(
{
"auth_factor_exists": self.get_queryset()
.filter(**serializer.validated_data)
.exists()
}
)

@action(detail=False, methods=["get"])
def get_otp_secret(self, request: Request[User]):
"""Get the secret for the user's one-time-password."""
Expand Down
50 changes: 12 additions & 38 deletions src/api/views/auth_factor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
from django.db.models import Count

from ..permissions import HasAuthFactor
from ..serializers import (
AuthFactorSerializer,
CheckIfAuthFactorExistsSerializer,
)
from .auth_factor import AuthFactorViewSet

# pylint: disable=missing-class-docstring
Expand All @@ -40,22 +36,6 @@ def setUp(self):
)
assert self.mfa_non_admin_school_teacher_user.auth_factors.exists()

# test: get serializer class

def test_get_serializer_class__list(self):
"""Listing auth factors uses the general serializer."""
self.assert_get_serializer_class(AuthFactorSerializer, action="list")

def test_get_serializer_class__create(self):
"""Creating an auth factor uses the general serializer."""
self.assert_get_serializer_class(AuthFactorSerializer, action="create")

def test_get_serializer_class__check_if_exists(self):
"""Checking if an auth factor exists uses a dedicated serializer."""
self.assert_get_serializer_class(
CheckIfAuthFactorExistsSerializer, action="check_if_exists"
)

# test: get queryset

def test_get_queryset__list__admin(self):
Expand Down Expand Up @@ -238,6 +218,18 @@ def test_list__user(self):
filters={"user": str(user.pk)},
)

def test_list__type(self):
"""Can list enabled auth-factors, filtered by type."""
user = self.mfa_non_admin_school_teacher_user
auth_factor = user.auth_factors.first()
assert auth_factor

self.client.login_as(user)
self.client.list(
[auth_factor],
filters={"type": auth_factor.type},
)

def test_create__otp(self):
"""Can enable OTP."""
teacher_user = TeacherUser.objects.exclude(
Expand Down Expand Up @@ -267,24 +259,6 @@ def test_destroy(self):
self.client.login_as(user)
self.client.destroy(auth_factor)

def test_check_if_exists(self):
"""Can successfully check if the requesting user has an auth factor."""
user = self.mfa_non_admin_school_teacher_user
auth_factor = user.auth_factors.first()
assert auth_factor

self.client.login_as(user)

response = self.client.post(
self.reverse_action("check_if_exists"),
data={
"user": auth_factor.user.pk,
"type": auth_factor.type,
},
)

self.assertDictEqual(response.json(), {"auth_factor_exists": True})

def test_get_otp_secret(self):
"""Can successfully generate a OTP provisioning URI."""
user = TeacherUser.objects.exclude(
Expand Down

0 comments on commit e08db77

Please sign in to comment.