diff --git a/app/ioc.py b/app/ioc.py index c111396dd..870657023 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -66,6 +66,7 @@ from ldap_protocol.kerberos.service import KerberosService from ldap_protocol.kerberos.template_render import KRBTemplateRenderer from ldap_protocol.ldap_requests.contexts import ( + LDAPAbandonRequestContext, LDAPAddRequestContext, LDAPBindRequestContext, LDAPDeleteRequestContext, @@ -214,7 +215,7 @@ async def get_kadmin_http( yield KadminHTTPClient(client) @provide(scope=Scope.REQUEST) - async def get_kadmin( + def get_kadmin( self, client: KadminHTTPClient, kadmin_class: type[AbstractKadmin], @@ -269,14 +270,14 @@ async def get_dns_http_client( yield DNSManagerHTTPClient(client) @provide(scope=Scope.REQUEST) - async def get_dns_mngr( + def get_dns_mngr( self, settings: DNSManagerSettings, dns_manager_class: type[AbstractDNSManager], http_client: DNSManagerHTTPClient, - ) -> AsyncIterator[AbstractDNSManager]: + ) -> AbstractDNSManager: """Get DNSManager class.""" - yield dns_manager_class(settings=settings, http_client=http_client) + return dns_manager_class(settings=settings, http_client=http_client) @provide(scope=Scope.APP) async def get_redis_for_sessions( @@ -293,7 +294,7 @@ async def get_redis_for_sessions( await client.aclose() @provide(scope=Scope.APP) - async def get_session_storage( + def get_session_storage( self, client: SessionStorageClient, settings: Settings, @@ -306,7 +307,7 @@ async def get_session_storage( ) @provide() - async def get_normalized_audit_event( + def get_normalized_audit_event( self, ) -> type[NormalizedAuditEvent]: """Get normalized audit event class.""" @@ -327,13 +328,13 @@ async def get_audit_redis_client( await client.aclose() @provide(scope=Scope.APP) - async def get_raw_audit_manager( + def get_raw_audit_manager( self, client: AuditRedisClient, settings: Settings, - ) -> AsyncIterator[RawAuditManager]: + ) -> RawAuditManager: """Get raw audit manager.""" - yield RawAuditManager( + return RawAuditManager( client, settings.RAW_EVENT_STREAM_NAME, settings.EVENT_HANDLER_GROUP, @@ -342,13 +343,13 @@ async def get_raw_audit_manager( ) @provide(scope=Scope.APP) - async def get_normalized_audit_manager( + def get_normalized_audit_manager( self, client: AuditRedisClient, settings: Settings, - ) -> AsyncIterator[NormalizedAuditManager]: + ) -> NormalizedAuditManager: """Get raw audit manager.""" - yield NormalizedAuditManager( + return NormalizedAuditManager( client, settings.NORMALIZED_EVENT_STREAM_NAME, settings.EVENT_SENDER_GROUP, @@ -361,7 +362,7 @@ async def get_normalized_audit_manager( audit_destination_dao = provide(AuditDestinationDAO, scope=Scope.REQUEST) @provide(scope=Scope.REQUEST) - async def get_dhcp_manager_repository( + def get_dhcp_manager_repository( self, session: AsyncSession, ) -> DHCPManagerRepository: @@ -377,20 +378,20 @@ async def get_dhcp_manager_state( return await dhcp_manager_repository.ensure_state() @provide(scope=Scope.REQUEST) - async def get_dhcp_mngr_class( + def get_dhcp_mngr_class( self, dhcp_state: DHCPManagerState, ) -> type[AbstractDHCPManager]: """Get DHCP manager type.""" - return await get_dhcp_manager_class(dhcp_state) + return get_dhcp_manager_class(dhcp_state) @provide(scope=Scope.REQUEST) - async def get_dhcp_api_repository_class( + def get_dhcp_api_repository_class( self, dhcp_state: DHCPManagerState, ) -> type[DHCPAPIRepository]: """Get DHCP API repository type.""" - return await get_dhcp_api_repository_class(dhcp_state) + return get_dhcp_api_repository_class(dhcp_state) @provide(scope=Scope.APP) async def get_dhcp_http_client( @@ -404,7 +405,7 @@ async def get_dhcp_http_client( yield DHCPManagerHTTPClient(http_client) @provide(scope=Scope.REQUEST) - async def get_dhcp_api_repository( + def get_dhcp_api_repository( self, http_client: DHCPManagerHTTPClient, dhcp_api_repository_class: type[DHCPAPIRepository], @@ -413,7 +414,7 @@ async def get_dhcp_api_repository( return dhcp_api_repository_class(http_client) @provide(scope=Scope.REQUEST) - async def get_dhcp_mngr( + def get_dhcp_mngr( self, dhcp_manager_class: type[AbstractDHCPManager], dhcp_api_repository: DHCPAPIRepository, @@ -458,7 +459,7 @@ async def get_dhcp_mngr( ) password_utils = provide(PasswordUtils, scope=Scope.RUNTIME) - access_manager = provide(AccessManager, scope=Scope.REQUEST) + access_manager = provide(AccessManager, scope=Scope.RUNTIME) role_dao = provide(RoleDAO, scope=Scope.REQUEST) ace_dao = provide(AccessControlEntryDAO, scope=Scope.REQUEST) role_use_case = provide(RoleUseCase, scope=Scope.REQUEST) @@ -503,12 +504,16 @@ class LDAPContextProvider(Provider): LDAPModifyDNRequestContext, scope=Scope.REQUEST, ) + unbind_request_context = provide( + LDAPUnbindRequestContext, + scope=Scope.REQUEST, + ) search_request_context = provide( LDAPSearchRequestContext, scope=Scope.REQUEST, ) - unbind_request_context = provide( - LDAPUnbindRequestContext, + abandon_request_context = provide( + LDAPAbandonRequestContext, scope=Scope.REQUEST, ) @@ -535,7 +540,7 @@ class HTTPProvider(LDAPContextProvider): ) @provide() - async def get_audit_monitor( + def get_audit_monitor( self, session: AsyncSession, audit_use_case: "AuditUseCase", @@ -595,7 +600,7 @@ def get_permissions_provider( return auth_provider @provide() - async def get_identity_provider( + def get_identity_provider( self, request: Request, session_storage: SessionStorage, @@ -816,7 +821,7 @@ async def get_client( yield MFAHTTPClient(client) @provide(provides=MultifactorAPI) - async def get_http_mfa( + def get_http_mfa( self, credentials: MFA_HTTP_Creds, client: MFAHTTPClient, @@ -838,7 +843,7 @@ async def get_http_mfa( ) @provide(provides=LDAPMultiFactorAPI) - async def get_ldap_mfa( + def get_ldap_mfa( self, credentials: MFA_LDAP_Creds, client: MFAHTTPClient, diff --git a/app/ldap_protocol/dhcp/__init__.py b/app/ldap_protocol/dhcp/__init__.py index 27df7d0c0..cf26f1903 100644 --- a/app/ldap_protocol/dhcp/__init__.py +++ b/app/ldap_protocol/dhcp/__init__.py @@ -26,7 +26,7 @@ from .stub import StubDHCPAPIRepository, StubDHCPManager -async def get_dhcp_manager_class( +def get_dhcp_manager_class( dhcp_state: DHCPManagerState, ) -> type[AbstractDHCPManager]: """Get an instance of the DHCP manager.""" @@ -35,7 +35,7 @@ async def get_dhcp_manager_class( return StubDHCPManager -async def get_dhcp_api_repository_class( +def get_dhcp_api_repository_class( dhcp_state: DHCPManagerState, ) -> type[DHCPAPIRepository]: """Get an instance of the DHCP API repository.""" diff --git a/app/ldap_protocol/ldap_requests/abandon.py b/app/ldap_protocol/ldap_requests/abandon.py index 3facb0562..b9569ca0e 100644 --- a/app/ldap_protocol/ldap_requests/abandon.py +++ b/app/ldap_protocol/ldap_requests/abandon.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, ClassVar from ldap_protocol.asn1parser import ASN1Row +from ldap_protocol.ldap_requests.contexts import LDAPAbandonRequestContext from ldap_protocol.objects import ProtocolRequests from .base import BaseRequest @@ -16,6 +17,7 @@ class AbandonRequest(BaseRequest): """Abandon protocol.""" + CONTEXT_TYPE: ClassVar[type] = LDAPAbandonRequestContext PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ABANDON message_id: int @@ -27,7 +29,7 @@ def from_data( """Create structure from ASN1Row dataclass list.""" return cls(message_id=1) - async def handle(self) -> AsyncGenerator: + async def handle(self, ctx: LDAPAbandonRequestContext) -> AsyncGenerator: # noqa: ARG002 """Handle message with current user.""" await asyncio.sleep(0) return diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index fec2a8679..aa5d3b3ee 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -65,6 +65,7 @@ class AddRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ADD + CONTEXT_TYPE: ClassVar[type] = LDAPAddRequestContext entry: str = Field(..., description="Any `DistinguishedName`") attributes: list[PartialAttribute] diff --git a/app/ldap_protocol/ldap_requests/base.py b/app/ldap_protocol/ldap_requests/base.py index 3123e6247..445ce3bae 100644 --- a/app/ldap_protocol/ldap_requests/base.py +++ b/app/ldap_protocol/ldap_requests/base.py @@ -24,6 +24,7 @@ from ldap_protocol.dependency import resolve_deps from ldap_protocol.dialogue import LDAPSession from ldap_protocol.ldap_responses import BaseResponse, LDAPResult +from ldap_protocol.objects import ProtocolRequests from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.audit.events.factory import ( RawAuditEventBuilderRedis, @@ -62,6 +63,7 @@ class _APIProtocol: ... class BaseRequest(ABC, _APIProtocol, BaseModel): """Base request builder.""" + CONTEXT_TYPE: ClassVar[type] handle: ClassVar[handler] from_data: ClassVar[serializer] __event_data: dict = {} @@ -113,38 +115,39 @@ async def handle_tcp( container: AsyncContainer, ) -> AsyncIterator[BaseResponse]: """Hanlde response with tcp.""" - kwargs = await resolve_deps(func=self.handle, container=container) - responses = [] + ctx = await container.get(self.CONTEXT_TYPE) # type: ignore - async for response in self.handle(**kwargs): + responses = [] + async for response in self.handle(ctx=ctx): responses.append(response) yield response - ldap_session = await container.get(LDAPSession) - settings = await container.get(Settings) - audit_use_case = await container.get(AuditUseCase) - - if await audit_use_case.check_event_processing_enabled( - self.PROTOCOL_OP, - ): - username = getattr( - ldap_session.user, - "user_principal_name", - "ANONYMOUS", - ) - event = RawAuditEventBuilderRedis.from_ldap_request( - self, - responses=responses, - username=username, - ip=ldap_session.ip, - protocol="TCP_LDAP", - settings=settings, - context=self.get_event_data(), - ) + if self.PROTOCOL_OP != ProtocolRequests.SEARCH: + ldap_session = await container.get(LDAPSession) + settings = await container.get(Settings) + audit_use_case = await container.get(AuditUseCase) + + if await audit_use_case.check_event_processing_enabled( + self.PROTOCOL_OP, + ): + username = getattr( + ldap_session.user, + "user_principal_name", + "ANONYMOUS", + ) + event = RawAuditEventBuilderRedis.from_ldap_request( + self, + responses=responses, + username=username, + ip=ldap_session.ip, + protocol="TCP_LDAP", + settings=settings, + context=self.get_event_data(), + ) - ldap_session.event_task_group.create_task( - audit_use_case.manager.send_event(event), - ) + ldap_session.event_task_group.create_task( + audit_use_case.manager.send_event(event), + ) async def _handle_api( self, @@ -156,7 +159,8 @@ async def _handle_api( :param AsyncSession session: db session :return list[BaseResponse]: list of handled responses """ - kwargs = await resolve_deps(func=self.handle, container=container) + ctx = await container.get(self.CONTEXT_TYPE) # type: ignore + ldap_session = await container.get(LDAPSession) settings = await container.get(Settings) audit_use_case = await container.get(AuditUseCase) @@ -168,7 +172,7 @@ async def _handle_api( else: log_api.info(f"{get_class_name(self)}[{un}]") - responses = [response async for response in self.handle(**kwargs)] + responses = [response async for response in self.handle(ctx=ctx)] if settings.DEBUG: for response in responses: diff --git a/app/ldap_protocol/ldap_requests/bind.py b/app/ldap_protocol/ldap_requests/bind.py index a303f0fdf..445b2f25c 100644 --- a/app/ldap_protocol/ldap_requests/bind.py +++ b/app/ldap_protocol/ldap_requests/bind.py @@ -43,6 +43,7 @@ class BindRequest(BaseRequest): """Bind request fields mapping.""" PROTOCOL_OP: ClassVar[int] = ProtocolRequests.BIND + CONTEXT_TYPE: ClassVar[type] = LDAPBindRequestContext version: int name: str @@ -230,6 +231,7 @@ class UnbindRequest(BaseRequest): """Remove user from ldap_session.""" PROTOCOL_OP: ClassVar[int] = ProtocolRequests.UNBIND + CONTEXT_TYPE: ClassVar[type] = LDAPUnbindRequestContext @classmethod def from_data( diff --git a/app/ldap_protocol/ldap_requests/contexts.py b/app/ldap_protocol/ldap_requests/contexts.py index 4926d826e..98f6e1a9b 100644 --- a/app/ldap_protocol/ldap_requests/contexts.py +++ b/app/ldap_protocol/ldap_requests/contexts.py @@ -123,3 +123,7 @@ class LDAPModifyDNRequestContext: access_manager: AccessManager role_use_case: RoleUseCase attribute_value_validator: AttributeValueValidator + + +@dataclass +class LDAPAbandonRequestContext: ... diff --git a/app/ldap_protocol/ldap_requests/delete.py b/app/ldap_protocol/ldap_requests/delete.py index 5c731e9b7..334df621a 100644 --- a/app/ldap_protocol/ldap_requests/delete.py +++ b/app/ldap_protocol/ldap_requests/delete.py @@ -43,6 +43,7 @@ class DeleteRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.DELETE + CONTEXT_TYPE: ClassVar[type] = LDAPDeleteRequestContext entry: str diff --git a/app/ldap_protocol/ldap_requests/extended.py b/app/ldap_protocol/ldap_requests/extended.py index 85ca1f31b..c3967889e 100644 --- a/app/ldap_protocol/ldap_requests/extended.py +++ b/app/ldap_protocol/ldap_requests/extended.py @@ -308,6 +308,7 @@ class ExtendedRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.EXTENDED + CONTEXT_TYPE: ClassVar[type] = LDAPExtendedRequestContext request_name: LDAPOID request_value: SerializeAsAny[BaseExtendedValue] diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index 7334ba225..8ba60b963 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -102,6 +102,7 @@ class ModifyRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY + CONTEXT_TYPE: ClassVar[type] = LDAPModifyRequestContext object: str changes: list[Changes] diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index c1ff681d3..0cf547833 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -68,6 +68,7 @@ class ModifyDNRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY_DN + CONTEXT_TYPE: ClassVar[type] = LDAPModifyDNRequestContext entry: str newrdn: str diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index 01ec77169..1f9579dc2 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -14,7 +14,12 @@ from pydantic import Field, PrivateAttr, field_serializer from sqlalchemy import func, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, selectinload, with_loader_criteria +from sqlalchemy.orm import ( + contains_eager, + joinedload, + selectinload, + with_loader_criteria, +) from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from sqlalchemy.sql.expression import Select @@ -100,6 +105,7 @@ class SearchRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.SEARCH + CONTEXT_TYPE: ClassVar[type] = LDAPSearchRequestContext base_object: str = Field("", description="Any `DistinguishedName`") scope: Scope @@ -339,7 +345,7 @@ def _mutate_query_with_attributes_to_load( if self.entity_type_name: query = ( query.join(qa(Directory.entity_type)) - .options(selectinload(qa(Directory.entity_type))) + .options(contains_eager(qa(Directory.entity_type))) ) # fmt: skip if self.all_attrs: @@ -369,8 +375,8 @@ def _build_query( query = ( select(Directory) .join(qa(Directory.user), isouter=True) - .options(joinedload(qa(Directory.user))) - .options(selectinload(qa(Directory.group))) + .options(contains_eager(qa(Directory.user))) + .options(joinedload(qa(Directory.group))) ) query = self._mutate_query_with_attributes_to_load(query) @@ -423,7 +429,7 @@ def _build_query( if self.member: query = query.options( - selectinload(qa(Directory.group)).selectinload( + joinedload(qa(Directory.group)).selectinload( qa(Group.members), ), ) @@ -501,7 +507,6 @@ async def _fill_attrs( ) if self.member_of: - logger.debug(f"Member of group: {directory.groups}") for group in directory.groups: attrs["memberOf"].append(group.directory.path_dn) @@ -541,9 +546,9 @@ async def tree_view( # noqa: C901 access_manager: AccessManager, ) -> AsyncGenerator[SearchResultEntry, None]: """Yield all resulted directories.""" - directories = await session.stream_scalars(query) + directories = await session.scalars(query) - async for directory in directories: + for directory in directories: attrs = defaultdict(list) obj_classes = [] diff --git a/tests/conftest.py b/tests/conftest.py index b97a8ce4a..a90d4fee6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -87,6 +87,7 @@ from ldap_protocol.kerberos.template_render import KRBTemplateRenderer from ldap_protocol.ldap_requests.bind import BindRequest from ldap_protocol.ldap_requests.contexts import ( + LDAPAbandonRequestContext, LDAPAddRequestContext, LDAPBindRequestContext, LDAPDeleteRequestContext, @@ -669,6 +670,11 @@ async def get_audit_monitor( LDAPSearchRequestContext, scope=Scope.REQUEST, ) + abandon_request_context = provide( + LDAPAbandonRequestContext, + scope=Scope.REQUEST, + ) + unbind_request_context = provide( LDAPUnbindRequestContext, scope=Scope.REQUEST,