From 8ac3ee75f064f154c35c64e53700515a00eeb8c4 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Sat, 9 Nov 2024 15:02:09 -0800 Subject: [PATCH 1/4] Fix a GC hole in 'IContextCallback' dispatch logic --- src/WinRT.Runtime/Interop/IContextCallback.cs | 50 ++++++++++++++----- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/src/WinRT.Runtime/Interop/IContextCallback.cs b/src/WinRT.Runtime/Interop/IContextCallback.cs index 9a87c29fa..f272f3737 100644 --- a/src/WinRT.Runtime/Interop/IContextCallback.cs +++ b/src/WinRT.Runtime/Interop/IContextCallback.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using WinRT; using WinRT.Interop; @@ -18,11 +19,36 @@ internal struct ComCallData } #if NET && CsWinRT_LANG_11_FEATURES - internal unsafe struct CallbackData + internal sealed unsafe class CallbackData { + [ThreadStatic] + private static CallbackData TlsInstance; + public delegate* Callback; public object State; + public GCHandle Handle; + + private CallbackData() + { + // Create a handle to access the object from a native callback invoked on another thread. + // The handle is weak to ensure that the object does not leak (or it would keep itself + // alive). The target is guaranteed to be alive because callers will use 'GC.KeepAlive'. + Handle = GCHandle.Alloc(this, GCHandleType.Weak); + } + + ~CallbackData() + { + Handle.Free(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static CallbackData GetOrCreate() + { + return TlsInstance ??= new CallbackData(); + } } + + #endif #if NET && CsWinRT_LANG_11_FEATURES @@ -39,25 +65,18 @@ public static void ContextCallback(IntPtr contextCallbackPtr, delegate*pUserDefined; + CallbackData callbackData = Unsafe.As(GCHandle.FromIntPtr(comCallData->pUserDefined).Target); - callbackData->Callback(callbackData->State); + callbackData.Callback(callbackData.State); return 0; // S_OK } @@ -76,6 +95,11 @@ static int InvokeCallback(ComCallData* comCallData) &iid, /* iMethod */ 5, IntPtr.Zero); + + // This call is critical to ensure that the callback data is kept alive until we get here. + // This prevents its finalizer to run (that finalizer would free the GC handle used in the + // native callback to get back the target callback data that contains the dispatch parameters). + GC.KeepAlive(callbackData); if (hresult < 0) { From ce15fe8d8f44f790dda1db8f748d420e6d452713 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Sun, 10 Nov 2024 13:17:51 -0800 Subject: [PATCH 2/4] Apply suggestions from code review --- src/WinRT.Runtime/Interop/IContextCallback.cs | 87 +++++++++---------- 1 file changed, 40 insertions(+), 47 deletions(-) diff --git a/src/WinRT.Runtime/Interop/IContextCallback.cs b/src/WinRT.Runtime/Interop/IContextCallback.cs index f272f3737..9a914501f 100644 --- a/src/WinRT.Runtime/Interop/IContextCallback.cs +++ b/src/WinRT.Runtime/Interop/IContextCallback.cs @@ -4,6 +4,7 @@ using System; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Threading; using WinRT; using WinRT.Interop; @@ -19,36 +20,14 @@ internal struct ComCallData } #if NET && CsWinRT_LANG_11_FEATURES - internal sealed unsafe class CallbackData + internal unsafe struct CallbackData { [ThreadStatic] - private static CallbackData TlsInstance; + public static object PerThreadObject; public delegate* Callback; - public object State; - public GCHandle Handle; - - private CallbackData() - { - // Create a handle to access the object from a native callback invoked on another thread. - // The handle is weak to ensure that the object does not leak (or it would keep itself - // alive). The target is guaranteed to be alive because callers will use 'GC.KeepAlive'. - Handle = GCHandle.Alloc(this, GCHandleType.Weak); - } - - ~CallbackData() - { - Handle.Free(); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static CallbackData GetOrCreate() - { - return TlsInstance ??= new CallbackData(); - } + public object* StatePtr; } - - #endif #if NET && CsWinRT_LANG_11_FEATURES @@ -61,22 +40,19 @@ internal unsafe struct IContextCallbackVftbl public static void ContextCallback(IntPtr contextCallbackPtr, delegate* callback, delegate* onFailCallback, object state) { - ComCallData comCallData; - comCallData.dwDispid = 0; - comCallData.dwReserved = 0; - - CallbackData callbackData = CallbackData.GetOrCreate(); - - comCallData.pUserDefined = GCHandle.ToIntPtr(callbackData.Handle); - + // Native method that invokes the callback on the target context. The state object + // is guaranteed to be pinned, so we can access it from a pointer. Note that the + // object will be stored in a static field, and it will not be on the stack of the + // original thread, so it's safe with respect to cross-thread access of managed objects. + // See: https://github.com/dotnet/runtime/blob/main/docs/design/specs/Memory-model.md#cross-thread-access-to-local-variables. [UnmanagedCallersOnly] static int InvokeCallback(ComCallData* comCallData) { try { - CallbackData callbackData = Unsafe.As(GCHandle.FromIntPtr(comCallData->pUserDefined).Target); + CallbackData* callbackData = (CallbackData*)comCallData->pUserDefined; - callbackData.Callback(callbackData.State); + callbackData->Callback(*callbackData->StatePtr); return 0; // S_OK } @@ -86,20 +62,37 @@ static int InvokeCallback(ComCallData* comCallData) } } - Guid iid = IID.IID_ICallbackWithNoReentrancyToApplicationSTA; + ComCallData comCallData; + comCallData.dwDispid = 0; + comCallData.dwReserved = 0; + + CallbackData.PerThreadObject = state; + + int hresult; - int hresult = (*(IContextCallbackVftbl**)contextCallbackPtr)->ContextCallback_4( - contextCallbackPtr, - (IntPtr)(delegate* unmanaged)&InvokeCallback, - &comCallData, - &iid, - /* iMethod */ 5, - IntPtr.Zero); + fixed (object* statePtr = &CallbackData.PerThreadObject) + { + CallbackData callbackData; + callbackData.Callback = callback; + callbackData.StatePtr = statePtr; + + Guid iid = IID.IID_ICallbackWithNoReentrancyToApplicationSTA; + + // Add a memory barrier to be extra safe that the target thread will be able to see + // the write we just did on 'PerThreadObject' with the state to pass to the callback. + Thread.MemoryBarrier(); + + hresult = (*(IContextCallbackVftbl**)contextCallbackPtr)->ContextCallback_4( + contextCallbackPtr, + (IntPtr)(delegate* unmanaged)&InvokeCallback, + &comCallData, + &iid, + /* iMethod */ 5, + IntPtr.Zero); + } - // This call is critical to ensure that the callback data is kept alive until we get here. - // This prevents its finalizer to run (that finalizer would free the GC handle used in the - // native callback to get back the target callback data that contains the dispatch parameters). - GC.KeepAlive(callbackData); + // Reset the static field to avoid keeping the state alive for longer + Volatile.Write(ref CallbackData.PerThreadObject, null); if (hresult < 0) { From 9a66dfe6e1583fa92de3e565d860f23ea51bb3fc Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Sun, 10 Nov 2024 13:46:46 -0800 Subject: [PATCH 3/4] Add comments about reentrancy --- src/WinRT.Runtime/Interop/IContextCallback.cs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/WinRT.Runtime/Interop/IContextCallback.cs b/src/WinRT.Runtime/Interop/IContextCallback.cs index 9a914501f..4e3f243cb 100644 --- a/src/WinRT.Runtime/Interop/IContextCallback.cs +++ b/src/WinRT.Runtime/Interop/IContextCallback.cs @@ -70,6 +70,11 @@ static int InvokeCallback(ComCallData* comCallData) int hresult; + // We use a thread local static field to efficiently store the state that's used by the callback. Note that this + // is safe with respect to reentrancy, as the target callback will never try to switch back on the original thread. + // We're only ever switching once on the original context, only to release the object reference that is passed as + // state. There is no way for that to possibly switch back on the starting thread. As such, using a thread static + // field to pass the state to the target context (we need to store it somewhere on the managed heap) is fine. fixed (object* statePtr = &CallbackData.PerThreadObject) { CallbackData callbackData; From 7dd4c5dea6397b3cf78cea09e80e746c65e04dcf Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Sun, 10 Nov 2024 18:21:12 -0800 Subject: [PATCH 4/4] Set 'pUserDefined' again, add comments --- src/WinRT.Runtime/Interop/IContextCallback.cs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/WinRT.Runtime/Interop/IContextCallback.cs b/src/WinRT.Runtime/Interop/IContextCallback.cs index 4e3f243cb..3ca4c8de5 100644 --- a/src/WinRT.Runtime/Interop/IContextCallback.cs +++ b/src/WinRT.Runtime/Interop/IContextCallback.cs @@ -60,12 +60,10 @@ static int InvokeCallback(ComCallData* comCallData) { return e.HResult; } - } - - ComCallData comCallData; - comCallData.dwDispid = 0; - comCallData.dwReserved = 0; + } + // Store the state object in the thread static to pass to the callback. + // We don't need a volatile write here, we have a memory barrier below. CallbackData.PerThreadObject = state; int hresult; @@ -81,6 +79,11 @@ static int InvokeCallback(ComCallData* comCallData) callbackData.Callback = callback; callbackData.StatePtr = statePtr; + ComCallData comCallData; + comCallData.dwDispid = 0; + comCallData.dwReserved = 0; + comCallData.pUserDefined = (IntPtr)(void*)&callbackData; + Guid iid = IID.IID_ICallbackWithNoReentrancyToApplicationSTA; // Add a memory barrier to be extra safe that the target thread will be able to see