diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 339692e1..d6aeb276 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -74,6 +74,7 @@ jobs: pip install -U pip wheel setuptools pip install -U -r requirements-test.txt pip install -e .[saml,openvpn_status] + pip install --upgrade --no-deps --no-cache-dir --force-reinstall "https://github.com/openwisp/openwisp-users/tarball/issues/497-export-users" pip install ${{ matrix.django-version }} - name: Start InfluxDB and Redis container diff --git a/docs/user/rest-api.rst b/docs/user/rest-api.rst index 0efc1882..68ed71a3 100644 --- a/docs/user/rest-api.rst +++ b/docs/user/rest-api.rst @@ -803,6 +803,48 @@ Param Description phone_number string ============ =========== +Update Registered User Method ++++++++++++++++++++++++++++++ + +**Requires the user auth token (Bearer Token)**. + +Allows users to update their registered user method for an organization. +The method can only be updated when it is currently set to +``pending_verification``. Once updated, it cannot be changed again via +this endpoint. + +This endpoint is used during cross-organization login when a user +authenticates to a new organization. The user must complete verification +for that organization before they can create account with the new +organization. + +.. code-block:: text + + /api/v1/radius/organization//account/registration-method/ + +Responds only to **POST**. + +Parameters: + +====== =========== +Param Description +====== =========== +method string (\*) +====== =========== + +(\*) ``method`` must be one of the available +:ref:`registration/verification methods +`, excluding +``pending_verification``. + +**Success Response (200 OK)**: + +.. code-block:: json + + { + "method": "mobile_phone" + } + .. _radius_batch_user_creation: Batch user creation diff --git a/docs/user/settings.rst b/docs/user/settings.rst index 4df3d1c2..3ee5d040 100644 --- a/docs/user/settings.rst +++ b/docs/user/settings.rst @@ -696,6 +696,9 @@ verification method. The following choices are available by default: - ``mobile_phone``: Mobile phone number :ref:`verification via SMS ` - ``social_login``: :doc:`social login feature ` +- ``pending_verification``: Transitional state used when a user + authenticates to a new organization but has not yet completed + verification for that organization. .. note:: diff --git a/openwisp_radius/admin.py b/openwisp_radius/admin.py index ac095a54..92775786 100644 --- a/openwisp_radius/admin.py +++ b/openwisp_radius/admin.py @@ -7,6 +7,7 @@ from django.contrib.admin.utils import model_ngettext from django.contrib.auth import get_user_model from django.core.exceptions import PermissionDenied +from django.forms.models import BaseInlineFormSet from django.http import HttpResponseRedirect from django.templatetags.static import static from django.urls import reverse @@ -534,11 +535,30 @@ def has_change_permission(self, request, obj=None): return False +class RegisteredUserFormset(BaseInlineFormSet): + def get_unique_error_message(self, unique_check): + # Django inline formsets perform their own uniqueness validation + # (BaseModelFormSet.validate_unique) *before* model-level validation runs. + # Because of this, the custom `violation_error_message` defined on + # `UniqueConstraint` is never surfaced in the admin UI. + # + # Overriding this method allows us to replace Django’s generic + # "Please correct the duplicate data for ." message with a + # domain-specific, user-friendly error that matches our constraint. + if unique_check == ("user", "organization"): + return _( + "A user cannot have more than one registration record in the" + " same organization." + ) + + class RegisteredUserInline(StackedInline): model = RegisteredUser form = AlwaysHasChangedForm + formset = RegisteredUserFormset extra = 0 readonly_fields = ("modified",) + fields = ("organization", "method", "is_verified", "modified") def has_delete_permission(self, request, obj=None): return False @@ -549,12 +569,17 @@ def has_delete_permission(self, request, obj=None): RadiusUserGroupInline, PhoneTokenInline, ] -UserAdmin.list_filter += (RegisteredUserFilter, "registered_user__method") +UserAdmin.list_filter += (RegisteredUserFilter, "registered_users__method") def get_is_verified(self, obj): try: - value = "yes" if obj.registered_user.is_verified else "no" + if not obj.registered_users.exists(): + value = "unknown" + elif obj.registered_users.filter(is_verified=True).exists(): + value = "yes" + else: + value = "no" except Exception: value = "unknown" icon_url = static(f"admin/img/icon-{value}.svg") @@ -564,7 +589,6 @@ def get_is_verified(self, obj): UserAdmin.get_is_verified = get_is_verified UserAdmin.get_is_verified.short_description = _("Verified") UserAdmin.list_display.insert(3, "get_is_verified") -UserAdmin.list_select_related = ("registered_user",) class OrganizationRadiusSettingsInline(admin.StackedInline): diff --git a/openwisp_radius/api/freeradius_views.py b/openwisp_radius/api/freeradius_views.py index b69232e5..25404e83 100644 --- a/openwisp_radius/api/freeradius_views.py +++ b/openwisp_radius/api/freeradius_views.py @@ -7,7 +7,7 @@ from django.contrib.auth.models import AnonymousUser from django.core.cache import cache from django.db import IntegrityError -from django.db.models import Q +from django.db.models import Exists, OuterRef, Q from django.utils.translation import gettext_lazy as _ from django_filters import rest_framework as filters from django_filters.rest_framework import DjangoFilterBackend @@ -57,6 +57,7 @@ RadiusToken = load_model("RadiusToken") RadiusAccounting = load_model("RadiusAccounting") +RegisteredUser = load_model("RegisteredUser") OrganizationRadiusSettings = load_model("OrganizationRadiusSettings") OrganizationUser = swapper.load_model("openwisp_users", "OrganizationUser") Organization = swapper.load_model("openwisp_users", "Organization") @@ -290,7 +291,7 @@ def get_user(self, request, username, password): """ conditions = self._get_user_query_conditions(request) try: - user = auth_backend.get_users(username).filter(conditions)[0] + user = auth_backend.get_users(username).filter(conditions).distinct()[0] except IndexError: return None # ensure user is member of the authenticated org @@ -405,23 +406,45 @@ def _check_counters(self, data, user, group, group_checks): def _get_user_query_conditions(self, request): is_active = Q(is_active=True) needs_verification = self._needs_identity_verification({"pk": request._auth}) - # if no identity verification enabled for this org, - # just ensure user is active if not needs_verification: return is_active - # if identity verification is enabled - is_verified = Q(registered_user__is_verified=True) + organization_id = request._auth AUTHORIZE_UNVERIFIED = registration.AUTHORIZE_UNVERIFIED - # and no method should authorize unverified users - # ensure user is active AND verified + # Use subqueries to ensure org-specific records take precedence over + # global (organization=NULL) records. + # A JOIN-based filter would allow a user to pass if ANY registered_users + # row matched, causing a bypass when a global verified record coexisted + # with an org-specific unverified record. + # + # Strategy: check if org-specific record exists and satisfies criteria; + # if not, fall back to checking the global record. This matches the + # behavior in api/utils.py:IDVerificationHelper.is_identity_verified_strong. + org_specific = RegisteredUser.objects.filter( + user=OuterRef("pk"), + organization_id=organization_id, + ) + global_only = RegisteredUser.objects.filter( + user=OuterRef("pk"), + organization_id__isnull=True, + ) + + # is_verified: user passes if org-specific record is verified, or if + # no org-specific record exists and the global record is verified. + has_org_verified = Exists(org_specific.filter(is_verified=True)) + has_global_verified = Exists(global_only.filter(is_verified=True)) + no_org_specific = ~Exists(org_specific.values("pk")) + is_verified = has_org_verified | (no_org_specific & has_global_verified) + if not AUTHORIZE_UNVERIFIED: return is_active & is_verified - # in case some methods are allowed to authorize unverified users - # ensure user is active AND - # (user is verified OR user uses one of these methods) - else: - authorize_unverified = Q(registered_user__method__in=AUTHORIZE_UNVERIFIED) - return is_active & (is_verified | authorize_unverified) + + # authorize_unverified: user passes if org-specific record uses a + # special method, or if no org-specific record exists and the global + # record uses a special method. + has_org_special = Exists(org_specific.filter(method__in=AUTHORIZE_UNVERIFIED)) + has_global_special = Exists(global_only.filter(method__in=AUTHORIZE_UNVERIFIED)) + authorize_unverified = has_org_special | (no_org_specific & has_global_special) + return is_active & (is_verified | authorize_unverified) def authenticate_user(self, request, user, password): """ diff --git a/openwisp_radius/api/serializers.py b/openwisp_radius/api/serializers.py index b9b01165..8b878f4a 100644 --- a/openwisp_radius/api/serializers.py +++ b/openwisp_radius/api/serializers.py @@ -36,7 +36,7 @@ from .. import settings as app_settings from ..base.forms import PasswordResetForm from ..counters.exceptions import SkipCheck -from ..registration import REGISTRATION_METHOD_CHOICES +from ..registration import get_registration_choices from ..utils import ( get_group_checks, get_organization_radius_settings, @@ -571,7 +571,7 @@ class RegisterSerializer( 'verification in its "Organization RADIUS Settings."' ), default="", - choices=REGISTRATION_METHOD_CHOICES, + choices=get_registration_choices(), ) def validate_phone_number(self, phone_number): @@ -688,9 +688,11 @@ def save(self, request): # the custom_signup method contains the openwisp specific logic self.custom_signup(request, user) # create a RegisteredUser object for every user that registers through API - RegisteredUser.objects.create( + org = self.context["view"].organization + RegisteredUser.objects.get_or_create( user=user, - method=self.validated_data["method"], + organization=org, + defaults={"method": self.validated_data["method"]}, ) setup_user_email(request, user, []) return user @@ -753,8 +755,55 @@ def save(self): # yet, tha will be done by the phone token validation view # once the phone number has been validated # at this point we flag the user as unverified again - self.user.registered_user.is_verified = False - self.user.registered_user.save() + org = self.context["view"].organization + reg_user, _ = RegisteredUser.get_or_create_for_user_and_org( + user=self.user, + organization=org, + defaults={"is_verified": False, "method": ""}, + ) + reg_user.is_verified = False + reg_user.save() + + +class UpdateRegisteredUserMethodSerializer(ValidatedModelSerializer): + method = serializers.ChoiceField( + choices=get_registration_choices(), + help_text=_( + "The registration method to set for the user. " + "Cannot be 'pending_verification'." + ), + ) + + class Meta: + model = RegisteredUser + fields = ["method"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fields["method"].choices = get_registration_choices() + + def validate_method(self, value): + if value == "pending_verification": + raise serializers.ValidationError( + _("'pending_verification' cannot be set as a registration method.") + ) + return value + + def validate(self, attrs): + if self.instance.method != "pending_verification": + raise serializers.ValidationError( + { + "method": _( + "Method can only be updated from pending verification state." + ) + } + ) + return attrs + + def update(self, instance, validated_data): + instance.method = validated_data["method"] + instance.save() + return instance class RadiusUserSerializer(serializers.ModelSerializer): @@ -762,11 +811,8 @@ class RadiusUserSerializer(serializers.ModelSerializer): Used to return information about the logged in user """ - is_verified = serializers.BooleanField(source="registered_user.is_verified") - method = serializers.CharField( - source="registered_user.method", - allow_null=True, - ) + is_verified = serializers.SerializerMethodField() + method = serializers.SerializerMethodField() password_expired = serializers.BooleanField(source="has_password_expired") radius_user_token = serializers.CharField(source="radius_token.key", default=None) @@ -786,3 +832,30 @@ class Meta: "password_expired", "radius_user_token", ] + + def _get_registered_user(self, obj): + if not hasattr(self, "_registered_user_cache"): + self._registered_user_cache = {} + if obj.pk not in self._registered_user_cache: + view = self.context.get("view") + organization = getattr(view, "organization", None) + reg_user = None + # We iterate over .all() instead of using .filter() because callers + # of this serializer (e.g. validate_auth_token) prefetch + # "registered_users" via prefetch_related. Using .all() hits the + # in-memory prefetch cache (0 DB queries), whereas .filter() would + # bypass the cache and issue a new query every time. + for ru in obj.registered_users.all(): + if organization and ru.organization_id == organization.pk: + reg_user = ru + break + self._registered_user_cache[obj.pk] = reg_user + return self._registered_user_cache[obj.pk] + + def get_is_verified(self, obj): + reg_user = self._get_registered_user(obj) + return reg_user.is_verified if reg_user else None + + def get_method(self, obj): + reg_user = self._get_registered_user(obj) + return reg_user.method if reg_user else None diff --git a/openwisp_radius/api/urls.py b/openwisp_radius/api/urls.py index 88d02572..3ca6407d 100644 --- a/openwisp_radius/api/urls.py +++ b/openwisp_radius/api/urls.py @@ -77,6 +77,11 @@ def get_api_urls(api_views=None): api_views.change_phone_number, name="phone_number_change", ), + path( + "radius/organization//account/registration-method/", + api_views.update_registered_user_registration_method, + name="update_registered_user_registration_method", + ), path("radius/batch/", api_views.batch, name="batch"), path( "radius/organization//batch//pdf/", diff --git a/openwisp_radius/api/utils.py b/openwisp_radius/api/utils.py index 6d742c57..447ca7c5 100644 --- a/openwisp_radius/api/utils.py +++ b/openwisp_radius/api/utils.py @@ -9,6 +9,7 @@ Organization = load_model("openwisp_users", "Organization") OrganizationRadiusSettings = load_model("openwisp_radius", "OrganizationRadiusSettings") +RegisteredUser = load_model("openwisp_radius", "RegisteredUser") class ErrorDictMixin(object): @@ -30,8 +31,14 @@ def _needs_identity_verification(self, organization_filter_kwargs={}, org=None): except ObjectDoesNotExist: return app_settings.NEEDS_IDENTITY_VERIFICATION - def is_identity_verified_strong(self, user): - try: - return user.registered_user.is_identity_verified_strong - except ObjectDoesNotExist: + def is_identity_verified_strong(self, user, organization=None): + reg_user = None + # We use all() to utilize the prefetch cache, otherwise + # it would cause an additional query to fetch the registered user + for ru in user.registered_users.all(): + if organization and ru.organization_id == organization.pk: + reg_user = ru + break + if reg_user is None: return False + return reg_user.is_identity_verified_strong diff --git a/openwisp_radius/api/views.py b/openwisp_radius/api/views.py index 07c4bd37..9d86add9 100644 --- a/openwisp_radius/api/views.py +++ b/openwisp_radius/api/views.py @@ -11,8 +11,8 @@ from django.contrib.sites.shortcuts import get_current_site from django.core.cache import cache from django.core.exceptions import ValidationError +from django.db import IntegrityError, transaction from django.db.models import Q -from django.db.utils import IntegrityError from django.http import Http404, HttpResponse from django.utils import timezone from django.utils.decorators import method_decorator @@ -35,6 +35,7 @@ ListCreateAPIView, RetrieveAPIView, RetrieveUpdateDestroyAPIView, + get_object_or_404, ) from rest_framework.pagination import PageNumberPagination from rest_framework.permissions import ( @@ -74,6 +75,7 @@ RadiusBatchSerializer, RadiusGroupSerializer, RadiusUserGroupSerializer, + UpdateRegisteredUserMethodSerializer, UserRadiusUsageSerializer, ValidatePhoneTokenSerializer, ) @@ -92,6 +94,7 @@ Organization = swapper.load_model("openwisp_users", "Organization") OrganizationUser = swapper.load_model("openwisp_users", "OrganizationUser") PhoneToken = load_model("PhoneToken") +RegisteredUser = load_model("RegisteredUser") RadiusAccounting = load_model("RadiusAccounting") RadiusToken = load_model("RadiusToken") RadiusBatch = load_model("RadiusBatch") @@ -315,13 +318,13 @@ def post(self, request, *args, **kwargs): self.update_user_details(user) context = {"view": self, "request": request} serializer = self.serializer_class(instance=token, context=context) - response = RadiusUserSerializer(user).data + response = RadiusUserSerializer(user, context=context).data response.update(serializer.data) status_code = 200 if user.is_active else 401 # If identity verification is required, check if user is verified if self._needs_identity_verification( {"slug": kwargs["slug"]} - ) and not self.is_identity_verified_strong(user): + ) and not self.is_identity_verified_strong(user, self.organization): status_code = 401 return Response(response, status=status_code) @@ -335,24 +338,23 @@ def validate_membership(self, user): if get_organization_radius_settings( self.organization, "registration_enabled" ): - if self._needs_identity_verification( - org=self.organization - ) and not self.is_identity_verified_strong(user): - raise PermissionDenied try: - org_user = OrganizationUser( - user=user, organization=self.organization - ) - org_user.full_clean() - org_user.save() + with transaction.atomic(): + OrganizationUser.objects.get_or_create( + user=user, organization=self.organization + ) + RegisteredUser.objects.get_or_create( + user=user, + organization=self.organization, + defaults={"method": "pending_verification"}, + ) except ValidationError as error: raise serializers.ValidationError( {"non_field_errors": error.message_dict.pop("__all__")} ) else: message = _( - "{organization} does not allow self registration " - "of new accounts." + "{organization} does not allow self registration of new accounts." ).format(organization=self.organization.name) raise PermissionDenied(message) @@ -383,9 +385,15 @@ def post(self, request, *args, **kwargs): response = {"response_code": "BLANK_OR_INVALID_TOKEN"} if request_token: try: - token = UserToken.objects.select_related( - "user", "user__registered_user" - ).get(key=request_token) + token = ( + UserToken.objects.select_related( + "user", + ) + .prefetch_related( + "user__registered_users", + ) + .get(key=request_token) + ) except UserToken.DoesNotExist: pass else: @@ -395,7 +403,7 @@ def post(self, request, *args, **kwargs): ) # user may be in the process of changing the phone number # in that case show the new phone number (which is not verified yet) - if not self.is_identity_verified_strong(user): + if not self.is_identity_verified_strong(user, self.organization): phone_token = ( PhoneToken.objects.filter(user=user) .order_by("-created") @@ -404,8 +412,8 @@ def post(self, request, *args, **kwargs): user.phone_number = ( phone_token.phone_number if phone_token else user.phone_number ) - response = RadiusUserSerializer(user).data context = {"view": self, "request": request} + response = RadiusUserSerializer(user, context=context).data token_data = rest_auth_settings.api_settings.TOKEN_SERIALIZER( token, context=context ).data @@ -638,7 +646,7 @@ def create(self, *args, **kwargs): try: phone_token.full_clean() if kwargs.get("enforce_unverified", True): - phone_token._validate_already_verified() + phone_token._validate_already_verified(organization=self.organization) except ValidationError as e: error_dict = self._get_error_dict(e) raise serializers.ValidationError(error_dict) @@ -747,15 +755,24 @@ def post(self, request, *args, **kwargs): _("No verification code found in the system for this user.") ) try: - is_valid = phone_token.is_valid(serializer.data["code"]) + is_valid = phone_token.is_valid( + serializer.data["code"], organization=self.organization + ) except PhoneTokenException as e: return self._error_response(str(e)) if not is_valid: return self._error_response(_("Invalid code.")) else: - user.registered_user.is_verified = True - user.registered_user.method = "mobile_phone" - user.is_active = True + reg_user, __ = RegisteredUser.get_or_create_for_user_and_org( + user=user, + organization=self.organization, + defaults={ + "is_verified": True, + "method": "mobile_phone", + }, + ) + reg_user.is_verified = True + reg_user.method = "mobile_phone" # Update username if phone_number is used as username if user.username == user.phone_number: user.username = phone_token.phone_number @@ -763,7 +780,7 @@ def post(self, request, *args, **kwargs): # we can write it to the user field user.phone_number = phone_token.phone_number user.save() - user.registered_user.save() + reg_user.save() # delete any radius token cache key if present cache.delete(f"rt-{phone_token.phone_number}") return Response(None, status=200) @@ -813,6 +830,49 @@ def create_phone_token(self, *args, **kwargs): change_phone_number = ChangePhoneNumberView.as_view() +class UpdateRegisteredUserMethodView(DispatchOrgMixin, GenericAPIView): + authentication_classes = (BearerAuthentication, SessionAuthentication) + permission_classes = (IsAuthenticated,) + serializer_class = UpdateRegisteredUserMethodSerializer + + @swagger_auto_schema( + operation_description=(""" + **Requires the user auth token (Bearer Token).** + Allows users to update their registered user method for an organization. + The method can only be updated when it is currently + set to 'pending_verification'. + Once updated, it cannot be changed again via this endpoint. + """), + responses={ + 200: "Method updated successfully", + 400: ( + "Invalid request (method is not 'pending_verification' " + "or invalid method value)" + ), + 401: "Authentication required", + 404: "RegisteredUser not found for this user and organization", + }, + ) + def post(self, request, slug): + user = request.user + reg_user = get_object_or_404( + RegisteredUser, + user_id=user.pk, + organization=self.organization, + ) + serializer = self.get_serializer( + instance=reg_user, data=request.data, partial=True + ) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response( + {"method": serializer.instance.method}, status=status.HTTP_200_OK + ) + + +update_registered_user_registration_method = UpdateRegisteredUserMethodView.as_view() + + class RadiusAccountingFilter(AccountingFilter): called_station_id = CharFilter( field_name="called_station_id", method="filter_mac_address" diff --git a/openwisp_radius/base/admin_filters.py b/openwisp_radius/base/admin_filters.py index 5fd73991..d8c1f7d4 100644 --- a/openwisp_radius/base/admin_filters.py +++ b/openwisp_radius/base/admin_filters.py @@ -15,7 +15,9 @@ def lookups(self, request, model_admin): def queryset(self, request, queryset): if self.value() == "unknown": - return queryset.filter(registered_user__isnull=True) + return queryset.filter(registered_users__isnull=True) elif self.value(): - return queryset.filter(registered_user__is_verified=self.value() == "true") + return queryset.filter( + registered_users__is_verified=self.value() == "true" + ).distinct() return queryset diff --git a/openwisp_radius/base/models.py b/openwisp_radius/base/models.py index 808b8640..300fcc87 100644 --- a/openwisp_radius/base/models.py +++ b/openwisp_radius/base/models.py @@ -15,7 +15,7 @@ from django.conf import settings from django.contrib.auth import get_user_model from django.core.cache import cache -from django.core.exceptions import ObjectDoesNotExist, ValidationError +from django.core.exceptions import ValidationError from django.core.mail import send_mail from django.db import models, transaction from django.db.models import ProtectedError, Q @@ -175,6 +175,9 @@ _LOGIN_URL_HELP_TEXT = _("Enter the URL where users can log in to the wifi service") _STATUS_URL_HELP_TEXT = _("Enter the URL where users can log out from the wifi service") _PASSWORD_RESET_URL_HELP_TEXT = _("Enter the URL where users can reset their password") +_REGISTRATION_UNIQUE_VALIDATION_ERROR = _( + "A user cannot have more than one registration record in the same organization." +) OPTIONAL_SETTINGS = app_settings.OPTIONAL_REGISTRATION_FIELDS @@ -1058,10 +1061,22 @@ def save_user(self, user): OrganizationUser = swapper.load_model("openwisp_users", "OrganizationUser") RegisteredUser = swapper.load_model("openwisp_radius", "RegisteredUser") user.save() - registered_user = RegisteredUser(user=user, method="manual") - if self.organization.radius_settings.needs_identity_verification: + radius_settings = self.organization.radius_settings + registered_user, created = RegisteredUser.get_or_create_for_user_and_org( + user=user, + organization=self.organization, + defaults={ + "method": "manual", + "is_verified": radius_settings.needs_identity_verification, + }, + ) + if ( + not created + and self.organization.radius_settings.needs_identity_verification + ): + registered_user.method = "manual" registered_user.is_verified = True - registered_user.save() + registered_user.save() self.users.add(user) if OrganizationUser.objects.filter( user=user, organization=self.organization @@ -1559,28 +1574,33 @@ def send_token(self): ) sms_message.send(meta_data=org_radius_settings.sms_meta_data) - def is_valid(self, token): + def is_valid(self, token, organization=None): self.attempts += 1 try: - self.verified = self.__check(token) + self.verified = self.__check(token, organization=organization) except exceptions.PhoneTokenException as phone_error: self.save() raise phone_error self.save() return self.verified - def _validate_already_verified(self): - try: - if self.user.registered_user.is_verified: - logger.warning(f"User {self.user.pk} is already verified") - raise exceptions.UserAlreadyVerified( - _("This user has been already verified.") - ) - except ObjectDoesNotExist: - pass + def _validate_already_verified(self, organization=None): + RegisteredUser = swapper.load_model("openwisp_radius", "RegisteredUser") + if organization is not None: + reg_user = RegisteredUser.get_for_user_and_org(self.user, organization) + is_verified = reg_user is not None and reg_user.is_verified + else: + is_verified = RegisteredUser.objects.filter( + user=self.user, is_verified=True + ).exists() + if is_verified: + logger.warning(f"User {self.user.pk} is already verified") + raise exceptions.UserAlreadyVerified( + _("This user has been already verified.") + ) - def __check(self, token): - self._validate_already_verified() + def __check(self, token, organization=None): + self._validate_already_verified(organization=organization) if self.attempts > app_settings.SMS_TOKEN_MAX_ATTEMPTS: logger.warning( f"User {self.user} has reached the max " @@ -1602,12 +1622,18 @@ def __check(self, token): return token == self.token -class AbstractRegisteredUser(models.Model): - user = models.OneToOneField( +class AbstractRegisteredUser(UUIDModel): + user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name="registered_user", - primary_key=True, + related_name="registered_users", + ) + organization = models.ForeignKey( + swapper.get_model_name("openwisp_users", "Organization"), + on_delete=models.CASCADE, + related_name="registered_users", + verbose_name=_("organization"), + help_text=_("Organization associated with this registered user entry."), ) method = models.CharField( _("registration method"), @@ -1639,7 +1665,7 @@ class AbstractRegisteredUser(models.Model): default=False, ) modified = AutoLastModifiedField(_("Last verification change"), editable=True) - _weak_verification_methods = {"", "email"} + _weak_verification_methods = {"", "email", "pending_verification"} @property def is_identity_verified_strong(self): @@ -1649,19 +1675,44 @@ class Meta: abstract = True verbose_name = _("Registration Information") verbose_name_plural = verbose_name + constraints = [ + models.UniqueConstraint( + fields=["user", "organization"], + name="unique_registered_user_per_org", + violation_error_message=_REGISTRATION_UNIQUE_VALIDATION_ERROR, + ), + ] + + @classmethod + def get_or_create_for_user_and_org(cls, user, organization, defaults=None): + defaults = defaults or {} + return cls.objects.get_or_create( + user=user, organization=organization, defaults=defaults + ) + + @classmethod + def get_for_user_and_org(cls, user, organization): + try: + return cls.objects.get(user=user, organization=organization) + except cls.DoesNotExist: + return None @classmethod def unverify_inactive_users(cls): if not app_settings.UNVERIFY_INACTIVE_USERS: return - # Exclude users who have unspecified, manual, or email + # Exclude users who have unspecified, manual, email, or pending_verification # registration method because such users don't have an option # to re-verify. See https://github.com/openwisp/openwisp-radius/issues/517 - cls.objects.exclude(method__in=["", "manual", "email"]).filter( + cls.objects.exclude( + method__in=["", "manual", "email", "pending_verification"] + ).filter( user__is_staff=False, user__last_login__lt=timezone.now() - timedelta(days=app_settings.UNVERIFY_INACTIVE_USERS), - ).update(is_verified=False) + ).update( + is_verified=False + ) @classmethod def delete_inactive_users(cls): diff --git a/openwisp_radius/integrations/monitoring/tasks.py b/openwisp_radius/integrations/monitoring/tasks.py index f251edc3..e46affd3 100644 --- a/openwisp_radius/integrations/monitoring/tasks.py +++ b/openwisp_radius/integrations/monitoring/tasks.py @@ -75,9 +75,9 @@ def _write_user_signup_metric_for_all(metric_key): ) ) # Some manually created users, like superuser may not have a - # RegisteredUser object. We would could them with "unspecified" method + # RegisteredUser object. We would count them with "unspecified" method users_without_registereduser_query = User.objects.filter( - registered_user__isnull=True + registered_users__isnull=True ) if metric_key == "user_signups": users_without_registereduser_query = users_without_registereduser_query.filter( @@ -97,6 +97,8 @@ def _write_user_signup_metric_for_all(metric_key): for method, count in total_registered_users.items(): method = clean_registration_method(method) + if method is None: + continue metric = get_metric_func(organization_id="__all__", registration_method=method) metric_data.append((metric, {"value": count})) Metric.batch_write(metric_data) @@ -116,6 +118,7 @@ def _write_user_signup_metrics_for_orgs(metric_key): # count of users who registered with that organization and method. registered_users_query = RegisteredUser.objects.exclude( user__openwisp_users_organizationuser__created__gt=end_time, + method="pending_verification", ) if metric_key == "user_signups": @@ -131,7 +134,7 @@ def _write_user_signup_metrics_for_orgs(metric_key): # which do not have related RegisteredUser object. Add the count # of such users with the "unspecified" method. users_without_registereduser_query = OrganizationUser.objects.filter( - user__registered_user__isnull=True + user__registered_users__isnull=True ) if metric_key == "user_signups": users_without_registereduser_query = users_without_registereduser_query.filter( @@ -182,18 +185,22 @@ def post_save_radiusaccounting( called_station_id, time=None, ): - try: - registration_method = ( - RegisteredUser.objects.only("method").get(user__username=username).method - ) - except RegisteredUser.DoesNotExist: + registration_method = ( + RegisteredUser.objects.only("method") + .filter(user__username=username, organization_id=organization_id) + .first() + ) + if registration_method is None: logger.info( f'RegisteredUser object not found for "{username}".' ' The metric will be written with "unspecified" registration method!' ) registration_method = "unspecified" else: - registration_method = clean_registration_method(registration_method) + registration_method = registration_method.method + registration_method = clean_registration_method(registration_method) + if registration_method is None: + registration_method = "unspecified" device_lookup = Q(mac_address__iexact=called_station_id.replace("-", ":")) extra_tags = { "method": registration_method, diff --git a/openwisp_radius/integrations/monitoring/tests/test_metrics.py b/openwisp_radius/integrations/monitoring/tests/test_metrics.py index 8a3f6dd7..2cc598ee 100644 --- a/openwisp_radius/integrations/monitoring/tests/test_metrics.py +++ b/openwisp_radius/integrations/monitoring/tests/test_metrics.py @@ -21,8 +21,29 @@ @tag("radius_monitoring") class TestMetrics(CreateDeviceMonitoringMixin, BaseTransactionTestCase): + def _read_chart(self, chart, **kwargs): + return chart.read( + additional_query_kwargs={"additional_params": kwargs}, + ) + + def _assert_pending_verification_excluded(self, points): + pending_verification_traces = [ + trace_points + for trace_name, trace_points in points["traces"] + if trace_name == "pending_verification" + ] + self.assertEqual(pending_verification_traces, []) + self.assertNotIn( + "pending_verification", + points.get("summary", {}), + ) + def _create_registered_user(self, **kwargs): - options = {"is_verified": False, "method": "mobile_phone"} + options = { + "is_verified": False, + "method": "mobile_phone", + "organization": self.default_org, + } options.update(**kwargs) if "user" not in options: options["user"] = self._create_user() @@ -238,6 +259,7 @@ def test_post_save_radius_accounting_device_not_found(self, mocked_logger): convert_called_station_id feature, but it is not configured properly leaving all called_station_id unconverted. """ + cache.clear() user = self._create_user() reg_user = self._create_registered_user(user=user) options = _RADACCT.copy() @@ -254,7 +276,6 @@ def test_post_save_radius_accounting_device_not_found(self, mocked_logger): options["stop_time"] = options["start_time"] # Remove calls for user registration from mocked logger mocked_logger.reset_mock() - self._create_radius_accounting(**options) self.assertEqual( self.metric_model.objects.filter( @@ -368,14 +389,102 @@ def test_post_save_radius_accounting_registereduser_not_found(self, mocked_logge ' The metric will be written with "unspecified" registration method!' ) + def test_post_save_radiusaccounting_pending_verification(self): + """ + Test that when a user has a RegisteredUser with method="pending_verification", + the metric is written with "unspecified" instead of None. + """ + user = self._create_user() + self._create_registered_user(user=user, method="pending_verification") + device = self._create_device() + device_loc = self._create_device_location( + content_object=device, + location=self._create_location(organization=device.organization), + ) + options = _RADACCT.copy() + options.update( + { + "unique_id": "pending_001", + "username": user.username, + "called_station_id": device.mac_address.replace("-", ":").upper(), + "calling_station_id": "00:00:00:00:00:00", + "input_octets": "8000000000", + "output_octets": "9000000000", + } + ) + options["stop_time"] = options["start_time"] + self._create_radius_accounting(**options) + self.assertEqual( + self.metric_model.objects.filter( + configuration="radius_acc", + name="RADIUS Accounting", + key="radius_acc", + object_id=str(device.id), + content_type=ContentType.objects.get_for_model(self.device_model), + extra_tags={ + "called_station_id": device.mac_address, + "calling_station_id": sha1_hash("00:00:00:00:00:00"), + "location_id": str(device_loc.location.id), + "method": "unspecified", + "organization_id": str(self.default_org.id), + }, + ).count(), + 1, + ) + + def test_post_save_radiusaccounting_does_not_fallback_to_other_org( + self, + ): + """ + Test that a RegisteredUser from another organization is not used + when accounting is written for the current organization. + """ + user = self._create_user() + self._create_registered_user( + user=user, organization=self.default_org, method="mobile_phone" + ) + org2 = self._create_org(name="metrics-org-2", slug="metrics-org-2") + self._create_org_user(user=user, organization=org2) + self._create_registered_user(user=user, organization=org2, method="email") + device = self._create_device() + device_loc = self._create_device_location( + content_object=device, + location=self._create_location(organization=device.organization), + ) + options = _RADACCT.copy() + options.update( + { + "unique_id": "org_spec_001", + "username": user.username, + "called_station_id": device.mac_address.replace("-", ":").upper(), + "calling_station_id": "00:00:00:00:00:00", + "input_octets": "8000000000", + "output_octets": "9000000000", + } + ) + options["stop_time"] = options["start_time"] + self._create_radius_accounting(**options) + self.assertEqual( + self.metric_model.objects.filter( + configuration="radius_acc", + name="RADIUS Accounting", + key="radius_acc", + object_id=str(device.id), + content_type=ContentType.objects.get_for_model(self.device_model), + extra_tags={ + "called_station_id": device.mac_address, + "calling_station_id": sha1_hash("00:00:00:00:00:00"), + "location_id": str(device_loc.location.id), + "method": "mobile_phone", + "organization_id": str(self.default_org.id), + }, + ).count(), + 1, + ) + def test_write_user_registration_metrics(self): from ..tasks import write_user_registration_metrics - def _read_chart(chart, **kwargs): - return chart.read( - additional_query_kwargs={"additional_params": kwargs}, - ) - # The TransactionTestCase truncates all the data after each test. # The general metrics and charts which are created by migrations # get deleted after each test. Therefore, we create them again here. @@ -393,21 +502,25 @@ def _read_chart(chart, **kwargs): write_user_registration_metrics.delay() user_signup_chart = user_signup_metric.chart_set.first() - all_points = _read_chart(user_signup_chart, organization_id=["__all__"]) + all_points = self._read_chart( + user_signup_chart, organization_id=["__all__"] + ) self.assertEqual(all_points["traces"][0][0], "unspecified") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual(all_points["summary"], {"unspecified": 1}) - org_points = _read_chart(user_signup_chart, organization_id=[str(org.id)]) + org_points = self._read_chart( + user_signup_chart, organization_id=[str(org.id)] + ) self.assertEqual(len(org_points["traces"]), 0) total_user_signup_chart = total_user_signup_metric.chart_set.first() - all_points = _read_chart( + all_points = self._read_chart( total_user_signup_chart, organization_id=["__all__"] ) self.assertEqual(all_points["traces"][0][0], "unspecified") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual(all_points["summary"], {"unspecified": 1}) - org_points = _read_chart( + org_points = self._read_chart( total_user_signup_chart, organization_id=[str(org.id)] ) self.assertEqual(len(org_points["traces"]), 0) @@ -421,23 +534,27 @@ def _read_chart(chart, **kwargs): write_user_registration_metrics.delay() user_signup_chart = user_signup_metric.chart_set.first() - all_points = _read_chart(user_signup_chart, organization_id=["__all__"]) + all_points = self._read_chart( + user_signup_chart, organization_id=["__all__"] + ) self.assertEqual(all_points["traces"][0][0], "unspecified") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual(all_points["summary"], {"unspecified": 1}) - org_points = _read_chart(user_signup_chart, organization_id=[str(org.id)]) + org_points = self._read_chart( + user_signup_chart, organization_id=[str(org.id)] + ) self.assertEqual(all_points["traces"][0][0], "unspecified") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual(all_points["summary"], {"unspecified": 1}) total_user_signup_chart = total_user_signup_metric.chart_set.first() - all_points = _read_chart( + all_points = self._read_chart( total_user_signup_chart, organization_id=["__all__"] ) self.assertEqual(all_points["traces"][0][0], "unspecified") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual(all_points["summary"], {"unspecified": 1}) - org_points = _read_chart( + org_points = self._read_chart( total_user_signup_chart, organization_id=[str(org.id)] ) self.assertEqual(all_points["traces"][0][0], "unspecified") @@ -454,13 +571,17 @@ def _read_chart(chart, **kwargs): write_user_registration_metrics.delay() user_signup_chart = user_signup_metric.chart_set.first() - all_points = _read_chart(user_signup_chart, organization_id=["__all__"]) + all_points = self._read_chart( + user_signup_chart, organization_id=["__all__"] + ) self.assertEqual(all_points["traces"][0][0], "mobile_phone") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual( all_points["summary"], {"mobile_phone": 1, "unspecified": 0} ) - org_points = _read_chart(user_signup_chart, organization_id=[str(org.id)]) + org_points = self._read_chart( + user_signup_chart, organization_id=[str(org.id)] + ) self.assertEqual(all_points["traces"][0][0], "mobile_phone") self.assertEqual(all_points["traces"][0][1][-1], 1) self.assertEqual( @@ -468,7 +589,7 @@ def _read_chart(chart, **kwargs): ) total_user_signup_chart = total_user_signup_metric.chart_set.first() - org_points = _read_chart( + org_points = self._read_chart( total_user_signup_chart, organization_id=["__all__"] ) self.assertEqual(org_points["traces"][0][0], "mobile_phone") @@ -476,7 +597,7 @@ def _read_chart(chart, **kwargs): self.assertEqual( org_points["summary"], {"mobile_phone": 1, "unspecified": 0} ) - org_points = _read_chart( + org_points = self._read_chart( total_user_signup_chart, organization_id=[str(org.id)] ) self.assertEqual(all_points["traces"][0][0], "mobile_phone") @@ -484,3 +605,33 @@ def _read_chart(chart, **kwargs): self.assertEqual( all_points["summary"], {"mobile_phone": 1, "unspecified": 0} ) + + def test_pending_verification_excluded_from_metrics(self): + from ..tasks import write_user_registration_metrics + + cache.clear() + create_general_metrics(None, None) + org = self._create_org(name="pending_verification_test_org") + user_signup_metric = self.metric_model.objects.get(key="user_signups") + total_user_signup_metric = self.metric_model.objects.get(key="tot_user_signups") + user = self._create_org_user(organization=org).user + self._create_registered_user( + user=user, organization=org, method="pending_verification" + ) + write_user_registration_metrics.delay() + + user_signup_chart = user_signup_metric.chart_set.first() + org_points = self._read_chart(user_signup_chart, organization_id=[str(org.pk)]) + all_points = self._read_chart(user_signup_chart, organization_id=["__all__"]) + self._assert_pending_verification_excluded(org_points) + self._assert_pending_verification_excluded(all_points) + + total_user_signup_chart = total_user_signup_metric.chart_set.first() + org_points = self._read_chart( + total_user_signup_chart, organization_id=[str(org.pk)] + ) + all_points = self._read_chart( + total_user_signup_chart, organization_id=["__all__"] + ) + self._assert_pending_verification_excluded(org_points) + self._assert_pending_verification_excluded(all_points) diff --git a/openwisp_radius/integrations/monitoring/utils.py b/openwisp_radius/integrations/monitoring/utils.py index 6fdb8eee..2528f479 100644 --- a/openwisp_radius/integrations/monitoring/utils.py +++ b/openwisp_radius/integrations/monitoring/utils.py @@ -51,4 +51,6 @@ def sha1_hash(input_string): def clean_registration_method(method): if method == "": method = "unspecified" + elif method == "pending_verification": + return None return method diff --git a/openwisp_radius/management/commands/base/delete_unverified_users.py b/openwisp_radius/management/commands/base/delete_unverified_users.py index ebefc038..eceb2ce7 100644 --- a/openwisp_radius/management/commands/base/delete_unverified_users.py +++ b/openwisp_radius/management/commands/base/delete_unverified_users.py @@ -2,6 +2,7 @@ from django.contrib.auth import get_user_model from django.core.management import BaseCommand +from django.db.models import Count, Q from django.utils.timezone import now from openwisp_radius.utils import load_model @@ -33,14 +34,23 @@ def handle(self, *args, **options): if exclude_methods: exclude_methods = exclude_methods.split(",") - qs = User.objects.filter( - date_joined__lt=days, - registered_user__isnull=False, - registered_user__is_verified=False, - is_staff=False, + qs = ( + User.objects.filter( + date_joined__lt=days, + registered_users__isnull=False, + is_staff=False, + ) + .annotate( + num_verified=Count( + "registered_users", + filter=Q(registered_users__is_verified=True), + ) + ) + .filter(num_verified=0) + .distinct() ) if exclude_methods: - qs = qs.exclude(registered_user__method__in=exclude_methods) + qs = qs.exclude(registered_users__method__in=exclude_methods) for user in qs.iterator(): if not RadiusAccounting.objects.filter(username=user.username).exists(): diff --git a/openwisp_radius/migrations/0043_registereduser_add_uuid.py b/openwisp_radius/migrations/0043_registereduser_add_uuid.py new file mode 100644 index 00000000..3cff024b --- /dev/null +++ b/openwisp_radius/migrations/0043_registereduser_add_uuid.py @@ -0,0 +1,155 @@ +import uuid + +import django.db.models.deletion +import django.utils.timezone +import model_utils.fields +import swapper +from django.conf import settings +from django.db import migrations, models + +from . import copy_registered_users_ctcr_forward, copy_registered_users_ctcr_reverse + + +def copy_registered_users_forward(apps, schema_editor): + copy_registered_users_ctcr_forward(apps, schema_editor, app_label="openwisp_radius") + + +def copy_registered_users_reverse(apps, schema_editor): + copy_registered_users_ctcr_reverse(apps, schema_editor, app_label="openwisp_radius") + + +class Migration(migrations.Migration): + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + swapper.dependency("openwisp_users", "Organization"), + ("openwisp_radius", "0042_set_existing_batches_completed"), + ] + + operations = [ + migrations.SeparateDatabaseAndState( + state_operations=[ + migrations.AddField( + model_name="registereduser", + name="id", + field=models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + migrations.AlterField( + model_name="registereduser", + name="user", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="registered_users", + to=settings.AUTH_USER_MODEL, + ), + ), + migrations.AddField( + model_name="registereduser", + name="organization", + field=models.ForeignKey( + blank=True, + help_text=( + "Organization associated with this registered user entry." + ), + null=True, + related_name="registered_users", + on_delete=django.db.models.deletion.CASCADE, + to=swapper.get_model_name("openwisp_users", "Organization"), + verbose_name="organization", + ), + ), + ], + database_operations=[ + migrations.CreateModel( + name="RegisteredUserNew", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "method", + models.CharField( + blank=True, + default="", + help_text=( + "users can sign up in different ways, some " + "methods are valid as indirect identity " + "verification (eg: mobile phone SIM card in " + "most countries)" + ), + max_length=64, + verbose_name="registration method", + ), + ), + ( + "is_verified", + models.BooleanField( + default=False, + help_text=( + "whether the user has completed any identity " + "verification process sucessfully" + ), + verbose_name="verified", + ), + ), + ( + "modified", + model_utils.fields.AutoLastModifiedField( + default=django.utils.timezone.now, + editable=False, + verbose_name="Last verification change", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "organization", + models.ForeignKey( + blank=True, + help_text=( + "Organization associated with this registered user" + " entry." + ), + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=swapper.get_model_name( + "openwisp_users", "Organization" + ), + verbose_name="organization", + ), + ), + ], + options={ + "verbose_name": "Registration Information", + "verbose_name_plural": "Registration Information", + }, + ), + migrations.RunPython( + copy_registered_users_forward, + copy_registered_users_reverse, + ), + migrations.DeleteModel(name="RegisteredUser"), + migrations.RenameModel( + old_name="RegisteredUserNew", + new_name="RegisteredUser", + ), + ], + ), + ] diff --git a/openwisp_radius/migrations/0044_registered_user_multitenant_data.py b/openwisp_radius/migrations/0044_registered_user_multitenant_data.py new file mode 100644 index 00000000..da104a51 --- /dev/null +++ b/openwisp_radius/migrations/0044_registered_user_multitenant_data.py @@ -0,0 +1,31 @@ +from django.db import migrations + +from . import ( + migrate_registered_users_multitenant_forward, + migrate_registered_users_multitenant_reverse, +) + + +def migrate_registered_users_forward(apps, schema_editor): + migrate_registered_users_multitenant_forward( + apps, schema_editor, app_label="openwisp_radius" + ) + + +def migrate_registered_users_reverse(apps, schema_editor): + migrate_registered_users_multitenant_reverse( + apps, schema_editor, app_label="openwisp_radius" + ) + + +class Migration(migrations.Migration): + dependencies = [ + ("openwisp_radius", "0043_registereduser_add_uuid"), + ] + + operations = [ + migrations.RunPython( + migrate_registered_users_forward, + migrate_registered_users_reverse, + ), + ] diff --git a/openwisp_radius/migrations/0045_registered_user_multitenant_constraints.py b/openwisp_radius/migrations/0045_registered_user_multitenant_constraints.py new file mode 100644 index 00000000..6330406f --- /dev/null +++ b/openwisp_radius/migrations/0045_registered_user_multitenant_constraints.py @@ -0,0 +1,32 @@ +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("openwisp_radius", "0044_registered_user_multitenant_data"), + ] + + operations = [ + migrations.AlterField( + model_name="registereduser", + name="organization", + field=models.ForeignKey( + help_text="Organization associated with this registered user entry.", + on_delete=models.deletion.CASCADE, + related_name="registered_users", + to="openwisp_users.organization", + verbose_name="organization", + ), + ), + migrations.AddConstraint( + model_name="registereduser", + constraint=models.UniqueConstraint( + fields=["user", "organization"], + name="unique_registered_user_per_org", + violation_error_message=( + "A user cannot have more than one registration record in the same" + " organization." + ), + ), + ), + ] diff --git a/openwisp_radius/migrations/__init__.py b/openwisp_radius/migrations/__init__.py index 45c9abf2..7e414de8 100644 --- a/openwisp_radius/migrations/__init__.py +++ b/openwisp_radius/migrations/__init__.py @@ -1,12 +1,16 @@ import uuid +from collections import defaultdict import swapper from django.conf import settings from django.contrib.auth.management import create_permissions from django.contrib.auth.models import Permission +from django.db.models import Case, IntegerField, Value, When from ..utils import create_default_groups +BATCH_SIZE = 1000 + def get_swapped_model(apps, app_name, model_name): model_path = swapper.get_model_name(app_name, model_name) @@ -14,6 +18,227 @@ def get_swapped_model(apps, app_name, model_name): return apps.get_model(app, model) +def _batched_iterator(iterator, batch_size=BATCH_SIZE): + batch = [] + for item in iterator: + batch.append(item) + if len(batch) >= batch_size: + yield batch + batch = [] + if batch: + yield batch + + +def _flush_bulk_create(model, objects, batch_size=BATCH_SIZE): + if objects: + model.objects.bulk_create(objects, batch_size=batch_size) + objects.clear() + + +def _flush_bulk_update(model, objects, fields, batch_size=BATCH_SIZE): + if objects: + model.objects.bulk_update(objects, fields=fields, batch_size=batch_size) + objects.clear() + + +def _registered_user_extra_kwargs(registered_user, extra_fields=()): + return { + field_name: getattr(registered_user, field_name) for field_name in extra_fields + } + + +def _registered_user_method_priority_case(): + # Strong methods (anything that is not '' or 'email') must rank above the + # weak fallbacks so rollback restores the strongest verification state. + return Case( + When(method="", then=Value(0)), + When(method="email", then=Value(1)), + default=Value(2), + output_field=IntegerField(), + ) + + +def _registered_user_method_priority(registered_user): + if registered_user.method == "": + return 0 + if registered_user.method == "email": + return 1 + return 2 + + +def _registered_user_strength(registered_user): + return ( + int(registered_user.is_verified), + _registered_user_method_priority(registered_user), + registered_user.modified, + ) + + +def copy_registered_users_ctcr_forward( + apps, + schema_editor, + app_label, + new_model_name="RegisteredUserNew", + extra_fields=(), +): + RegisteredUser = apps.get_model(app_label, "RegisteredUser") + RegisteredUserNew = apps.get_model(app_label, new_model_name) + if RegisteredUser._meta.swapped: + return + + new_objects = [] + queryset = RegisteredUser.objects.order_by("user_id") + for registered_user in queryset.iterator(chunk_size=BATCH_SIZE): + copied = RegisteredUserNew( + id=uuid.uuid4(), + user_id=registered_user.user_id, + organization=None, + method=registered_user.method, + is_verified=registered_user.is_verified, + **_registered_user_extra_kwargs(registered_user, extra_fields), + ) + copied.modified = registered_user.modified + new_objects.append(copied) + if len(new_objects) >= BATCH_SIZE: + _flush_bulk_create(RegisteredUserNew, new_objects) + _flush_bulk_create(RegisteredUserNew, new_objects) + + +def copy_registered_users_ctcr_reverse( + apps, + schema_editor, + app_label, + new_model_name="RegisteredUserNew", + extra_fields=(), +): + RegisteredUser = apps.get_model(app_label, "RegisteredUser") + RegisteredUserNew = apps.get_model(app_label, new_model_name) + if RegisteredUser._meta.swapped: + return + + restored_objects = [] + previous_user_id = None + # Annotate each row with an explicit verification priority so that stronger + # methods (anything that is not '' or 'email') sort before weaker ones. + # Lexical ordering of 'method' would place '' first, picking the weakest. + method_priority = _registered_user_method_priority_case() + queryset = RegisteredUserNew.objects.annotate( + method_priority=method_priority + ).order_by("user_id", "-is_verified", "-method_priority", "-modified") + for registered_user in queryset.iterator(chunk_size=BATCH_SIZE): + if registered_user.user_id == previous_user_id: + continue + previous_user_id = registered_user.user_id + restored = RegisteredUser( + user_id=registered_user.user_id, + method=registered_user.method, + is_verified=registered_user.is_verified, + **_registered_user_extra_kwargs(registered_user, extra_fields), + ) + restored.modified = registered_user.modified + restored_objects.append(restored) + if len(restored_objects) >= BATCH_SIZE: + _flush_bulk_create(RegisteredUser, restored_objects) + _flush_bulk_create(RegisteredUser, restored_objects) + + +def migrate_registered_users_multitenant_forward( + apps, schema_editor, app_label, extra_fields=() +): + RegisteredUser = apps.get_model(app_label, "RegisteredUser") + OrganizationUser = get_swapped_model(apps, "openwisp_users", "OrganizationUser") + + queryset = RegisteredUser.objects.filter(organization__isnull=True).order_by( + "user_id" + ) + iterator = queryset.iterator(chunk_size=BATCH_SIZE) + for batch in _batched_iterator(iterator, BATCH_SIZE): + user_ids = [registered_user.user_id for registered_user in batch] + memberships = defaultdict(set) + membership_qs = OrganizationUser.objects.filter( + user_id__in=user_ids + ).values_list("user_id", "organization_id") + for user_id, organization_id in membership_qs.iterator(chunk_size=BATCH_SIZE): + memberships[user_id].add(organization_id) + + existing_pairs = set( + RegisteredUser.objects.filter( + user_id__in=user_ids, + organization__isnull=False, + ).values_list("user_id", "organization_id") + ) + + to_create = [] + for registered_user in batch: + organization_ids = sorted(memberships.get(registered_user.user_id, ())) + if not organization_ids: + continue + extra_kwargs = _registered_user_extra_kwargs(registered_user, extra_fields) + for organization_id in organization_ids: + pair = (registered_user.user_id, organization_id) + if pair in existing_pairs: + continue + existing_pairs.add(pair) + copied = RegisteredUser( + id=uuid.uuid4(), + user_id=registered_user.user_id, + organization_id=organization_id, + is_verified=registered_user.is_verified, + method=registered_user.method, + **extra_kwargs, + ) + copied.modified = registered_user.modified + to_create.append(copied) + + _flush_bulk_create(RegisteredUser, to_create) + + +def migrate_registered_users_multitenant_reverse( + apps, schema_editor, app_label, extra_fields=() +): + # Keep the strongest RegisteredUser per user and delete the weaker duplicates. + # Ranking is by: verified over unverified, stronger method over weaker method, + # then newer modified timestamps over older ones. + RegisteredUser = apps.get_model(app_label, "RegisteredUser") + # Process users in batches so the migration scales to large tables without + # issuing one query per user. + user_ids_qs = ( + RegisteredUser.objects.order_by().values_list("user_id", flat=True).distinct() + ) + for user_id_batch in _batched_iterator( + user_ids_qs.iterator(chunk_size=BATCH_SIZE), BATCH_SIZE + ): + # Annotate each row with an explicit verification priority so that stronger + # methods (anything that is not '' or 'email') sort before weaker ones. + method_priority = _registered_user_method_priority_case() + ranked_registered_users = ( + RegisteredUser.objects.filter( + user_id__in=user_id_batch, + ) + .annotate(method_priority=method_priority) + .order_by("user_id", "-is_verified", "-method_priority", "-modified") + ) + to_delete_pks = [] + current_user_id = None + for registered_user in ranked_registered_users.iterator(chunk_size=BATCH_SIZE): + # Rows for the same user are consecutive because of the ordering + # above, and the first row in each group is the strongest one. + # Every later row for that user is therefore a weaker duplicate. + is_duplicate_for_user = registered_user.user_id == current_user_id + if is_duplicate_for_user: + to_delete_pks.append(registered_user.pk) + else: + current_user_id = registered_user.user_id + if len(to_delete_pks) >= BATCH_SIZE: + RegisteredUser.objects.filter(pk__in=to_delete_pks).delete() + to_delete_pks.clear() + + # Delete all weaker rows for the batch at once rather than issuing a + # separate delete for each user. + if to_delete_pks: + RegisteredUser.objects.filter(pk__in=to_delete_pks).delete() + + def delete_old_radius_token(apps, schema_editor): RadiusToken = get_swapped_model(apps, "openwisp_radius", "RadiusToken") RadiusToken.objects.all().delete() diff --git a/openwisp_radius/registration.py b/openwisp_radius/registration.py index e376232d..178120ff 100644 --- a/openwisp_radius/registration.py +++ b/openwisp_radius/registration.py @@ -10,6 +10,7 @@ ("manual", _("Manually created")), ("email", _("Email")), ("mobile_phone", _("Mobile phone")), + ("pending_verification", _("Pending Verification")), ] AUTHORIZE_UNVERIFIED = [] diff --git a/openwisp_radius/saml/backends.py b/openwisp_radius/saml/backends.py index f61d5d55..3c55a657 100644 --- a/openwisp_radius/saml/backends.py +++ b/openwisp_radius/saml/backends.py @@ -1,4 +1,3 @@ -from django.core.exceptions import ObjectDoesNotExist from djangosaml2.backends import Saml2Backend from .. import settings as app_settings @@ -12,20 +11,27 @@ def _update_user(self, user, attributes, attribute_mapping, force_save=False): ): # Skip updating user's username if the user didn't signed up # with SAML registration method. - try: - attribute_mapping = attribute_mapping.copy() - if user.registered_user.method != "saml": - for key, value in attribute_mapping.items(): - if "username" in value: - break - if len(value) == 1: - attribute_mapping.pop(key, None) - else: - attribute_mapping[key] = [] - for attr in value: - if attr != "username": - attribute_mapping[key].append(attr) - - except ObjectDoesNotExist: - pass + attribute_mapping = attribute_mapping.copy() + # Check if any of the user's registered_users records + # were NOT created via SAML. + # NOTE: This uses a global check (any org) rather than org-specific. + # This is intentionally conservative: if a user has ever signed up + # via a non-SAML method in any org, their username won't be updated + # during SAML login in any org. This prevents the SAML identity + # provider from overwriting a username set or preferred by the user + # elsewhere. Since the User model is shared across organizations, + # updating the username based solely on one org's SAML flow could + # unexpectedly change the user's identity in other orgs. + has_non_saml = user.registered_users.exclude(method="saml").exists() + if has_non_saml: + for key, value in attribute_mapping.items(): + if "username" in value: + break + if len(value) == 1: + attribute_mapping.pop(key, None) + else: + attribute_mapping[key] = [] + for attr in value: + if attr != "username": + attribute_mapping[key].append(attr) return super()._update_user(user, attributes, attribute_mapping, force_save) diff --git a/openwisp_radius/saml/views.py b/openwisp_radius/saml/views.py index 95bf5a25..2e953518 100644 --- a/openwisp_radius/saml/views.py +++ b/openwisp_radius/saml/views.py @@ -9,6 +9,7 @@ from django.contrib.auth import get_user_model, logout from django.contrib.auth.mixins import LoginRequiredMixin from django.core.exceptions import ObjectDoesNotExist, PermissionDenied, ValidationError +from django.db import transaction from django.shortcuts import get_object_or_404, redirect, render from django.urls import reverse from django.views.generic import UpdateView @@ -67,20 +68,21 @@ def post_login_hook(self, request, user, session_info): org = self.get_organization_from_relay_state() is_member = user.is_member(org) # add user to organization - if not is_member: - orgUser = OrganizationUser(organization=org, user=user) - orgUser.full_clean() - orgUser.save() - try: - user.registered_user - except ObjectDoesNotExist: - registered_user = RegisteredUser( - user=user, method="saml", is_verified=app_settings.SAML_IS_VERIFIED + with transaction.atomic(): + if not is_member: + orgUser = OrganizationUser(organization=org, user=user) + orgUser.full_clean() + orgUser.save() + registered_user, created = RegisteredUser.objects.get_or_create( + user=user, + organization=org, + defaults={ + "method": "saml", + "is_verified": app_settings.SAML_IS_VERIFIED, + }, ) - registered_user.full_clean() - registered_user.save() - # The user is just created, it will not have an email address - if user.email: + if created and user.email: + # The user is just created, it will not have an email address try: email_address = EmailAddress( user=user, email=user.email, primary=True, verified=True @@ -89,8 +91,8 @@ def post_login_hook(self, request, user, session_info): email_address.save() except ValidationError: logger.exception( - f'Failed email validation for "{user}"' - " during SAML user creation" + f'Failed email validation for "{user}" during' + " SAML user creation" ) def customize_relay_state(self, relay_state): diff --git a/openwisp_radius/settings.py b/openwisp_radius/settings.py index e7f908dd..dea0d461 100644 --- a/openwisp_radius/settings.py +++ b/openwisp_radius/settings.py @@ -232,10 +232,13 @@ def get_default_password_reset_url(urls): if not hasattr(settings, "OPENWISP_USERS_EXPORT_USERS_COMMAND_CONFIG"): from openwisp_users import settings as ow_users_settings - ow_users_settings.EXPORT_USERS_COMMAND_CONFIG["fields"].extend( - ["registered_user.method", "registered_user.is_verified"] + ow_users_settings.EXPORT_USERS_COMMAND_CONFIG["fields"].append( + { + "name": "registered_users", + "fields": ("organization_id", "method", "is_verified"), + } ) - ow_users_settings.EXPORT_USERS_COMMAND_CONFIG["select_related"].extend( - ["registered_user"] + ow_users_settings.EXPORT_USERS_COMMAND_CONFIG["prefetch_related"].extend( + ["registered_users"] ) BATCH_ASYNC_THRESHOLD = get_settings_value("BATCH_ASYNC_THRESHOLD", 15) diff --git a/openwisp_radius/social/views.py b/openwisp_radius/social/views.py index cc50a3f8..21db4fe2 100644 --- a/openwisp_radius/social/views.py +++ b/openwisp_radius/social/views.py @@ -1,5 +1,6 @@ import swapper -from django.core.exceptions import ObjectDoesNotExist, PermissionDenied +from django.core.exceptions import PermissionDenied +from django.db import transaction from django.http import HttpResponse, HttpResponseRedirect from django.shortcuts import get_object_or_404 from django.utils.translation import gettext_lazy as _ @@ -42,18 +43,19 @@ def authorize(self, request, org, *args, **kwargs): user = request.user is_member = user.is_member(org) # add user to organization - if not is_member: - orgUser = OrganizationUser(organization=org, user=user) - orgUser.full_clean() - orgUser.save() - try: - user.registered_user - except ObjectDoesNotExist: - registered_user = RegisteredUser( - user=user, method="social_login", is_verified=False + with transaction.atomic(): + if not is_member: + orgUser = OrganizationUser(organization=org, user=user) + orgUser.full_clean() + orgUser.save() + registered_user, created = RegisteredUser.objects.get_or_create( + user=user, + organization=org, + defaults={"method": "social_login", "is_verified": False}, ) - registered_user.full_clean() - registered_user.save() + if not created: + registered_user.full_clean() + registered_user.save() def get_redirect_url(self, request, organization): """ diff --git a/openwisp_radius/tests/mixins.py b/openwisp_radius/tests/mixins.py index 1852116d..01e39c19 100644 --- a/openwisp_radius/tests/mixins.py +++ b/openwisp_radius/tests/mixins.py @@ -97,10 +97,10 @@ def _get_user_edit_form_inline_params(self, user, organization): "phonetoken_set-MIN_NUM_FORMS": 0, "phonetoken_set-MAX_NUM_FORMS": 0, # registered user inline - "registered_user-TOTAL_FORMS": 0, - "registered_user-INITIAL_FORMS": 0, - "registered_user-MIN_NUM_FORMS": 0, - "registered_user-MAX_NUM_FORMS": 0, + "registered_users-TOTAL_FORMS": 0, + "registered_users-INITIAL_FORMS": 0, + "registered_users-MIN_NUM_FORMS": 0, + "registered_users-MAX_NUM_FORMS": 0, # radius token inline "radius_token-TOTAL_FORMS": "0", "radius_token-INITIAL_FORMS": "0", diff --git a/openwisp_radius/tests/test_admin.py b/openwisp_radius/tests/test_admin.py index bc829810..4e430f8e 100644 --- a/openwisp_radius/tests/test_admin.py +++ b/openwisp_radius/tests/test_admin.py @@ -671,16 +671,19 @@ def test_backward_compatible_default_password_reset_url(self): f"admin:{self.app_label_users}_organization_add", ) PASSWORD_RESET_URLS = {"default": default_password_reset_url} - with mock.patch.object( - app_settings, - "DEFAULT_PASSWORD_RESET_URL", - app_settings.get_default_password_reset_url(PASSWORD_RESET_URLS), - ), mock.patch.object( - # The default value is set on project startup, hence - # it also requires mocking. - OrganizationRadiusSettings._meta.get_field("password_reset_url"), - "fallback", - app_settings.DEFAULT_PASSWORD_RESET_URL, + with ( + mock.patch.object( + app_settings, + "DEFAULT_PASSWORD_RESET_URL", + app_settings.get_default_password_reset_url(PASSWORD_RESET_URLS), + ), + mock.patch.object( + # The default value is set on project startup, hence + # it also requires mocking. + OrganizationRadiusSettings._meta.get_field("password_reset_url"), + "fallback", + app_settings.DEFAULT_PASSWORD_RESET_URL, + ), ): response = self.client.get(url) self.assertContains(response, default_password_reset_url) @@ -1359,7 +1362,7 @@ def test_inline_registered_user(self): with self.subTest("Inline exists"): response = self.client.get(url) - self.assertContains(response, "id_registered_user-TOTAL_FORMS") + self.assertContains(response, "id_registered_users-TOTAL_FORMS") with self.subTest("Register new choice"): register_registration_method("national_id", "National ID") @@ -1407,6 +1410,66 @@ def test_inline_registered_user(self): register_registration_method("github", "GitHub", strong_identity=False) self.assertIn("github", RegisteredUser._weak_verification_methods) + def test_admin_prevents_duplicate_registered_user_same_org(self): + user = self._create_user(username="dup_test_user", email="dup@test.org") + reg_user = RegisteredUser.objects.create( + user=user, organization=self.default_org, is_verified=True + ) + user_change_url = reverse( + f"admin:{User._meta.app_label}_user_change", args=[user.pk] + ) + response = self.client.get(user_change_url) + self.assertEqual(response.status_code, 200) + data = { + "username": "dup_test_user", + "email": "dup@test.org", + "registered_users-TOTAL_FORMS": "2", + "registered_users-INITIAL_FORMS": "1", + "registered_users-MIN_NUM_FORMS": "0", + "registered_users-MAX_NUM_FORMS": "1000", + "registered_users-0-id": str(reg_user.pk), + "registered_users-0-user": str(user.pk), + "registered_users-0-organization": str(self.default_org.pk), + "registered_users-0-method": "", + "registered_users-0-is_verified": "on", + "registered_users-1-id": "", + "registered_users-1-user": str(user.pk), + "registered_users-1-organization": str(self.default_org.pk), + "registered_users-1-method": "", + "registered_users-1-is_verified": "on", + } + response = self.client.post(user_change_url, data) + self.assertContains(response, "errors") + self.assertContains( + response, + "A user cannot have more than one registration record in the" + " same organization.", + ) + self.assertEqual( + RegisteredUser.objects.filter( + user=user, organization=self.default_org + ).count(), + 1, + ) + + def test_user_admin_shows_multiple_registered_user_records(self): + user = self._create_user(username="multiuser", email="multi@test.org") + org2 = self._create_org(name="org2", slug="org2") + RegisteredUser.objects.create( + user=user, organization=self.default_org, is_verified=True + ) + RegisteredUser.objects.create(user=user, organization=org2, is_verified=False) + user_url = reverse(f"admin:{User._meta.app_label}_user_change", args=[user.pk]) + response = self.client.get(user_url) + self.assertEqual(response.status_code, 200) + self.assertContains( + response, + ( + '' + ), + ) + def test_get_is_verified_user_admin_list(self): unknown = User.objects.first() self.assertIsNotNone(unknown) @@ -1416,7 +1479,10 @@ def test_get_is_verified_user_admin_list(self): verified.full_clean() verified.save() RegisteredUser.objects.create( - user=verified, method="mobile_phone", is_verified=True + user=verified, + organization=self.default_org, + method="mobile_phone", + is_verified=True, ) unverified = User.objects.create( username="unverified", password="unverified", email="unverified@test.com" @@ -1424,7 +1490,10 @@ def test_get_is_verified_user_admin_list(self): unverified.full_clean() unverified.save() RegisteredUser.objects.create( - user=unverified, method="mobile_phone", is_verified=False + user=unverified, + organization=self.default_org, + method="mobile_phone", + is_verified=False, ) app_label = User._meta.app_label url = reverse(f"admin:{app_label}_user_changelist") @@ -1449,7 +1518,10 @@ def test_registered_user_filter(self): verified.full_clean() verified.save() RegisteredUser.objects.create( - user=verified, method="mobile_phone", is_verified=True + user=verified, + organization=self.default_org, + method="mobile_phone", + is_verified=True, ) unverified = User.objects.create( username="unverified", password="unverified", email="unverified@test.com" @@ -1457,7 +1529,10 @@ def test_registered_user_filter(self): unverified.full_clean() unverified.save() RegisteredUser.objects.create( - user=unverified, method="mobile_phone", is_verified=False + user=unverified, + organization=self.default_org, + method="mobile_phone", + is_verified=False, ) app_label = User._meta.app_label url = reverse(f"admin:{app_label}_user_changelist") diff --git a/openwisp_radius/tests/test_api/test_api.py b/openwisp_radius/tests/test_api/test_api.py index d0a6f3d5..2b1d8b72 100644 --- a/openwisp_radius/tests/test_api/test_api.py +++ b/openwisp_radius/tests/test_api/test_api.py @@ -41,6 +41,7 @@ RadiusBatch = load_model("RadiusBatch") RadiusUserGroup = load_model("RadiusUserGroup") RadiusGroup = load_model("RadiusGroup") +RegisteredUser = load_model("RegisteredUser") OrganizationRadiusSettings = load_model("OrganizationRadiusSettings") Organization = swapper.load_model("openwisp_users", "Organization") OrganizationUser = swapper.load_model("openwisp_users", "OrganizationUser") @@ -60,10 +61,34 @@ def _radius_batch_post_request(self, data, username="admin", password="tester"): login_payload = {"username": username, "password": password} login_url = reverse("radius:user_auth_token", args=[self.default_org.slug]) login_response = self.client.post(login_url, data=login_payload) - header = f'Bearer {login_response.json()["key"]}' + header = f"Bearer {login_response.json()['key']}" url = reverse("radius:batch") return self.client.post(url, data, HTTP_AUTHORIZATION=header) + def _get_update_method_url(self, org=None): + if org is None: + org = self.default_org + return reverse( + "radius:update_registered_user_registration_method", args=[org.slug] + ) + + def _create_pending_verification_user(self, username_suffix=""): + user = self._create_user( + username=f"pendinguser{username_suffix}", + password="tester", + email=f"pendinguser{username_suffix}@test.com", + ) + org2 = self._create_org(name="org2") + OrganizationUser.objects.create(user=user, organization=org2) + RegisteredUser.objects.create( + user=user, + organization=org2, + method="pending_verification", + is_verified=False, + ) + user_token = Token.objects.create(user=user) + return user, org2, user_token + def test_batch_bad_request_400(self): self.assertEqual(RadiusBatch.objects.count(), 0) data = self._radius_batch_prefix_data(number_of_users=-1) @@ -159,7 +184,10 @@ def test_register_201(self): user = User.objects.get(email=self._test_email) self.assertTrue(user.is_member(self.default_org)) self.assertTrue(user.is_active) - self.assertFalse(user.registered_user.is_verified) + self.assertEqual( + user.registered_users.get(organization=self.default_org).is_verified, + False, + ) def test_register_400_password(self): response = self._register_user( @@ -319,19 +347,27 @@ def test_register_duplicate_different_org(self): def test_radius_user_serializer(self): self._register_user() try: - user = User.objects.select_related("radius_token", "registered_user").get( - email=self._test_email + user = ( + User.objects.select_related("radius_token") + .prefetch_related("registered_users") + .get(email=self._test_email) ) - admin = User.objects.select_related("radius_token", "registered_user").get( - username="admin" + admin = ( + User.objects.select_related("radius_token") + .prefetch_related("registered_users") + .get(username="admin") ) except User.DoesNotExist as e: self.fail(f"user not found: {e}") with self.assertNumQueries(0): - data = RadiusUserSerializer(user).data + # Organization is required to get the RegisteredUser object + view = mock.MagicMock() + view.organization = self.default_org + data = RadiusUserSerializer(user, context={"view": view}).data with self.subTest("test full data"): + registered_user = user.registered_users.get(organization=self.default_org) self.assertEqual( data, { @@ -343,9 +379,9 @@ def test_radius_user_serializer(self): "birth_date": user.birth_date, "location": user.location, "is_active": user.is_active, - "is_verified": user.registered_user.is_verified, "password_expired": user.has_password_expired(), - "method": user.registered_user.method, + "is_verified": registered_user.is_verified, + "method": registered_user.method, "radius_user_token": user.radius_token.key, }, ) @@ -370,6 +406,44 @@ def test_radius_user_serializer(self): }, ) + with self.subTest("org-specific record is returned for the current org"): + user2 = self._create_user(username="user2", email="user2@test.com") + self._create_org_user(user=user2, organization=self.default_org) + RegisteredUser.objects.create( + user=user2, + organization=self.default_org, + is_verified=True, + method="mobile_phone", + ) + url = reverse("radius:user_auth_token", args=[self.default_org.slug]) + r = self.client.post(url, {"username": "user2", "password": "tester"}) + self.assertEqual(r.status_code, 200) + self.assertEqual(r.data["is_verified"], True) + self.assertEqual(r.data["method"], "mobile_phone") + + with self.subTest("other-organization record is not used as fallback"): + user3 = self._create_user(username="user3", email="user3@test.com") + self._create_org_user(user=user3, organization=self.default_org) + org2 = self._create_org(name="serializer-org2", slug="serializer-org2") + self._create_org_user(user=user3, organization=org2) + RegisteredUser.objects.create( + user=user3, organization=org2, is_verified=True, method="email" + ) + url = reverse("radius:user_auth_token", args=[self.default_org.slug]) + r = self.client.post(url, {"username": "user3", "password": "tester"}) + self.assertEqual(r.status_code, 200) + self.assertIsNone(r.data["is_verified"]) + self.assertIsNone(r.data["method"]) + + with self.subTest("returns None when no RegisteredUser records exist"): + user4 = self._create_user(username="user4", email="user4@test.com") + self._create_org_user(user=user4, organization=self.default_org) + url = reverse("radius:user_auth_token", args=[self.default_org.slug]) + r = self.client.post(url, {"username": "user4", "password": "tester"}) + self.assertEqual(r.status_code, 200) + self.assertIsNone(r.data["is_verified"]) + self.assertIsNone(r.data["method"]) + # The fallback value is set on project startup, hence it also requires mocking. @mock.patch.object( OrganizationRadiusSettings._meta.get_field("first_name"), @@ -916,7 +990,7 @@ def test_user_accounting_list_200(self): response = self.client.post( auth_url, {"username": "tester", "password": "tester"} ) - authorization = f'Bearer {response.data["key"]}' + authorization = f"Bearer {response.data['key']}" stop_time = "2018-03-02T11:43:24.020460+01:00" data1 = self.acct_post_data data1.update( @@ -1556,6 +1630,121 @@ def test_radius_user_group_serializer_without_view_context(self): self.assertEqual(serializer._user, None) self.assertEqual(serializer.fields["group"].queryset.count(), 0) + def test_update_registered_user_method_success(self): + user, org2, user_token = self._create_pending_verification_user( + username_suffix="_success" + ) + url = self._get_update_method_url(org2) + response = self.client.post( + url, + {"method": "mobile_phone"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["method"], "mobile_phone") + registered_user = RegisteredUser.objects.get(user=user, organization=org2) + self.assertEqual(registered_user.method, "mobile_phone") + self.assertEqual(registered_user.is_verified, False) + + def test_update_registered_user_method_with_valid_methods(self): + user, org2, user_token = self._create_pending_verification_user( + username_suffix="_valid" + ) + url = self._get_update_method_url(org2) + for method in ["", "manual", "email", "mobile_phone"]: + with self.subTest(method=method): + registered_user = RegisteredUser.objects.get( + user=user, organization=org2 + ) + registered_user.method = "pending_verification" + registered_user.save() + response = self.client.post( + url, + {"method": method}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["method"], method) + + def test_update_registered_user_method_validation_errors(self): + user, org2, user_token = self._create_pending_verification_user() + url = self._get_update_method_url(org2) + with self.subTest("reject_pending_verification_as_input"): + response = self.client.post( + url, + {"method": "pending_verification"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 400) + + with self.subTest("reject_invalid_method"): + response = self.client.post( + url, + {"method": "invalid_method"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 400) + + with self.subTest("reject_non_pending_state"): + registered_user = RegisteredUser.objects.get(user=user, organization=org2) + registered_user.method = "mobile_phone" + registered_user.save() + response = self.client.post( + url, + {"method": "email"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 400) + self.assertIn("pending verification", response.data["method"][0]) + + def test_update_registered_user_method_404_cases(self): + with self.subTest("not_found_without_registered_user"): + user = self._create_user(username="noreguser", password="tester") + user_token = Token.objects.create(user=user) + url = self._get_update_method_url() + response = self.client.post( + url, + {"method": "mobile_phone"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 404) + + with self.subTest("only_owner_can_update"): + user, org2, user_token = self._create_pending_verification_user( + username_suffix="_owner" + ) + other_user = self._create_user( + username="otheruser", password="tester", email="otheruser@test.com" + ) + other_user_token = Token.objects.create(user=other_user) + url = self._get_update_method_url(org2) + response = self.client.post( + url, + {"method": "mobile_phone"}, + HTTP_AUTHORIZATION=f"Bearer {other_user_token.key}", + ) + self.assertEqual(response.status_code, 404) + + with self.subTest("invalid_org"): + user, _, user_token = self._create_pending_verification_user( + username_suffix="_invalid_org" + ) + url = reverse( + "radius:update_registered_user_registration_method", + args=["nonexistent-org-slug"], + ) + response = self.client.post( + url, + {"method": "mobile_phone"}, + HTTP_AUTHORIZATION=f"Bearer {user_token.key}", + ) + self.assertEqual(response.status_code, 404) + + def test_update_registered_user_method_requires_authentication(self): + url = self._get_update_method_url() + response = self.client.post(url, {"method": "mobile_phone"}) + self.assertEqual(response.status_code, 401) + class TestTransactionApi(AcctMixin, ApiTokenMixin, BaseTransactionTestCase): def test_user_radius_usage_view(self): @@ -1565,7 +1754,7 @@ def test_user_radius_usage_view(self): response = self.client.post( auth_url, {"username": "tester", "password": "tester"} ) - authorization = f'Bearer {response.data["key"]}' + authorization = f"Bearer {response.data['key']}" self.assertEqual(response.status_code, 200) with self.subTest("Test user has not used any data"): response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization) diff --git a/openwisp_radius/tests/test_api/test_freeradius_api.py b/openwisp_radius/tests/test_api/test_freeradius_api.py index 64968edd..55ea5e49 100644 --- a/openwisp_radius/tests/test_api/test_freeradius_api.py +++ b/openwisp_radius/tests/test_api/test_freeradius_api.py @@ -172,7 +172,7 @@ def test_authorize_fail_auth_details_incomplete(self): f"?uuid={str(self.default_org.pk)}", ]: with self.subTest(querystring): - post_url = f'{reverse("radius:authorize")}{querystring}' + post_url = f"{reverse('radius:authorize')}{querystring}" response = self.client.post( post_url, {"username": "tester", "password": "tester"} ) @@ -206,6 +206,134 @@ def test_authorize_unverified_user(self): self.assertEqual(response.status_code, 200) self.assertIsNone(response.data) + def test_authorize_verified_user(self): + org_user = self._get_org_user() + user = org_user.user + org_settings = OrganizationRadiusSettings.objects.get( + organization=self._get_org() + ) + org_settings.needs_identity_verification = True + org_settings.save() + + with self.subTest("org-specific verified record passes authorization"): + RegisteredUser.objects.create( + user=user, organization=self._get_org(), is_verified=True + ) + response = self._authorize_user(auth_header=self.auth_header) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, {"control:Auth-Type": "Accept"}) + + with self.subTest("other-organization record does not pass authorization"): + RegisteredUser.objects.filter(user=user).delete() + org2 = self._create_org(name="verified-org-2", slug="verified-org-2") + self._create_org_user(organization=org2, user=user) + RegisteredUser.objects.create( + user=user, organization=org2, is_verified=True + ) + response = self._authorize_user(auth_header=self.auth_header) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, None) + + def test_multi_org_user_different_verification_states(self): + org1 = self._get_org() + org_settings = OrganizationRadiusSettings.objects.get(organization=org1) + org_settings.needs_identity_verification = True + org_settings.save() + org2 = self._create_org(name="org2", slug="org2") + org2_settings = OrganizationRadiusSettings.objects.get_or_create( + organization=org2 + )[0] + org2_settings.needs_identity_verification = True + org2_settings.full_clean() + org2_settings.save() + user = self._get_user_with_org() + self._create_org_user(organization=org2, user=user) + RegisteredUser.objects.create(user=user, organization=org1, is_verified=True) + auth_header_org1 = f"Bearer {org1.pk} {org1.radius_settings.token}" + response = self._authorize_user( + username=user.username, auth_header=auth_header_org1 + ) + self.assertEqual(response.data["control:Auth-Type"], "Accept") + + auth_header_org2 = f"Bearer {org2.pk} {org2.radius_settings.token}" + response = self._authorize_user( + username=user.username, auth_header=auth_header_org2 + ) + self.assertIsNone(response.data) + + def test_other_org_record_is_not_used_as_fallback(self): + org1 = self._get_org() + org2 = self._create_org(name="org2", slug="org2") + org2_settings = OrganizationRadiusSettings.objects.get_or_create( + organization=org2 + )[0] + org2_settings.needs_identity_verification = True + org2_settings.full_clean() + org2_settings.save() + user = self._get_user_with_org() + self._create_org_user(organization=org2, user=user) + RegisteredUser.objects.create(user=user, organization=org2, is_verified=True) + org_settings = OrganizationRadiusSettings.objects.get(organization=org1) + org_settings.needs_identity_verification = True + org_settings.save() + + auth_header_org1 = f"Bearer {org1.pk} {org1.radius_settings.token}" + response = self._authorize_user( + username=user.username, auth_header=auth_header_org1 + ) + self.assertEqual(response.data, None) + + auth_header_org2 = f"Bearer {org2.pk} {org2.radius_settings.token}" + response = self._authorize_user( + username=user.username, auth_header=auth_header_org2 + ) + self.assertEqual(response.data["control:Auth-Type"], "Accept") + + def test_other_org_verified_with_org_unverified(self): + """ + A user with a verified record in another org should not be + authorized for an org where they have an org-specific unverified record. + """ + org = self._get_org() + org_settings = OrganizationRadiusSettings.objects.get(organization=org) + org_settings.needs_identity_verification = True + org_settings.save() + user = self._get_user_with_org() + org2 = self._create_org(name="org2-priority", slug="org2-priority") + self._create_org_user(organization=org2, user=user) + RegisteredUser.objects.create(user=user, organization=org, is_verified=False) + RegisteredUser.objects.create(user=user, organization=org2, is_verified=True) + auth_header = f"Bearer {org.pk} {org.radius_settings.token}" + response = self._authorize_user(username=user.username, auth_header=auth_header) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, None) + + @mock.patch.object(registration, "AUTHORIZE_UNVERIFIED", ["mobile_phone"]) + def test_other_org_special_method_with_org_unverified_not_authorized(self): + """ + When AUTHORIZE_UNVERIFIED is set, the org-specific + record still takes precedence. A user with org-specific unverified record + using a non-special method should NOT be authorized even if they have a + verified record in another organization with a special method. + """ + org = self._get_org() + org_settings = OrganizationRadiusSettings.objects.get(organization=org) + org_settings.needs_identity_verification = True + org_settings.save() + user = self._get_user_with_org() + org2 = self._create_org(name="org2-special", slug="org2-special") + self._create_org_user(organization=org2, user=user) + RegisteredUser.objects.create( + user=user, organization=org, method="email", is_verified=False + ) + RegisteredUser.objects.create( + user=user, organization=org2, method="mobile_phone", is_verified=True + ) + auth_header = f"Bearer {org.pk} {org.radius_settings.token}" + response = self._authorize_user(username=user.username, auth_header=auth_header) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, None) + def test_authorize_radius_token_unverified_user(self): user = self._get_org_user() org_settings = OrganizationRadiusSettings.objects.get( @@ -258,7 +386,7 @@ def test_postauth_radius_token_accept_201(self): def test_postauth_accept_201_querystring(self): self.assertEqual(RadiusPostAuth.objects.all().count(), 0) params = self._get_postauth_params() - post_url = f'{reverse("radius:postauth")}{self.token_querystring}' + post_url = f"{reverse('radius:postauth')}{self.token_querystring}" response = self.client.post(post_url, params) params["password"] = "" self.assertEqual(RadiusPostAuth.objects.filter(**params).count(), 1) @@ -1227,7 +1355,7 @@ def test_accounting_when_nas_using_pfsense_started(self): self.assertIsNone(response.data) def test_get_authorize_view(self): - url = f'{reverse("radius:authorize")}{self.token_querystring}' + url = f"{reverse('radius:authorize')}{self.token_querystring}" r = self.client.get(url, HTTP_ACCEPT="text/html") self.assertEqual(r.status_code, 405) expected = f'
email > empty. + """ + user = self._create_user(username="method-priority-user") + org1 = self._create_org(name="method-org-1", slug="method-org-1") + org2 = self._create_org(name="method-org-2", slug="method-org-2") + org3 = self._create_org(name="method-org-3", slug="method-org-3") + modified_base = timezone.now() + # All unverified, same timestamp - method should decide + with freeze_time(modified_base): + RegisteredUser.objects.create( + user=user, + organization=org1, + is_verified=False, + method="", + ) + RegisteredUser.objects.create( + user=user, + organization=org2, + is_verified=False, + method="email", + ) + RegisteredUser.objects.create( + user=user, + organization=org3, + is_verified=False, + method="mobile_phone", + ) + # Rollback: mobile_phone should win (highest method priority) + migrate_registered_users_multitenant_reverse( + apps, None, app_label="openwisp_radius" + ) + surviving_record = RegisteredUser.objects.get(user=user) + self.assertEqual(surviving_record.organization, org3) + self.assertEqual(surviving_record.method, "mobile_phone") + self.assertEqual(RegisteredUser.objects.filter(user=user).count(), 1) + + def test_multitenant_reverse_full_cleanup(self): + """ + Test that duplicate org-scoped records are reduced to one per user. + """ + user1 = self._create_user( + username="cleanup-user-1", email="cleanup1@example.com" + ) + user2 = self._create_user( + username="cleanup-user-2", email="cleanup2@example.com" + ) + org1 = self._create_org(name="cleanup-org-1", slug="cleanup-org-1") + org2 = self._create_org(name="cleanup-org-2", slug="cleanup-org-2") + # Create multiple org-scoped records for multiple users + for user, org in [(user1, org1), (user1, org2), (user2, org1)]: + RegisteredUser.objects.create( + user=user, + organization=org, + is_verified=False, + method="email", + ) + self.assertEqual( + RegisteredUser.objects.filter(user=user1).count(), + 2, + ) + migrate_registered_users_multitenant_reverse( + apps, None, app_label="openwisp_radius" + ) + self.assertEqual( + RegisteredUser.objects.filter(user=user1).count(), + 1, + ) + self.assertEqual( + RegisteredUser.objects.filter(user=user2).count(), + 1, + ) diff --git a/openwisp_radius/tests/test_models.py b/openwisp_radius/tests/test_models.py index c618a29e..d784b643 100644 --- a/openwisp_radius/tests/test_models.py +++ b/openwisp_radius/tests/test_models.py @@ -42,6 +42,7 @@ RadiusBatch = load_model("RadiusBatch") OrganizationRadiusSettings = load_model("OrganizationRadiusSettings") Organization = swapper.load_model("openwisp_users", "Organization") +RegisteredUser = load_model("RegisteredUser") class TestNas(BaseTestCase): @@ -1218,5 +1219,47 @@ def test_sessions_with_multiple_orgs(self, mocked_radclient): self.assertEqual(org2_session.groupname, f"{org2.slug}-users") +class TestRegisteredUser(BaseTestCase): + def test_get_for_user_and_org(self): + user = self._create_user() + org1 = self._create_org(name="ru-test-org-1", slug="ru-test-org-1") + org2 = self._create_org(name="ru-test-org-2", slug="ru-test-org-2") + + with self.subTest("returns None when no records exist"): + result = RegisteredUser.get_for_user_and_org(user, org1) + self.assertIsNone(result) + + with self.subTest("returns only the requested organization record"): + org2_ru = RegisteredUser.objects.create( + user=user, organization=org2, is_verified=True + ) + result = RegisteredUser.get_for_user_and_org(user, org1) + self.assertIsNone(result) + result = RegisteredUser.get_for_user_and_org(user, org2) + self.assertEqual(result, org2_ru) + self.assertEqual(result.is_verified, True) + + def test_clean_requires_unique_org_specific_registered_user(self): + user = self._create_user() + org = self._create_org(name="dup-test-org", slug="dup-test-org") + other_org = self._create_org(name="dup-test-org-2", slug="dup-test-org-2") + + with self.subTest("duplicate org-specific raises ValidationError"): + RegisteredUser.objects.create(user=user, organization=org) + duplicate = RegisteredUser(user=user, organization=org) + with self.assertRaises(ValidationError): + duplicate.full_clean() + + with self.subTest("different organizations are allowed"): + record = RegisteredUser(user=user, organization=other_org) + record.full_clean() + + def test_clean_requires_organization(self): + user = self._create_user() + + with self.assertRaises(ValidationError): + RegisteredUser(user=user).full_clean() + + del BaseTestCase del BaseTransactionTestCase diff --git a/openwisp_radius/tests/test_saml/test_views.py b/openwisp_radius/tests/test_saml/test_views.py index 0c662970..adcaf1fd 100644 --- a/openwisp_radius/tests/test_saml/test_views.py +++ b/openwisp_radius/tests/test_saml/test_views.py @@ -152,8 +152,9 @@ def test_relay_state_relative_path(self): @capture_any_output() def test_user_registered_with_non_saml_method(self): + org = Organization.objects.get(slug="default") user = self._create_user(username="test-user", email="org_user@example.com") - RegisteredUser.objects.create(user=user, method="manual") + RegisteredUser.objects.create(user=user, method="manual", organization=org) relay_state = self._get_relay_state( redirect_url="https://captive-portal.example.com", org_slug="default" ) diff --git a/openwisp_radius/tests/test_selenium.py b/openwisp_radius/tests/test_selenium.py index 7b059345..8291f46d 100644 --- a/openwisp_radius/tests/test_selenium.py +++ b/openwisp_radius/tests/test_selenium.py @@ -21,6 +21,7 @@ @tag("selenium_tests") +@tag("no_parallel") class BasicTest( SeleniumTestMixin, FileMixin, StaticLiveServerTestCase, TestOrganizationMixin ): diff --git a/openwisp_radius/tests/test_social.py b/openwisp_radius/tests/test_social.py index 19ceafdb..07454792 100644 --- a/openwisp_radius/tests/test_social.py +++ b/openwisp_radius/tests/test_social.py @@ -2,7 +2,6 @@ from allauth.socialaccount.models import SocialAccount from django.contrib.auth import get_user_model -from django.core.exceptions import ObjectDoesNotExist from django.urls import reverse from rest_framework.authtoken.models import Token from swapper import load_model @@ -14,6 +13,7 @@ from .mixins import ApiTokenMixin, BaseTestCase RadiusToken = load_model("openwisp_radius", "RadiusToken") +RegisteredUser = load_model("openwisp_radius", "RegisteredUser") OrganizationRadiusSettings = load_model("openwisp_radius", "OrganizationRadiusSettings") Organization = load_model("openwisp_users", "Organization") User = get_user_model() @@ -102,13 +102,13 @@ def test_redirect_cp_301(self): user = User.objects.filter(username="socialuser").first() self.assertTrue(user.is_member(self.default_org)) try: - reg_user = user.registered_user - except ObjectDoesNotExist: + reg_user = user.registered_users.get(organization=self.default_org) + except RegisteredUser.DoesNotExist: self.fail("RegisteredUser instance not found") self.assertEqual(reg_user.method, "social_login") # social login is not a legally valid identity verification method # so this should be always False when users sign up with this method - self.assertFalse(reg_user.is_verified) + self.assertEqual(reg_user.is_verified, False) def test_authorize_using_radius_user_token_200(self): self.test_redirect_cp_301() diff --git a/openwisp_radius/tests/test_tasks.py b/openwisp_radius/tests/test_tasks.py index 8aadb051..230d2335 100644 --- a/openwisp_radius/tests/test_tasks.py +++ b/openwisp_radius/tests/test_tasks.py @@ -139,9 +139,7 @@ def test_delete_unverified_users(self): management.call_command("batch_add_users", **options) User.objects.update(date_joined=now() - timedelta(days=3)) for user in User.objects.all(): - user.registered_user.is_verified = False - user.registered_user.method = "email" - user.registered_user.save(update_fields=["is_verified", "method"]) + user.registered_users.update(is_verified=False, method="email") self.assertEqual(User.objects.count(), 3) tasks.delete_unverified_users.delay(older_than_days=2) self.assertEqual(User.objects.count(), 0) @@ -320,19 +318,35 @@ def test_unverify_inactive_users(self, *args): User.objects.exclude(id=active_user.id).update( last_login=today - timedelta(days=60) ) - RegisteredUser.objects.create(user=admin, is_verified=True) - RegisteredUser.objects.create(user=active_user, is_verified=True) RegisteredUser.objects.create( - user=unspecified_user, method="", is_verified=True + user=admin, organization=self.default_org, is_verified=True ) RegisteredUser.objects.create( - user=manually_registered_user, method="manual", is_verified=True + user=active_user, organization=self.default_org, is_verified=True ) RegisteredUser.objects.create( - user=email_registered_user, method="email", is_verified=True + user=unspecified_user, + organization=self.default_org, + method="", + is_verified=True, ) RegisteredUser.objects.create( - user=mobile_registered_user, method="mobile_phone", is_verified=True + user=manually_registered_user, + organization=self.default_org, + method="manual", + is_verified=True, + ) + RegisteredUser.objects.create( + user=email_registered_user, + organization=self.default_org, + method="email", + is_verified=True, + ) + RegisteredUser.objects.create( + user=mobile_registered_user, + organization=self.default_org, + method="mobile_phone", + is_verified=True, ) tasks.unverify_inactive_users.delay() @@ -342,12 +356,38 @@ def test_unverify_inactive_users(self, *args): manually_registered_user.refresh_from_db() email_registered_user.refresh_from_db() mobile_registered_user.refresh_from_db() - self.assertEqual(admin.registered_user.is_verified, True) - self.assertEqual(active_user.registered_user.is_verified, True) - self.assertEqual(unspecified_user.registered_user.is_verified, True) - self.assertEqual(manually_registered_user.registered_user.is_verified, True) - self.assertEqual(email_registered_user.registered_user.is_verified, True) - self.assertEqual(mobile_registered_user.registered_user.is_verified, False) + self.assertEqual( + admin.registered_users.get(organization=self.default_org).is_verified, + True, + ) + self.assertEqual( + active_user.registered_users.get(organization=self.default_org).is_verified, + True, + ) + self.assertEqual( + unspecified_user.registered_users.get( + organization=self.default_org + ).is_verified, + True, + ) + self.assertEqual( + manually_registered_user.registered_users.get( + organization=self.default_org + ).is_verified, + True, + ) + self.assertEqual( + email_registered_user.registered_users.get( + organization=self.default_org + ).is_verified, + True, + ) + self.assertEqual( + mobile_registered_user.registered_users.get( + organization=self.default_org + ).is_verified, + False, + ) @mock.patch.object(app_settings, "DELETE_INACTIVE_USERS", 30) def test_delete_inactive_users(self, *args): diff --git a/openwisp_radius/tests/test_token.py b/openwisp_radius/tests/test_token.py index 3a03115b..6e89a688 100644 --- a/openwisp_radius/tests/test_token.py +++ b/openwisp_radius/tests/test_token.py @@ -65,7 +65,10 @@ def _create_token( def test_is_already_verified(self): token = self._create_token() RegisteredUser.objects.create( - user=token.user, method="mobile_phone", is_verified=True + user=token.user, + organization=self.default_org, + method="mobile_phone", + is_verified=True, ) token.refresh_from_db() diff --git a/openwisp_radius/tests/test_users_integration.py b/openwisp_radius/tests/test_users_integration.py index dcefb721..2a0010ce 100644 --- a/openwisp_radius/tests/test_users_integration.py +++ b/openwisp_radius/tests/test_users_integration.py @@ -96,11 +96,23 @@ def test_radiustoken_inline(self): @capture_stdout() def test_export_users_command(self): temp_file = NamedTemporaryFile(delete=False) - user = self._create_org_user().user - RegisteredUser.objects.create( - user=user, method="mobile_phone", is_verified=False + org_user = self._create_org_user() + user = org_user.user + org2 = self._create_org(name="Test Organization 2") + self._create_org_user(organization=org2, user=user) + org1_reg_user = RegisteredUser.objects.create( + user=user, + organization=org_user.organization, + method="mobile_phone", + is_verified=False, ) - with self.assertNumQueries(1): + org2_reg_user = RegisteredUser.objects.create( + user=user, + organization=org2, + method="mobile_phone", + is_verified=True, + ) + with self.assertNumQueries(3): call_command("export_users", filename=temp_file.name) with open(temp_file.name, "r") as file: @@ -108,10 +120,19 @@ def test_export_users_command(self): csv_data = list(csv_reader) self.assertEqual(len(csv_data), 2) - self.assertIn("registered_user.method", csv_data[0]) - self.assertIn("registered_user.is_verified", csv_data[0]) - self.assertEqual(csv_data[1][-2], "mobile_phone") - self.assertEqual(csv_data[1][-1], "False") + self.assertIn( + "registered_users (organization_id, method, is_verified)", csv_data[0] + ) + self.assertEqual( + csv_data[1][-1], + ( + f"({org1_reg_user.organization_id},{org1_reg_user.method}," + f"{org1_reg_user.is_verified})" + "\n" + f"({org2_reg_user.organization_id},{org2_reg_user.method}," + f"{org2_reg_user.is_verified})" + ), + ) def test_radiususergroup_inline(self): """ diff --git a/runtests b/runtests index 60761e1d..188b58c4 100755 --- a/runtests +++ b/runtests @@ -3,7 +3,7 @@ set -e # Standard tests coverage run runtests.py --parallel \ - --exclude-tag=no_parallel >/dev/null 2>&1 \ + --exclude-tag=no_parallel 2>&1 \ || ./runtests.py --exclude-tag=no_parallel # Test extensibility diff --git a/tests/openwisp2/sample_radius/api/views.py b/tests/openwisp2/sample_radius/api/views.py index 6bb68b99..d5b468ef 100644 --- a/tests/openwisp2/sample_radius/api/views.py +++ b/tests/openwisp2/sample_radius/api/views.py @@ -24,6 +24,9 @@ RadiusUserGroupListCreateView, ) from openwisp_radius.api.views import RegisterView as BaseRegisterView +from openwisp_radius.api.views import ( + UpdateRegisteredUserMethodView as BaseUpdateRegisteredUserMethodView, +) from openwisp_radius.api.views import UserAccountingView as BaseUserAccountingView from openwisp_radius.api.views import UserRadiusUsageView as BaseUserRadiusUsageView from openwisp_radius.api.views import ValidateAuthTokenView as BaseValidateAuthTokenView @@ -104,6 +107,10 @@ class RadiusAccountingView(BaseRadiusAccountingView): pass +class UpdateRegisteredUserMethodView(BaseUpdateRegisteredUserMethodView): + pass + + authorize = AuthorizeView.as_view() postauth = PostAuthView.as_view() accounting = AccountingView.as_view() @@ -126,3 +133,4 @@ class RadiusAccountingView(BaseRadiusAccountingView): radius_group_detail = RadiusGroupDetailView.as_view() radius_user_group_list = RadiusUserGroupListCreateView.as_view() radius_user_group_detail = RadiusUserGroupDetailView.as_view() +update_registered_user_registration_method = UpdateRegisteredUserMethodView.as_view() diff --git a/tests/openwisp2/sample_radius/migrations/0032_registered_user_multitenant.py b/tests/openwisp2/sample_radius/migrations/0032_registered_user_multitenant.py new file mode 100644 index 00000000..8f78ded5 --- /dev/null +++ b/tests/openwisp2/sample_radius/migrations/0032_registered_user_multitenant.py @@ -0,0 +1,211 @@ +import uuid + +import django +import django.db.models.deletion +import django.utils.timezone +import model_utils.fields +import swapper +from django.conf import settings +from django.db import migrations, models + +from openwisp_radius.migrations import ( + copy_registered_users_ctcr_forward, + copy_registered_users_ctcr_reverse, + migrate_registered_users_multitenant_forward, + migrate_registered_users_multitenant_reverse, +) + + +def copy_registered_users_forward(apps, schema_editor): + copy_registered_users_ctcr_forward( + apps, + schema_editor, + app_label="sample_radius", + extra_fields=("details",), + ) + + +def copy_registered_users_reverse(apps, schema_editor): + copy_registered_users_ctcr_reverse( + apps, + schema_editor, + app_label="sample_radius", + extra_fields=("details",), + ) + + +def migrate_registered_users_forward(apps, schema_editor): + migrate_registered_users_multitenant_forward( + apps, + schema_editor, + app_label="sample_radius", + extra_fields=("details",), + ) + + +def migrate_registered_users_reverse(apps, schema_editor): + migrate_registered_users_multitenant_reverse( + apps, + schema_editor, + app_label="sample_radius", + extra_fields=("details",), + ) + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + swapper.dependency("openwisp_users", "Organization"), + ("sample_radius", "0031_radiusbatch_status"), + ] + + operations = [ + migrations.SeparateDatabaseAndState( + database_operations=[ + migrations.CreateModel( + name="RegisteredUserNew", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "details", + models.CharField( + blank=True, + max_length=64, + null=True, + ), + ), + ( + "method", + models.CharField( + blank=True, + default="", + help_text=( + "users can sign up in different ways, some " + "methods are valid as indirect identity " + "verification (eg: mobile phone SIM card in " + "most countries)" + ), + max_length=64, + verbose_name="registration method", + ), + ), + ( + "is_verified", + models.BooleanField( + default=False, + help_text=( + "whether the user has completed any identity " + "verification process sucessfully" + ), + verbose_name="verified", + ), + ), + ( + "modified", + model_utils.fields.AutoLastModifiedField( + default=django.utils.timezone.now, + editable=False, + verbose_name="Last verification change", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "organization", + models.ForeignKey( + blank=True, + help_text=( + "The organization this registration info belongs" + " to." + ), + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to=swapper.get_model_name( + "openwisp_users", "Organization" + ), + verbose_name="organization", + ), + ), + ], + options={ + "verbose_name": "Registration Information", + "verbose_name_plural": "Registration Information", + }, + ), + migrations.RunPython( + copy_registered_users_forward, + copy_registered_users_reverse, + ), + migrations.DeleteModel(name="RegisteredUser"), + migrations.RenameModel( + old_name="RegisteredUserNew", + new_name="RegisteredUser", + ), + ], + state_operations=[ + migrations.AddField( + model_name="registereduser", + name="id", + field=models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + migrations.AlterField( + model_name="registereduser", + name="user", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="registered_users", + to=settings.AUTH_USER_MODEL, + ), + ), + migrations.AddField( + model_name="registereduser", + name="organization", + field=models.ForeignKey( + help_text=( + "Organization associated with this registered user entry." + ), + related_name="registered_users", + on_delete=django.db.models.deletion.CASCADE, + to=swapper.get_model_name("openwisp_users", "Organization"), + verbose_name="organization", + ), + ), + ], + ), + migrations.RunPython( + migrate_registered_users_forward, + migrate_registered_users_reverse, + ), + migrations.AddConstraint( + model_name="registereduser", + constraint=models.UniqueConstraint( + fields=["user", "organization"], + name="unique_registered_user_per_org", + violation_error_message=( + "A user cannot have more than one registration" + " record in the same organization." + ), + ), + ), + ]