From 80d628ba5432e583e1c56e0bd89c4655cf63c3a2 Mon Sep 17 00:00:00 2001 From: William Candillon Date: Wed, 11 Mar 2026 04:00:37 +0100 Subject: [PATCH] =?UTF-8?q?fix(=F0=9F=90=9B):=20fix=20runtime=20promise=20?= =?UTF-8?q?invalidation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/webgpu/cpp/jsi/JSIConverter.h | 24 ++++--- packages/webgpu/cpp/jsi/Promise.cpp | 52 +++++++++++---- packages/webgpu/cpp/jsi/Promise.h | 65 +++++++++++++++++-- .../webgpu/cpp/rnwgpu/RNWebGPUManager.cpp | 17 ++++- packages/webgpu/cpp/rnwgpu/RNWebGPUManager.h | 2 + packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h | 31 +++++---- 6 files changed, 150 insertions(+), 41 deletions(-) diff --git a/packages/webgpu/cpp/jsi/JSIConverter.h b/packages/webgpu/cpp/jsi/JSIConverter.h index 387a57813..1302f07a5 100644 --- a/packages/webgpu/cpp/jsi/JSIConverter.h +++ b/packages/webgpu/cpp/jsi/JSIConverter.h @@ -205,15 +205,21 @@ template <> struct JSIConverter { } static jsi::Value toJSI(jsi::Runtime& runtime, rnwgpu::async::AsyncTaskHandle&& handle) { - return rnwgpu::Promise::createPromise(runtime, [handle = std::move(handle)](jsi::Runtime& runtime, - std::shared_ptr promise) mutable { - if (!handle.valid()) { - promise->resolve(jsi::Value::undefined()); - return; - } - - handle.attachPromise(promise); - }); + auto context = rnwgpu::RuntimeContext::getMainContext(); + if (!context) { + throw std::runtime_error("RuntimeContext not set - cannot create Promise"); + } + return rnwgpu::Promise::createPromise( + context, + [handle = std::move(handle)](jsi::Runtime& runtime, + std::shared_ptr promise) mutable { + if (!handle.valid()) { + promise->resolve(jsi::Value::undefined()); + return; + } + + handle.attachPromise(promise); + }); } }; diff --git a/packages/webgpu/cpp/jsi/Promise.cpp b/packages/webgpu/cpp/jsi/Promise.cpp index 5d11e58b5..5a2f2511c 100644 --- a/packages/webgpu/cpp/jsi/Promise.cpp +++ b/packages/webgpu/cpp/jsi/Promise.cpp @@ -10,35 +10,65 @@ namespace rnwgpu { namespace jsi = facebook::jsi; -Promise::Promise(jsi::Runtime& runtime, jsi::Function&& resolver, jsi::Function&& rejecter) - : runtime(runtime), _resolver(std::move(resolver)), _rejecter(std::move(rejecter)) {} +// Static member definition +std::shared_ptr RuntimeContext::_mainContext = nullptr; -jsi::Value Promise::createPromise(jsi::Runtime& runtime, RunPromise run) { +Promise::Promise(std::weak_ptr context, jsi::Function&& resolver, + jsi::Function&& rejecter) + : _context(std::move(context)), _resolver(std::move(resolver)), + _rejecter(std::move(rejecter)) {} + +jsi::Value Promise::createPromise(std::shared_ptr context, + RunPromise run) { + auto runtime = context->getRuntime(); // Get Promise ctor from global - auto promiseCtor = runtime.global().getPropertyAsFunction(runtime, "Promise"); + auto promiseCtor = runtime->global().getPropertyAsFunction(*runtime, "Promise"); auto promiseCallback = jsi::Function::createFromHostFunction( - runtime, jsi::PropNameID::forUtf8(runtime, "PromiseCallback"), 2, - [=](jsi::Runtime& runtime, const jsi::Value& thisValue, const jsi::Value* arguments, size_t count) -> jsi::Value { + *runtime, jsi::PropNameID::forUtf8(*runtime, "PromiseCallback"), 2, + [context, run](jsi::Runtime& runtime, const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) -> jsi::Value { // Call function auto resolver = arguments[0].asObject(runtime).asFunction(runtime); auto rejecter = arguments[1].asObject(runtime).asFunction(runtime); - auto promise = std::make_shared(runtime, std::move(resolver), std::move(rejecter)); + auto promise = std::make_shared(context, std::move(resolver), + std::move(rejecter)); run(runtime, promise); return jsi::Value::undefined(); }); - return promiseCtor.callAsConstructor(runtime, promiseCallback); + return promiseCtor.callAsConstructor(*runtime, promiseCallback); } void Promise::resolve(jsi::Value&& result) { - _resolver.call(runtime, std::move(result)); + auto context = _context.lock(); + if (!context || !context->isValid()) { + // Runtime has been torn down, silently ignore + return; + } + auto runtime = context->getRuntime(); + _resolver.call(*runtime, std::move(result)); } void Promise::reject(std::string message) { - jsi::JSError error(runtime, message); - _rejecter.call(runtime, error.value()); + auto context = _context.lock(); + if (!context || !context->isValid()) { + // Runtime has been torn down, silently ignore + return; + } + auto runtime = context->getRuntime(); + jsi::JSError error(*runtime, message); + _rejecter.call(*runtime, error.value()); +} + +jsi::Runtime* Promise::getRuntime() const { + auto context = _context.lock(); + if (!context || !context->isValid()) { + return nullptr; + } + return context->getRuntime(); } } // namespace rnwgpu diff --git a/packages/webgpu/cpp/jsi/Promise.h b/packages/webgpu/cpp/jsi/Promise.h index b6b8d111a..c588c93e2 100644 --- a/packages/webgpu/cpp/jsi/Promise.h +++ b/packages/webgpu/cpp/jsi/Promise.h @@ -10,26 +10,81 @@ namespace rnwgpu { namespace jsi = facebook::jsi; +/** + * A context that wraps a runtime pointer and can be invalidated + * when the runtime is torn down (e.g., during hot reload). + */ +class RuntimeContext : public std::enable_shared_from_this { +public: + explicit RuntimeContext(jsi::Runtime& runtime) : _runtime(&runtime) {} + + void invalidate() { _runtime = nullptr; } + + jsi::Runtime* getRuntime() const { return _runtime; } + + bool isValid() const { return _runtime != nullptr; } + + /** + * Set the main runtime context (called during module initialization). + */ + static void setMainContext(std::shared_ptr context) { + _mainContext = std::move(context); + } + + /** + * Get the main runtime context (may be nullptr if not set or invalidated). + */ + static std::shared_ptr getMainContext() { + return _mainContext; + } + +private: + jsi::Runtime* _runtime; + static std::shared_ptr _mainContext; +}; + class Promise { public: - Promise(jsi::Runtime& runtime, jsi::Function&& resolver, jsi::Function&& rejecter); + Promise(std::weak_ptr context, jsi::Function&& resolver, + jsi::Function&& rejecter); void resolve(jsi::Value&& result); void reject(std::string error); -public: - jsi::Runtime& runtime; + /** + * Get the runtime pointer, or nullptr if the runtime has been torn down. + * Use this to safely construct jsi::Value before calling resolve(). + */ + jsi::Runtime* getRuntime() const; + + /** + * Resolve with a value constructed by the factory function. + * The factory is only called if the runtime is still valid. + * Usage: promise->resolveWith([&](jsi::Runtime& rt) { return + * JSIConverter::toJSI(rt, value); }); + */ + template void resolveWith(F&& valueFactory) { + auto context = _context.lock(); + if (!context || !context->isValid()) { + return; + } + auto* runtime = context->getRuntime(); + _resolver.call(*runtime, valueFactory(*runtime)); + } private: + std::weak_ptr _context; jsi::Function _resolver; jsi::Function _rejecter; public: - using RunPromise = std::function promise)>; + using RunPromise = + std::function promise)>; /** Create a new Promise and runs the given `run` function. */ - static jsi::Value createPromise(jsi::Runtime& runtime, RunPromise run); + static jsi::Value createPromise(std::shared_ptr context, + RunPromise run); }; } // namespace rnwgpu diff --git a/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.cpp b/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.cpp index 5a2decc09..a0741324e 100644 --- a/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.cpp +++ b/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.cpp @@ -55,14 +55,18 @@ RNWebGPUManager::RNWebGPUManager( std::shared_ptr jsCallInvoker, std::shared_ptr platformContext) : _jsRuntime(jsRuntime), _jsCallInvoker(jsCallInvoker), - _platformContext(platformContext) { + _platformContext(platformContext), + _runtimeContext(std::make_shared(*jsRuntime)) { // Register main runtime for RuntimeAwareCache BaseRuntimeAwareCache::setMainJsRuntime(_jsRuntime); + // Set the main runtime context for Promise creation + RuntimeContext::setMainContext(_runtimeContext); + auto gpu = std::make_shared(*_jsRuntime); - auto rnWebGPU = - std::make_shared(gpu, _platformContext, _jsCallInvoker); + auto rnWebGPU = std::make_shared(gpu, _platformContext, + _jsCallInvoker, _runtimeContext); _gpu = gpu->get(); _jsRuntime->global().setProperty(*_jsRuntime, "RNWebGPU", RNWebGPU::create(*_jsRuntime, rnWebGPU)); @@ -218,6 +222,13 @@ void RNWebGPUManager::installWebGPUWorkletHelpers(jsi::Runtime &runtime) { } RNWebGPUManager::~RNWebGPUManager() { + // Invalidate the runtime context first to prevent any pending promises + // from accessing the torn-down runtime (e.g., during hot reload) + if (_runtimeContext) { + _runtimeContext->invalidate(); + } + // Clear the global context reference + RuntimeContext::setMainContext(nullptr); _jsRuntime = nullptr; _jsCallInvoker = nullptr; } diff --git a/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.h b/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.h index 2043c9658..3630ed1ba 100644 --- a/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.h +++ b/packages/webgpu/cpp/rnwgpu/RNWebGPUManager.h @@ -4,6 +4,7 @@ #include "GPU.h" #include "PlatformContext.h" +#include "Promise.h" #include "SurfaceRegistry.h" namespace facebook { @@ -37,6 +38,7 @@ class RNWebGPUManager { private: jsi::Runtime *_jsRuntime; std::shared_ptr _jsCallInvoker; + std::shared_ptr _runtimeContext; public: wgpu::Instance _gpu; diff --git a/packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h b/packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h index 59fe14bd8..9e4540a7b 100644 --- a/packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h +++ b/packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h @@ -60,9 +60,10 @@ class RNWebGPU : public NativeObject { explicit RNWebGPU(std::shared_ptr gpu, std::shared_ptr platformContext, - std::shared_ptr callInvoker) + std::shared_ptr callInvoker, + std::shared_ptr runtimeContext) : NativeObject(CLASS_NAME), _gpu(gpu), _platformContext(platformContext), - _callInvoker(callInvoker) {} + _callInvoker(callInvoker), _runtimeContext(runtimeContext) {} std::shared_ptr getGPU() { return _gpu; } @@ -117,9 +118,10 @@ class RNWebGPU : public NativeObject { // Copy bytes on the JS thread — the ArrayBuffer pointer is into // JS-owned memory that can be GC'd std::vector dataCopy(data.begin(), data.end()); + auto runtimeContext = _runtimeContext; return Promise::createPromise( - runtime, + runtimeContext, [platformContext, callInvoker, dataCopy = std::move(dataCopy)]( jsi::Runtime & /*runtime*/, @@ -130,9 +132,10 @@ class RNWebGPU : public NativeObject { auto imageBitmap = std::make_shared(imageData); callInvoker->invokeAsync([promise, imageBitmap]() { - promise->resolve( - JSIConverter>::toJSI( - promise->runtime, imageBitmap)); + promise->resolveWith([&](jsi::Runtime &rt) { + return JSIConverter>::toJSI( + rt, imageBitmap); + }); }); }, [callInvoker, promise](std::string error) { @@ -149,21 +152,22 @@ class RNWebGPU : public NativeObject { std::string blobId = blob->blobId; double offset = blob->offset; double size = blob->size; + auto runtimeContext = _runtimeContext; return Promise::createPromise( - runtime, + runtimeContext, [platformContext, callInvoker, blobId, offset, size](jsi::Runtime & /*runtime*/, std::shared_ptr promise) { platformContext->createImageBitmapAsync( blobId, offset, size, [callInvoker, promise](ImageData imageData) { auto imageBitmap = std::make_shared(imageData); - callInvoker->invokeAsync( - [promise, imageBitmap]() { - promise->resolve( - JSIConverter>::toJSI( - promise->runtime, imageBitmap)); - }); + callInvoker->invokeAsync([promise, imageBitmap]() { + promise->resolveWith([&](jsi::Runtime &rt) { + return JSIConverter>::toJSI( + rt, imageBitmap); + }); + }); }, [callInvoker, promise](std::string error) { callInvoker->invokeAsync( @@ -198,6 +202,7 @@ class RNWebGPU : public NativeObject { std::shared_ptr _gpu; std::shared_ptr _platformContext; std::shared_ptr _callInvoker; + std::shared_ptr _runtimeContext; }; } // namespace rnwgpu