Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions packages/webgpu/cpp/jsi/JSIConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,21 @@ template <> struct JSIConverter<rnwgpu::async::AsyncTaskHandle> {
}

static jsi::Value toJSI(jsi::Runtime& runtime, rnwgpu::async::AsyncTaskHandle&& handle) {
return rnwgpu::Promise::createPromise(runtime, [handle = std::move(handle)](jsi::Runtime& runtime,
std::shared_ptr<rnwgpu::Promise> promise) mutable {
if (!handle.valid()) {
promise->resolve(jsi::Value::undefined());
return;
}

handle.attachPromise(promise);
});
auto context = rnwgpu::RuntimeContext::getMainContext();
if (!context) {
throw std::runtime_error("RuntimeContext not set - cannot create Promise");
}
return rnwgpu::Promise::createPromise(
context,
[handle = std::move(handle)](jsi::Runtime& runtime,
std::shared_ptr<rnwgpu::Promise> promise) mutable {
if (!handle.valid()) {
promise->resolve(jsi::Value::undefined());
return;
}

handle.attachPromise(promise);
});
}
};

Expand Down
52 changes: 41 additions & 11 deletions packages/webgpu/cpp/jsi/Promise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,65 @@ namespace rnwgpu {

namespace jsi = facebook::jsi;

Promise::Promise(jsi::Runtime& runtime, jsi::Function&& resolver, jsi::Function&& rejecter)
: runtime(runtime), _resolver(std::move(resolver)), _rejecter(std::move(rejecter)) {}
// Static member definition
std::shared_ptr<RuntimeContext> RuntimeContext::_mainContext = nullptr;

jsi::Value Promise::createPromise(jsi::Runtime& runtime, RunPromise run) {
Promise::Promise(std::weak_ptr<RuntimeContext> context, jsi::Function&& resolver,
jsi::Function&& rejecter)
: _context(std::move(context)), _resolver(std::move(resolver)),
_rejecter(std::move(rejecter)) {}

jsi::Value Promise::createPromise(std::shared_ptr<RuntimeContext> context,
RunPromise run) {
auto runtime = context->getRuntime();
// Get Promise ctor from global
auto promiseCtor = runtime.global().getPropertyAsFunction(runtime, "Promise");
auto promiseCtor = runtime->global().getPropertyAsFunction(*runtime, "Promise");

auto promiseCallback = jsi::Function::createFromHostFunction(
runtime, jsi::PropNameID::forUtf8(runtime, "PromiseCallback"), 2,
[=](jsi::Runtime& runtime, const jsi::Value& thisValue, const jsi::Value* arguments, size_t count) -> jsi::Value {
*runtime, jsi::PropNameID::forUtf8(*runtime, "PromiseCallback"), 2,
[context, run](jsi::Runtime& runtime, const jsi::Value& thisValue,
const jsi::Value* arguments,
size_t count) -> jsi::Value {
// Call function
auto resolver = arguments[0].asObject(runtime).asFunction(runtime);
auto rejecter = arguments[1].asObject(runtime).asFunction(runtime);
auto promise = std::make_shared<Promise>(runtime, std::move(resolver), std::move(rejecter));
auto promise = std::make_shared<Promise>(context, std::move(resolver),
std::move(rejecter));
run(runtime, promise);

return jsi::Value::undefined();
});

return promiseCtor.callAsConstructor(runtime, promiseCallback);
return promiseCtor.callAsConstructor(*runtime, promiseCallback);
}

void Promise::resolve(jsi::Value&& result) {
_resolver.call(runtime, std::move(result));
auto context = _context.lock();
if (!context || !context->isValid()) {
// Runtime has been torn down, silently ignore
return;
}
auto runtime = context->getRuntime();
_resolver.call(*runtime, std::move(result));
}

void Promise::reject(std::string message) {
jsi::JSError error(runtime, message);
_rejecter.call(runtime, error.value());
auto context = _context.lock();
if (!context || !context->isValid()) {
// Runtime has been torn down, silently ignore
return;
}
auto runtime = context->getRuntime();
jsi::JSError error(*runtime, message);
_rejecter.call(*runtime, error.value());
}

jsi::Runtime* Promise::getRuntime() const {
auto context = _context.lock();
if (!context || !context->isValid()) {
return nullptr;
}
return context->getRuntime();
}

} // namespace rnwgpu
65 changes: 60 additions & 5 deletions packages/webgpu/cpp/jsi/Promise.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,81 @@ namespace rnwgpu {

namespace jsi = facebook::jsi;

/**
* A context that wraps a runtime pointer and can be invalidated
* when the runtime is torn down (e.g., during hot reload).
*/
class RuntimeContext : public std::enable_shared_from_this<RuntimeContext> {
public:
explicit RuntimeContext(jsi::Runtime& runtime) : _runtime(&runtime) {}

void invalidate() { _runtime = nullptr; }

jsi::Runtime* getRuntime() const { return _runtime; }

bool isValid() const { return _runtime != nullptr; }

/**
* Set the main runtime context (called during module initialization).
*/
static void setMainContext(std::shared_ptr<RuntimeContext> context) {
_mainContext = std::move(context);
}

/**
* Get the main runtime context (may be nullptr if not set or invalidated).
*/
static std::shared_ptr<RuntimeContext> getMainContext() {
return _mainContext;
}

private:
jsi::Runtime* _runtime;
static std::shared_ptr<RuntimeContext> _mainContext;
};

class Promise {
public:
Promise(jsi::Runtime& runtime, jsi::Function&& resolver, jsi::Function&& rejecter);
Promise(std::weak_ptr<RuntimeContext> context, jsi::Function&& resolver,
jsi::Function&& rejecter);

void resolve(jsi::Value&& result);
void reject(std::string error);

public:
jsi::Runtime& runtime;
/**
* Get the runtime pointer, or nullptr if the runtime has been torn down.
* Use this to safely construct jsi::Value before calling resolve().
*/
jsi::Runtime* getRuntime() const;

/**
* Resolve with a value constructed by the factory function.
* The factory is only called if the runtime is still valid.
* Usage: promise->resolveWith([&](jsi::Runtime& rt) { return
* JSIConverter<T>::toJSI(rt, value); });
*/
template <typename F> void resolveWith(F&& valueFactory) {
auto context = _context.lock();
if (!context || !context->isValid()) {
return;
}
auto* runtime = context->getRuntime();
_resolver.call(*runtime, valueFactory(*runtime));
}

private:
std::weak_ptr<RuntimeContext> _context;
jsi::Function _resolver;
jsi::Function _rejecter;

public:
using RunPromise = std::function<void(jsi::Runtime& runtime, std::shared_ptr<Promise> promise)>;
using RunPromise =
std::function<void(jsi::Runtime& runtime, std::shared_ptr<Promise> promise)>;
/**
Create a new Promise and runs the given `run` function.
*/
static jsi::Value createPromise(jsi::Runtime& runtime, RunPromise run);
static jsi::Value createPromise(std::shared_ptr<RuntimeContext> context,
RunPromise run);
};

} // namespace rnwgpu
17 changes: 14 additions & 3 deletions packages/webgpu/cpp/rnwgpu/RNWebGPUManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,18 @@ RNWebGPUManager::RNWebGPUManager(
std::shared_ptr<facebook::react::CallInvoker> jsCallInvoker,
std::shared_ptr<PlatformContext> platformContext)
: _jsRuntime(jsRuntime), _jsCallInvoker(jsCallInvoker),
_platformContext(platformContext) {
_platformContext(platformContext),
_runtimeContext(std::make_shared<RuntimeContext>(*jsRuntime)) {

// Register main runtime for RuntimeAwareCache
BaseRuntimeAwareCache::setMainJsRuntime(_jsRuntime);

// Set the main runtime context for Promise creation
RuntimeContext::setMainContext(_runtimeContext);

auto gpu = std::make_shared<GPU>(*_jsRuntime);
auto rnWebGPU =
std::make_shared<RNWebGPU>(gpu, _platformContext, _jsCallInvoker);
auto rnWebGPU = std::make_shared<RNWebGPU>(gpu, _platformContext,
_jsCallInvoker, _runtimeContext);
_gpu = gpu->get();
_jsRuntime->global().setProperty(*_jsRuntime, "RNWebGPU",
RNWebGPU::create(*_jsRuntime, rnWebGPU));
Expand Down Expand Up @@ -218,6 +222,13 @@ void RNWebGPUManager::installWebGPUWorkletHelpers(jsi::Runtime &runtime) {
}

RNWebGPUManager::~RNWebGPUManager() {
// Invalidate the runtime context first to prevent any pending promises
// from accessing the torn-down runtime (e.g., during hot reload)
if (_runtimeContext) {
_runtimeContext->invalidate();
}
// Clear the global context reference
RuntimeContext::setMainContext(nullptr);
_jsRuntime = nullptr;
_jsCallInvoker = nullptr;
}
Expand Down
2 changes: 2 additions & 0 deletions packages/webgpu/cpp/rnwgpu/RNWebGPUManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "GPU.h"
#include "PlatformContext.h"
#include "Promise.h"
#include "SurfaceRegistry.h"

namespace facebook {
Expand Down Expand Up @@ -37,6 +38,7 @@ class RNWebGPUManager {
private:
jsi::Runtime *_jsRuntime;
std::shared_ptr<facebook::react::CallInvoker> _jsCallInvoker;
std::shared_ptr<RuntimeContext> _runtimeContext;

public:
wgpu::Instance _gpu;
Expand Down
31 changes: 18 additions & 13 deletions packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ class RNWebGPU : public NativeObject<RNWebGPU> {

explicit RNWebGPU(std::shared_ptr<GPU> gpu,
std::shared_ptr<PlatformContext> platformContext,
std::shared_ptr<facebook::react::CallInvoker> callInvoker)
std::shared_ptr<facebook::react::CallInvoker> callInvoker,
std::shared_ptr<RuntimeContext> runtimeContext)
: NativeObject(CLASS_NAME), _gpu(gpu), _platformContext(platformContext),
_callInvoker(callInvoker) {}
_callInvoker(callInvoker), _runtimeContext(runtimeContext) {}

std::shared_ptr<GPU> getGPU() { return _gpu; }

Expand Down Expand Up @@ -117,9 +118,10 @@ class RNWebGPU : public NativeObject<RNWebGPU> {
// Copy bytes on the JS thread — the ArrayBuffer pointer is into
// JS-owned memory that can be GC'd
std::vector<uint8_t> dataCopy(data.begin(), data.end());
auto runtimeContext = _runtimeContext;

return Promise::createPromise(
runtime,
runtimeContext,
[platformContext, callInvoker,
dataCopy = std::move(dataCopy)](
jsi::Runtime & /*runtime*/,
Expand All @@ -130,9 +132,10 @@ class RNWebGPU : public NativeObject<RNWebGPU> {
auto imageBitmap =
std::make_shared<ImageBitmap>(imageData);
callInvoker->invokeAsync([promise, imageBitmap]() {
promise->resolve(
JSIConverter<std::shared_ptr<ImageBitmap>>::toJSI(
promise->runtime, imageBitmap));
promise->resolveWith([&](jsi::Runtime &rt) {
return JSIConverter<std::shared_ptr<ImageBitmap>>::toJSI(
rt, imageBitmap);
});
});
},
[callInvoker, promise](std::string error) {
Expand All @@ -149,21 +152,22 @@ class RNWebGPU : public NativeObject<RNWebGPU> {
std::string blobId = blob->blobId;
double offset = blob->offset;
double size = blob->size;
auto runtimeContext = _runtimeContext;

return Promise::createPromise(
runtime,
runtimeContext,
[platformContext, callInvoker, blobId, offset,
size](jsi::Runtime & /*runtime*/, std::shared_ptr<Promise> promise) {
platformContext->createImageBitmapAsync(
blobId, offset, size,
[callInvoker, promise](ImageData imageData) {
auto imageBitmap = std::make_shared<ImageBitmap>(imageData);
callInvoker->invokeAsync(
[promise, imageBitmap]() {
promise->resolve(
JSIConverter<std::shared_ptr<ImageBitmap>>::toJSI(
promise->runtime, imageBitmap));
});
callInvoker->invokeAsync([promise, imageBitmap]() {
promise->resolveWith([&](jsi::Runtime &rt) {
return JSIConverter<std::shared_ptr<ImageBitmap>>::toJSI(
rt, imageBitmap);
});
});
},
[callInvoker, promise](std::string error) {
callInvoker->invokeAsync(
Expand Down Expand Up @@ -198,6 +202,7 @@ class RNWebGPU : public NativeObject<RNWebGPU> {
std::shared_ptr<GPU> _gpu;
std::shared_ptr<PlatformContext> _platformContext;
std::shared_ptr<facebook::react::CallInvoker> _callInvoker;
std::shared_ptr<RuntimeContext> _runtimeContext;
};

} // namespace rnwgpu
Loading