diff --git a/.nuget/nuget.exe b/.nuget/nuget.exe new file mode 100644 index 0000000000..5e246fb40d Binary files /dev/null and b/.nuget/nuget.exe differ diff --git a/Source/Csla/DataPortalOperationAttributes.cs b/Source/Csla/DataPortalOperationAttributes.cs index 945d49e6d7..709bcb7bc0 100644 --- a/Source/Csla/DataPortalOperationAttributes.cs +++ b/Source/Csla/DataPortalOperationAttributes.cs @@ -23,6 +23,13 @@ public class InjectAttribute : Attribute /// service is not registered. /// public bool AllowNull { get; set; } + + /// + /// Gets or sets the key used to identify a keyed service registration. + /// When specified, the service is retrieved using GetKeyedService or GetRequiredKeyedService. + /// Requires .NET 8.0 or higher. + /// + public object? Key { get; set; } } /// diff --git a/Source/Csla/Reflection/ServiceProviderMethodCaller.cs b/Source/Csla/Reflection/ServiceProviderMethodCaller.cs index d24d5b16de..1acc85d9be 100644 --- a/Source/Csla/Reflection/ServiceProviderMethodCaller.cs +++ b/Source/Csla/Reflection/ServiceProviderMethodCaller.cs @@ -553,11 +553,40 @@ private static ParameterInfo[] GetDIParameters(System.Reflection.MethodInfo meth { throw new NullReferenceException(nameof(service)); } - // Use GetService for optional (allows null) or GetRequiredService for required (throws if not registered) - plist[index] = method.AllowNull[index] - ? service.GetService(item.ParameterType) - : service.GetRequiredService(item.ParameterType); - + + var serviceKey = method.ServiceKeys[index]; + if (serviceKey != null) + { +#if NET8_0_OR_GREATER + // Use keyed service injection for .NET 8+ + if (method.AllowNull[index]) + { + // For optional keyed services, cast to IKeyedServiceProvider + if (service is IKeyedServiceProvider keyedProvider) + { + plist[index] = keyedProvider.GetKeyedService(item.ParameterType, serviceKey); + } + else + { + throw new InvalidOperationException("Service provider must implement IKeyedServiceProvider to support keyed services."); + } + } + else + { + // For required keyed services, use extension method + plist[index] = service.GetRequiredKeyedService(item.ParameterType, serviceKey); + } +#else + throw new NotSupportedException("Keyed service injection is only supported on .NET 8.0 or higher."); +#endif + } + else + { + // Use GetService for optional (allows null) or GetRequiredService for required (throws if not registered) + plist[index] = method.AllowNull[index] + ? service.GetService(item.ParameterType) + : service.GetRequiredService(item.ParameterType); + } } else { diff --git a/Source/Csla/Reflection/ServiceProviderMethodInfo.cs b/Source/Csla/Reflection/ServiceProviderMethodInfo.cs index f59bfe3863..5595201d38 100644 --- a/Source/Csla/Reflection/ServiceProviderMethodInfo.cs +++ b/Source/Csla/Reflection/ServiceProviderMethodInfo.cs @@ -22,7 +22,7 @@ namespace Csla.Reflection /// public class ServiceProviderMethodInfo { - [MemberNotNullWhen(true, nameof(DynamicMethod), nameof(Parameters), nameof(IsInjected), nameof(AllowNull), nameof(DataPortalMethodInfo))] + [MemberNotNullWhen(true, nameof(DynamicMethod), nameof(Parameters), nameof(IsInjected), nameof(AllowNull), nameof(ServiceKeys), nameof(DataPortalMethodInfo))] private bool Initialized { get; set; } /// @@ -55,6 +55,11 @@ public class ServiceProviderMethodInfo /// public bool[]? AllowNull { get; private set; } /// + /// Gets an array of keys for injected parameters. + /// Null entries indicate non-keyed services. + /// + public object?[]? ServiceKeys { get; private set; } + /// /// Gets a value indicating whether the method /// returns type Task /// @@ -83,7 +88,7 @@ public ServiceProviderMethodInfo(System.Reflection.MethodInfo methodInfo) /// Initializes and caches the metastate values /// necessary to invoke the method /// - [MemberNotNull(nameof(DynamicMethod), nameof(Parameters), nameof(IsInjected), nameof(AllowNull), nameof(DataPortalMethodInfo))] + [MemberNotNull(nameof(DynamicMethod), nameof(Parameters), nameof(IsInjected), nameof(AllowNull), nameof(ServiceKeys), nameof(DataPortalMethodInfo))] public void PrepForInvocation() { if (!Initialized) @@ -97,6 +102,7 @@ public void PrepForInvocation() TakesParamArray = (Parameters.Length == 1 && Parameters[0].ParameterType.Equals(typeof(object[]))); IsInjected = new bool[Parameters.Length]; AllowNull = new bool[Parameters.Length]; + ServiceKeys = new object?[Parameters.Length]; int index = 0; foreach (var item in Parameters) @@ -106,6 +112,7 @@ public void PrepForInvocation() { IsInjected[index] = true; AllowNull[index] = injectAttribute.AllowNull || ParameterAllowsNull(item); + ServiceKeys[index] = injectAttribute.Key; } index++; } diff --git a/Source/tests/csla.netcore.test/DataPortal/ServiceProviderMethodCallerTests.cs b/Source/tests/csla.netcore.test/DataPortal/ServiceProviderMethodCallerTests.cs index 49096e3d81..e006e6690b 100644 --- a/Source/tests/csla.netcore.test/DataPortal/ServiceProviderMethodCallerTests.cs +++ b/Source/tests/csla.netcore.test/DataPortal/ServiceProviderMethodCallerTests.cs @@ -504,6 +504,92 @@ public async Task InvokeMethodWithRequiredServiceInjection_ThrowsWhenServiceNotR await FluentActions.Invoking(async () => await portal.CreateAsync()) .Should().ThrowAsync(); } + +#if NET8_0_OR_GREATER + [TestMethod] + public void FindMethodWithKeyedServiceInjection() + { + var obj = new KeyedServiceInjection(); + var method = _systemUnderTest.FindDataPortalMethod(obj, null); + + method.Should().NotBeNull(); + method.PrepForInvocation(); + method.Parameters.Should().HaveCount(1); + method.IsInjected.Should().HaveCount(1); + method.IsInjected![0].Should().BeTrue(); + method.ServiceKeys.Should().HaveCount(1); + method.ServiceKeys![0].Should().Be("serviceA"); + } + + [TestMethod] + public async Task InvokeMethodWithKeyedServiceInjection() + { + var contextWithService = TestDIContextFactory.CreateDefaultContext(services => + { + services.AddKeyedTransient("serviceA"); + }); + + var portal = contextWithService.CreateDataPortal(); + var obj = await portal.CreateAsync(); + + obj.Should().NotBeNull(); + obj.Data.Should().Be("Service A data"); + } + + [TestMethod] + public void FindMethodWithKeyedServiceAndCriteriaInjection() + { + var obj = new KeyedServiceWithCriteriaInjection(); + var method = _systemUnderTest.FindDataPortalMethod(obj, [123]); + + method.Should().NotBeNull(); + method.PrepForInvocation(); + method.Parameters.Should().HaveCount(2); + method.IsInjected.Should().HaveCount(2); + method.IsInjected![0].Should().BeFalse(); // First param is criteria + method.IsInjected![1].Should().BeTrue(); // Second param is injected + method.ServiceKeys.Should().HaveCount(2); + method.ServiceKeys![0].Should().BeNull(); // First param is not injected + method.ServiceKeys![1].Should().Be("serviceB"); + } + + [TestMethod] + public async Task InvokeMethodWithKeyedServiceAndCriteriaInjection() + { + var contextWithService = TestDIContextFactory.CreateDefaultContext(services => + { + services.AddKeyedTransient("serviceB"); + }); + + var portal = contextWithService.CreateDataPortal(); + var obj = await portal.CreateAsync(42); + + obj.Should().NotBeNull(); + obj.Id.Should().Be(42); + obj.Data.Should().Be("Service B data"); + } + + [TestMethod] + public async Task InvokeMethodWithOptionalKeyedServiceInjection_ServiceNotRegistered() + { + var portal = _diContext.CreateDataPortal(); + var obj = await portal.CreateAsync(); + + obj.Should().NotBeNull(); + obj.Data.Should().Be("Keyed service is null as expected"); + } + + [TestMethod] + public async Task InvokeMethodWithKeyedServiceInjection_ThrowsWhenServiceNotRegistered() + { + var portal = _diContext.CreateDataPortal(); + + // This should throw because the keyed service is required but not registered + // The exception will be wrapped in DataPortalException + await FluentActions.Invoking(async () => await portal.CreateAsync()) + .Should().ThrowAsync(); + } +#endif } #region Classes for testing various scenarios of loading/finding data portal methods @@ -920,5 +1006,83 @@ private void Create([Inject] IOptionalService requiredService) } } +#if NET8_0_OR_GREATER + // Tests for keyed service injection (NET8+ only) + public class KeyedServiceInjection : BusinessBase + { + public static readonly PropertyInfo DataProperty = RegisterProperty(nameof(Data)); + public string Data + { + get => GetProperty(DataProperty); + set => SetProperty(DataProperty, value); + } + + [Create] + private void Create([Inject(Key = "serviceA")] IOptionalService keyedService) + { + using (BypassPropertyChecks) + { + Data = keyedService.GetData(); + } + } + } + + public class KeyedServiceWithCriteriaInjection : BusinessBase + { + public static readonly PropertyInfo DataProperty = RegisterProperty(nameof(Data)); + public string Data + { + get => GetProperty(DataProperty); + set => SetProperty(DataProperty, value); + } + + public static readonly PropertyInfo IdProperty = RegisterProperty(nameof(Id)); + public int Id + { + get => GetProperty(IdProperty); + set => SetProperty(IdProperty, value); + } + + [Create] + private void Create(int id, [Inject(Key = "serviceB")] IOptionalService keyedService) + { + using (BypassPropertyChecks) + { + Id = id; + Data = keyedService.GetData(); + } + } + } + + public class OptionalKeyedServiceInjection : BusinessBase + { + public static readonly PropertyInfo DataProperty = RegisterProperty(nameof(Data)); + public string Data + { + get => GetProperty(DataProperty); + set => SetProperty(DataProperty, value); + } + + [Create] + private void Create([Inject(Key = "nonExistent", AllowNull = true)] IOptionalService? keyedService) + { + using (BypassPropertyChecks) + { + Data = keyedService == null ? "Keyed service is null as expected" : keyedService.GetData(); + } + } + } + + public class ServiceAImplementation : IOptionalService + { + public string GetData() => "Service A data"; + } + + public class ServiceBImplementation : IOptionalService + { + public string GetData() => "Service B data"; + } +#endif + #endregion } \ No newline at end of file