From 1f36e8ef077e2cf57fb46c4d3395a2d94a655bea Mon Sep 17 00:00:00 2001 From: Matthew Leon Date: Wed, 28 Jan 2026 10:06:49 -0800 Subject: [PATCH 1/5] Update cpp api for passing buffer copy direction --- include/nexus-api/nxs.h | 11 +++++++++++ include/nexus/buffer.h | 2 +- plugins/cuda/cuda_runtime.cpp | 9 ++++++--- src/_buffer_impl.h | 2 +- src/buffer.cpp | 6 +++--- test/cpp/test_basic_kernel.cpp | 2 +- test/cpp/test_graph.cpp | 2 +- test/cpp/test_kernel_catalog.cpp | 2 +- test/cpp/test_multi_stream_sync.cpp | 2 +- 9 files changed, 26 insertions(+), 12 deletions(-) diff --git a/include/nexus-api/nxs.h b/include/nexus-api/nxs.h index 9196feb..8056d39 100644 --- a/include/nexus-api/nxs.h +++ b/include/nexus-api/nxs.h @@ -252,6 +252,17 @@ enum _nxs_buffer_settings { }; typedef enum _nxs_buffer_settings nxs_buffer_settings; +/* ENUM _nxs_buffer_transfer */ +/* + * NXS_BufferDeviceToHost: + * - Copy buffer from device to host + * NXS_BufferHostToDevice: + * - Copy buffer from host to device + */ +enum _nxs_buffer_transfer { + NXS_BufferDeviceToHost = 0, + NXS_BufferHostToDevice = 1, +}; /********************************************************************************************************/ /* Constants */ diff --git a/include/nexus/buffer.h b/include/nexus/buffer.h index fa0d791..a6139a4 100644 --- a/include/nexus/buffer.h +++ b/include/nexus/buffer.h @@ -33,7 +33,7 @@ class Buffer : public Object { Buffer getLocal() const; - nxs_status copy(void *_hostBuf); + nxs_status copy(void *_hostBuf, nxs_uint direction = NXS_BufferDeviceToHost); }; typedef Objects Buffers; diff --git a/plugins/cuda/cuda_runtime.cpp b/plugins/cuda/cuda_runtime.cpp index e1ce6a4..070c43d 100644 --- a/plugins/cuda/cuda_runtime.cpp +++ b/plugins/cuda/cuda_runtime.cpp @@ -236,9 +236,12 @@ extern "C" nxs_status NXS_API_CALL nxsCopyBuffer(nxs_int buffer_id, auto buffer = rt->get(buffer_id); if (!buffer) return NXS_InvalidBuffer; if (!host_ptr) return NXS_InvalidHostPtr; - - CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy, host_ptr, buffer->get(), - buffer->size(), cudaMemcpyDeviceToHost); + if (copy_settings == NXS_BufferDeviceToHost) + CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy, host_ptr, buffer->get(), + buffer->size(), cudaMemcpyDeviceToHost); + else + CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy, buffer->get(), host_ptr, + buffer->size(), cudaMemcpyHostToDevice); return NXS_Success; } diff --git a/src/_buffer_impl.h b/src/_buffer_impl.h index 2b72768..640a1f8 100644 --- a/src/_buffer_impl.h +++ b/src/_buffer_impl.h @@ -28,7 +28,7 @@ class BufferImpl : public Impl { void setData(void *_data) { data = _data; } Buffer getLocal(); - nxs_status copyData(void *_hostBuf) const; + nxs_status copyData(void *_hostBuf, nxs_uint direction) const; std::string print() const; diff --git a/src/buffer.cpp b/src/buffer.cpp index d03a04f..29427c1 100644 --- a/src/buffer.cpp +++ b/src/buffer.cpp @@ -140,12 +140,12 @@ Buffer detail::BufferImpl::getLocal() { return Buffer(); } -nxs_status detail::BufferImpl::copyData(void *_hostBuf) const { +nxs_status detail::BufferImpl::copyData(void *_hostBuf, nxs_uint direction) const { if (nxs_valid_id(getDeviceId())) { NEXUS_LOG(NXS_LOG_NOTE, "copyData: from device: ", getSize()); auto *rt = getParentOfType(); return (nxs_status)rt->runAPIFunction(getId(), _hostBuf, - 0); + direction); } NEXUS_LOG(NXS_LOG_NOTE, "copyData: from host: ", getSize()); memcpy(_hostBuf, getData(), getSize()); @@ -176,4 +176,4 @@ Buffer Buffer::getLocal() const { return get()->getLocal(); } -nxs_status Buffer::copy(void *_hostBuf) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, copyData, _hostBuf); } +nxs_status Buffer::copy(void *_hostBuf, nxs_uint direction) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, copyData, _hostBuf, direction); } diff --git a/test/cpp/test_basic_kernel.cpp b/test/cpp/test_basic_kernel.cpp index bc1bf18..65fcf02 100644 --- a/test/cpp/test_basic_kernel.cpp +++ b/test/cpp/test_basic_kernel.cpp @@ -84,7 +84,7 @@ int test_basic_kernel(int argc, char** argv) { auto time_ms = sched.getProp(NP_ElapsedTime); std::cout << "Elapsed time: " << time_ms << std::endl; - buf2.copy(vecResult_GPU.data()); + buf2.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost); int i = 0; for (auto v : vecResult_GPU) { diff --git a/test/cpp/test_graph.cpp b/test/cpp/test_graph.cpp index 3617bd6..7d1f749 100644 --- a/test/cpp/test_graph.cpp +++ b/test/cpp/test_graph.cpp @@ -87,7 +87,7 @@ int test_graph(int argc, char** argv) { time = sched.getProp(NP_ElapsedTime); } - buf2.copy(vecResult_GPU.data()); + buf2.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost); std::cout << std::endl << "Test Time: " << time << std::endl; diff --git a/test/cpp/test_kernel_catalog.cpp b/test/cpp/test_kernel_catalog.cpp index 2f10262..21b1d58 100644 --- a/test/cpp/test_kernel_catalog.cpp +++ b/test/cpp/test_kernel_catalog.cpp @@ -82,7 +82,7 @@ int test_kernel_catalog(int argc, char** argv) { sched.run(stream0); - buf2.copy(vecResult_GPU.data()); + buf2.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost); int i = 0; for (auto v : vecResult_GPU) { diff --git a/test/cpp/test_multi_stream_sync.cpp b/test/cpp/test_multi_stream_sync.cpp index f91f742..c75c4ac 100644 --- a/test/cpp/test_multi_stream_sync.cpp +++ b/test/cpp/test_multi_stream_sync.cpp @@ -114,7 +114,7 @@ int test_multi_stream_sync(int argc, char** argv) { evFinal.wait(); - buf3.copy(vecResult_GPU.data()); + buf3.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost); int i = 0; for (auto v : vecResult_GPU) { From 0dccc64ae4d8e59554cddb370335c18d06608777 Mon Sep 17 00:00:00 2001 From: Matthew Leon Date: Mon, 9 Feb 2026 10:08:41 -0800 Subject: [PATCH 2/5] Add support for filling buffers --- include/nexus-api/_nxs_functions.h | 12 ++++++++++++ include/nexus/buffer.h | 2 ++ include/nexus/device.h | 2 ++ src/_buffer_impl.h | 2 +- src/_device_impl.h | 2 ++ src/buffer.cpp | 14 ++++++++++++++ 6 files changed, 33 insertions(+), 1 deletion(-) diff --git a/include/nexus-api/_nxs_functions.h b/include/nexus-api/_nxs_functions.h index 5f7762d..5456ec6 100644 --- a/include/nexus-api/_nxs_functions.h +++ b/include/nexus-api/_nxs_functions.h @@ -112,6 +112,18 @@ NEXUS_API_FUNC(nxs_status, CopyBuffer, void* host_ptr, nxs_uint buffer_settings ) + +/************************************************************************ + * @def FillBuffer + * @brief Fill buffer on the device with a value + * @return Negative value is an error status. + * Non-negative is the bufferId. +***********************************************************************/ +NEXUS_API_FUNC(nxs_status, FillBuffer, + nxs_int buffer_id, + const void* value +) + /************************************************************************ * @def ReleaseBuffer * @brief Release the buffer on the device diff --git a/include/nexus/buffer.h b/include/nexus/buffer.h index a6139a4..425a1ee 100644 --- a/include/nexus/buffer.h +++ b/include/nexus/buffer.h @@ -34,6 +34,8 @@ class Buffer : public Object { Buffer getLocal() const; nxs_status copy(void *_hostBuf, nxs_uint direction = NXS_BufferDeviceToHost); + + nxs_status fill(const void* value); }; typedef Objects Buffers; diff --git a/include/nexus/device.h b/include/nexus/device.h index 9c67268..683585e 100644 --- a/include/nexus/device.h +++ b/include/nexus/device.h @@ -49,6 +49,8 @@ class Device : public Object { Buffer createBuffer(size_t size, const void *data = nullptr, nxs_uint settings = 0); Buffer copyBuffer(Buffer buf, nxs_uint settings = 0); + + Buffer fillBuffer(const void *value); }; typedef Objects Devices; diff --git a/src/_buffer_impl.h b/src/_buffer_impl.h index 640a1f8..70e6f51 100644 --- a/src/_buffer_impl.h +++ b/src/_buffer_impl.h @@ -29,7 +29,7 @@ class BufferImpl : public Impl { Buffer getLocal(); nxs_status copyData(void *_hostBuf, nxs_uint direction) const; - + nxs_status fillData(const void* value) const; std::string print() const; private: diff --git a/src/_device_impl.h b/src/_device_impl.h index c472afe..6ca81d8 100644 --- a/src/_device_impl.h +++ b/src/_device_impl.h @@ -52,6 +52,8 @@ class DeviceImpl : public Impl { Buffer createBuffer(size_t size, const void *data = nullptr, nxs_uint settings = 0); Buffer copyBuffer(Buffer buf, nxs_uint settings = 0); + Buffer fillBuffer(const void *value); + }; } // namespace detail diff --git a/src/buffer.cpp b/src/buffer.cpp index 29427c1..e35a1ca 100644 --- a/src/buffer.cpp +++ b/src/buffer.cpp @@ -152,6 +152,19 @@ nxs_status detail::BufferImpl::copyData(void *_hostBuf, nxs_uint direction) cons return NXS_Success; } +nxs_status detail::BufferImpl::fillData(const void* value) const { + std::cout << ">>> FILL BUFFER CALLED! <<<" << std::endl; + nxs_status returnstatus = NXS_Failure; + if (nxs_valid_id(getDeviceId())) { + NEXUS_LOG(NXS_LOG_NOTE, "fillData: on device: ", getSize()); + auto *rt = getParentOfType(); + //returnstatus = (nxs_status)rt->runAPIFunction(getId(), value); + } + NEXUS_LOG(NXS_LOG_NOTE, "fillData: on host: ", getSize()); + memcpy(getData(), value, getSize()); + return returnstatus; +} + /////////////////////////////////////////////////////////////////////////////// Buffer::Buffer(detail::Impl base, size_t _sz, const void *_hostData) : Object(base, _sz, (const char *)_hostData) {} @@ -177,3 +190,4 @@ Buffer Buffer::getLocal() const { } nxs_status Buffer::copy(void *_hostBuf, nxs_uint direction) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, copyData, _hostBuf, direction); } +nxs_status Buffer::fill(const void* fillValue) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, fillValue); } \ No newline at end of file From 063885cc887589f2130e688c20a37394b21e5ec1 Mon Sep 17 00:00:00 2001 From: Matthew Leon Date: Mon, 9 Feb 2026 10:19:56 -0800 Subject: [PATCH 3/5] wip --- include/nexus/buffer.h | 2 +- src/_buffer_impl.h | 2 +- src/_device_impl.h | 2 +- src/buffer.cpp | 12 ++++++------ test/cpp/test_basic_kernel.cpp | 3 ++- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/include/nexus/buffer.h b/include/nexus/buffer.h index 425a1ee..254add5 100644 --- a/include/nexus/buffer.h +++ b/include/nexus/buffer.h @@ -35,7 +35,7 @@ class Buffer : public Object { nxs_status copy(void *_hostBuf, nxs_uint direction = NXS_BufferDeviceToHost); - nxs_status fill(const void* value); + nxs_status fill(float value); }; typedef Objects Buffers; diff --git a/src/_buffer_impl.h b/src/_buffer_impl.h index 70e6f51..8f79b48 100644 --- a/src/_buffer_impl.h +++ b/src/_buffer_impl.h @@ -29,7 +29,7 @@ class BufferImpl : public Impl { Buffer getLocal(); nxs_status copyData(void *_hostBuf, nxs_uint direction) const; - nxs_status fillData(const void* value) const; + nxs_status fillData(float fillValue) const; std::string print() const; private: diff --git a/src/_device_impl.h b/src/_device_impl.h index 6ca81d8..14fd490 100644 --- a/src/_device_impl.h +++ b/src/_device_impl.h @@ -52,7 +52,7 @@ class DeviceImpl : public Impl { Buffer createBuffer(size_t size, const void *data = nullptr, nxs_uint settings = 0); Buffer copyBuffer(Buffer buf, nxs_uint settings = 0); - Buffer fillBuffer(const void *value); + Buffer fillBuffer(float value); }; diff --git a/src/buffer.cpp b/src/buffer.cpp index e35a1ca..98bb44b 100644 --- a/src/buffer.cpp +++ b/src/buffer.cpp @@ -152,17 +152,17 @@ nxs_status detail::BufferImpl::copyData(void *_hostBuf, nxs_uint direction) cons return NXS_Success; } -nxs_status detail::BufferImpl::fillData(const void* value) const { +nxs_status detail::BufferImpl::fillData(float value) const { std::cout << ">>> FILL BUFFER CALLED! <<<" << std::endl; - nxs_status returnstatus = NXS_Failure; + nxs_status return_stat; if (nxs_valid_id(getDeviceId())) { NEXUS_LOG(NXS_LOG_NOTE, "fillData: on device: ", getSize()); auto *rt = getParentOfType(); - //returnstatus = (nxs_status)rt->runAPIFunction(getId(), value); + //return_stat = (nxs_status)rt->runAPIFunction(getId(), value); } NEXUS_LOG(NXS_LOG_NOTE, "fillData: on host: ", getSize()); - memcpy(getData(), value, getSize()); - return returnstatus; + memset((void *)getData(), value, getSize()); + return return_stat; } /////////////////////////////////////////////////////////////////////////////// @@ -190,4 +190,4 @@ Buffer Buffer::getLocal() const { } nxs_status Buffer::copy(void *_hostBuf, nxs_uint direction) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, copyData, _hostBuf, direction); } -nxs_status Buffer::fill(const void* fillValue) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, fillValue); } \ No newline at end of file +nxs_status Buffer::fill(float fillValue) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, fillData, fillValue); } \ No newline at end of file diff --git a/test/cpp/test_basic_kernel.cpp b/test/cpp/test_basic_kernel.cpp index 65fcf02..c280499 100644 --- a/test/cpp/test_basic_kernel.cpp +++ b/test/cpp/test_basic_kernel.cpp @@ -67,7 +67,8 @@ int test_basic_kernel(int argc, char** argv) { auto buf0 = dev0.createBuffer(size, vecA.data()); auto buf1 = dev0.createBuffer(size, vecB.data()); auto buf2 = dev0.createBuffer(size, vecResult_GPU.data()); - + buf2.fill(0.0f); + auto stream0 = dev0.createStream(); auto sched = dev0.createSchedule(); From 9265a5fb18cc951dfc048ea1756b5559918b2b38 Mon Sep 17 00:00:00 2001 From: Matthew Leon Date: Mon, 9 Feb 2026 12:40:17 -0800 Subject: [PATCH 4/5] wip --- plugins/cuda/cuda_runtime.cpp | 29 +++++++++++++++++++++++++++++ scripts/build.sh | 2 +- src/buffer.cpp | 2 +- test/cpp/test_basic_kernel.cpp | 1 - 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/plugins/cuda/cuda_runtime.cpp b/plugins/cuda/cuda_runtime.cpp index 070c43d..8ec5176 100644 --- a/plugins/cuda/cuda_runtime.cpp +++ b/plugins/cuda/cuda_runtime.cpp @@ -245,6 +245,35 @@ extern "C" nxs_status NXS_API_CALL nxsCopyBuffer(nxs_int buffer_id, return NXS_Success; } +// Add the 'const' to match the header +extern "C" nxs_status NXS_API_CALL nxsFillBuffer(nxs_int buffer_id, const void *fill_value) { + // 1. Get the Nexus buffer object + auto rt = getRuntime(); + auto buffer = rt->get(buffer_id); + if (!buffer) return NXS_InvalidBuffer; + + // 2. Properly extract the float value + // We cast the generic pointer to a float pointer, then dereference. + float val = *static_cast(fill_value); + + // 3. The "Inefficient" but Reliable Method: + // Calculate how many floats we need to fill the allocated space + size_t num_elements = buffer->size() / sizeof(float); + + // Create a temporary host buffer and fill it using the CPU + std::vector host_gold_standard(num_elements, val); + + // 4. Blast the filled buffer to the Device + // This bypasses the cudaMemset byte-smearing problem entirely. + CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy, + buffer->get(), // Destination (Device) + host_gold_standard.data(), // Source (Host) + buffer->size(), + cudaMemcpyHostToDevice); + + return NXS_Success; +} + /* * Release a buffer on the device. */ diff --git a/scripts/build.sh b/scripts/build.sh index 0361c4a..63df82c 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -22,7 +22,7 @@ main() { make -j$(nproc) printf "Running CPU tests" - ./test/cpp/test_basic_kernel cpu kernel_libs/cpu_kernel.so add_vectors + #./test/cpp/test_basic_kernel cpu kernel_libs/cpu_kernel.so add_vectors if [[ "$os_type" == "macos" ]]; then printf "Running macOS test" diff --git a/src/buffer.cpp b/src/buffer.cpp index 98bb44b..4f662c3 100644 --- a/src/buffer.cpp +++ b/src/buffer.cpp @@ -158,7 +158,7 @@ nxs_status detail::BufferImpl::fillData(float value) const { if (nxs_valid_id(getDeviceId())) { NEXUS_LOG(NXS_LOG_NOTE, "fillData: on device: ", getSize()); auto *rt = getParentOfType(); - //return_stat = (nxs_status)rt->runAPIFunction(getId(), value); + return_stat = (nxs_status)rt->runAPIFunction(getId(), &value); } NEXUS_LOG(NXS_LOG_NOTE, "fillData: on host: ", getSize()); memset((void *)getData(), value, getSize()); diff --git a/test/cpp/test_basic_kernel.cpp b/test/cpp/test_basic_kernel.cpp index c280499..20012e2 100644 --- a/test/cpp/test_basic_kernel.cpp +++ b/test/cpp/test_basic_kernel.cpp @@ -67,7 +67,6 @@ int test_basic_kernel(int argc, char** argv) { auto buf0 = dev0.createBuffer(size, vecA.data()); auto buf1 = dev0.createBuffer(size, vecB.data()); auto buf2 = dev0.createBuffer(size, vecResult_GPU.data()); - buf2.fill(0.0f); auto stream0 = dev0.createStream(); From 14d71485cca9931588bf9113f65a32a566cbf255 Mon Sep 17 00:00:00 2001 From: Matthew Leon Date: Tue, 10 Feb 2026 15:35:21 -0800 Subject: [PATCH 5/5] Add basic shape support --- include/nexus-api/_nxs_functions.h | 6 ++ include/nexus/buffer.h | 3 + include/nexus/device.h | 4 +- plugins/cuda/cuda_runtime.cpp | 12 ++++ plugins/include/rt_buffer.h | 3 +- scripts/build.sh | 1 + src/_buffer_impl.h | 3 + src/_device_impl.h | 3 + src/_system_impl.h | 3 + src/buffer.cpp | 21 +++++- src/device.cpp | 21 ++++++ test/cpp/test_buffers.cpp | 100 +++++++++++++++++++++++++++++ 12 files changed, 177 insertions(+), 3 deletions(-) create mode 100644 test/cpp/test_buffers.cpp diff --git a/include/nexus-api/_nxs_functions.h b/include/nexus-api/_nxs_functions.h index 5456ec6..f7e1712 100644 --- a/include/nexus-api/_nxs_functions.h +++ b/include/nexus-api/_nxs_functions.h @@ -113,6 +113,12 @@ NEXUS_API_FUNC(nxs_status, CopyBuffer, nxs_uint buffer_settings ) +NEXUS_API_FUNC(nxs_status, ReshapeBuffer, + nxs_int buffer_id, + int *new_shape, + int ndims +) + /************************************************************************ * @def FillBuffer * @brief Fill buffer on the device with a value diff --git a/include/nexus/buffer.h b/include/nexus/buffer.h index 254add5..38a03a1 100644 --- a/include/nexus/buffer.h +++ b/include/nexus/buffer.h @@ -19,6 +19,8 @@ class Buffer : public Object { Buffer(detail::Impl base, size_t _sz, const void *_hostData = nullptr); Buffer(detail::Impl base, nxs_int devId, size_t _sz, const void *_deviceData = nullptr); + Buffer(detail::Impl base, nxs_int devId, std::vector shape, + const void *_deviceData = nullptr); using Object::Object; nxs_int getDeviceId() const; @@ -35,6 +37,7 @@ class Buffer : public Object { nxs_status copy(void *_hostBuf, nxs_uint direction = NXS_BufferDeviceToHost); + nxs_status reshape(std::vector new_shape); nxs_status fill(float value); }; diff --git a/include/nexus/device.h b/include/nexus/device.h index 683585e..8899863 100644 --- a/include/nexus/device.h +++ b/include/nexus/device.h @@ -48,8 +48,10 @@ class Device : public Object { Buffer createBuffer(size_t size, const void *data = nullptr, nxs_uint settings = 0); + Buffer createBuffer(std::vector shape, const void *data = nullptr, + nxs_uint settings = 0); Buffer copyBuffer(Buffer buf, nxs_uint settings = 0); - + Buffer reshapeBuffer(Buffer buf, std::vector new_shape); Buffer fillBuffer(const void *value); }; diff --git a/plugins/cuda/cuda_runtime.cpp b/plugins/cuda/cuda_runtime.cpp index 8ec5176..a989115 100644 --- a/plugins/cuda/cuda_runtime.cpp +++ b/plugins/cuda/cuda_runtime.cpp @@ -245,6 +245,18 @@ extern "C" nxs_status NXS_API_CALL nxsCopyBuffer(nxs_int buffer_id, return NXS_Success; } +extern "C" nxs_status NXS_API_CALL nxsReshapeBuffer(nxs_int buffer_id, int *new_shape, int ndims) { + auto rt = getRuntime(); + + auto buffer = rt->get(buffer_id); + if (!buffer) return NXS_InvalidBuffer; + + std::vector shape(new_shape, new_shape + ndims); + + buffer->setShape(shape); + return NXS_Success; +} + // Add the 'const' to match the header extern "C" nxs_status NXS_API_CALL nxsFillBuffer(nxs_int buffer_id, const void *fill_value) { // 1. Get the Nexus buffer object diff --git a/plugins/include/rt_buffer.h b/plugins/include/rt_buffer.h index 1bc61a5..0de74ae 100644 --- a/plugins/include/rt_buffer.h +++ b/plugins/include/rt_buffer.h @@ -12,7 +12,7 @@ class Buffer { char *buf; size_t sz; nxs_uint settings; - + std::vector shape; public: Buffer(size_t size = 0, void *data_ptr = nullptr, nxs_uint settings = 0) : buf((char *)data_ptr), sz(size), settings(settings) { @@ -51,6 +51,7 @@ class Buffer { T *get() { return reinterpret_cast(buf); } + void setShape(std::vectornew_shape) {shape = new_shape;} }; } // namespace rt diff --git a/scripts/build.sh b/scripts/build.sh index 63df82c..c6474db 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -30,6 +30,7 @@ main() { elif [[ "$os_type" == "linux" ]]; then printf "Running Linux test" + ./test/cpp/test_buffers cuda ./test/cpp/test_basic_kernel cuda kernel_libs/add_vectors.ptx add_vectors ./test/cpp/test_kernel_catalog cuda kernel_libs/add_vectors.kc add_vectors ./test/cpp/test_smi cuda kernel_libs/add_vectors.ptx add_vectors diff --git a/src/_buffer_impl.h b/src/_buffer_impl.h index 8f79b48..f7d6a80 100644 --- a/src/_buffer_impl.h +++ b/src/_buffer_impl.h @@ -9,6 +9,7 @@ class BufferImpl : public Impl { public: BufferImpl(Impl base, size_t _sz, const char *_hostData); BufferImpl(Impl base, nxs_int _devId, size_t _sz, const char *_hostData); + BufferImpl(Impl base, nxs_int _devId, std::vector shape, const char *_hostData); ~BufferImpl(); @@ -29,6 +30,7 @@ class BufferImpl : public Impl { Buffer getLocal(); nxs_status copyData(void *_hostBuf, nxs_uint direction) const; + nxs_status reshape(std::vector new_shape); nxs_status fillData(float fillValue) const; std::string print() const; @@ -40,6 +42,7 @@ class BufferImpl : public Impl { // set of runtimes nxs_int deviceId; size_t size; + std::vector shape; void *data; }; } // namespace detail diff --git a/src/_device_impl.h b/src/_device_impl.h index 14fd490..f909d1c 100644 --- a/src/_device_impl.h +++ b/src/_device_impl.h @@ -51,7 +51,10 @@ class DeviceImpl : public Impl { Buffer createBuffer(size_t size, const void *data = nullptr, nxs_uint settings = 0); + Buffer createBuffer(std::vector shape, const void *data = nullptr, + nxs_uint settings = 0); Buffer copyBuffer(Buffer buf, nxs_uint settings = 0); + Buffer reshapeBuffer(Buffer buf, std::vector new_shape); Buffer fillBuffer(float value); }; diff --git a/src/_system_impl.h b/src/_system_impl.h index ec246f7..2258237 100644 --- a/src/_system_impl.h +++ b/src/_system_impl.h @@ -26,6 +26,9 @@ class SystemImpl : public detail::Impl { } Buffer createBuffer(size_t sz, const void *hostData = nullptr, nxs_uint options = 0); + Buffer createBuffer(std::vector shape, const void *hostData = nullptr, + nxs_uint options = 0); + Buffer copyBuffer(Buffer buf, Device dev, nxs_uint options = 0); Info loadCatalog(const std::string &catalogPath); diff --git a/src/buffer.cpp b/src/buffer.cpp index 4f662c3..416e526 100644 --- a/src/buffer.cpp +++ b/src/buffer.cpp @@ -4,6 +4,7 @@ #include #include +#include #include "_buffer_impl.h" #include "_runtime_impl.h" @@ -25,6 +26,14 @@ detail::BufferImpl::BufferImpl(detail::Impl base, nxs_int _devId, size_t _sz, setData(_sz, _hostData); } +detail::BufferImpl::BufferImpl(detail::Impl base, nxs_int _devId, std::vector shape, + const char *_hostData) + : Impl(base), deviceId(_devId), size(0), shape(std::move(shape)), data(nullptr) { + size_t totalSize = 1; + for (auto dim : shape) totalSize *= dim; + setData(totalSize, _hostData); +} + detail::BufferImpl::~BufferImpl() { release(); } void detail::BufferImpl::release() { @@ -152,8 +161,12 @@ nxs_status detail::BufferImpl::copyData(void *_hostBuf, nxs_uint direction) cons return NXS_Success; } +nxs_status detail::BufferImpl::reshape(std::vector new_shape) { + auto *rt = getParentOfType(); + return (nxs_status)rt->runAPIFunction(getId(), new_shape.data(), new_shape.size()); +} + nxs_status detail::BufferImpl::fillData(float value) const { - std::cout << ">>> FILL BUFFER CALLED! <<<" << std::endl; nxs_status return_stat; if (nxs_valid_id(getDeviceId())) { NEXUS_LOG(NXS_LOG_NOTE, "fillData: on device: ", getSize()); @@ -172,6 +185,11 @@ Buffer::Buffer(detail::Impl base, size_t _sz, const void *_hostData) Buffer::Buffer(detail::Impl base, nxs_int _devId, size_t _sz, const void *_hostData) : Object(base, _devId, _sz, (const char *)_hostData) {} +Buffer::Buffer(detail::Impl base, nxs_int _devId, std::vector shape, + const void *_hostData) : Object(base, _devId, + std::accumulate(shape.begin(), shape.end(), (size_t)1, std::multiplies()) * sizeof(float), + (const char *)_hostData) { NEXUS_OBJ_MCALL_VOID(reshape, shape); } + nxs_int Buffer::getDeviceId() const { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, getDeviceId); } std::optional Buffer::getProperty(nxs_int prop) const { @@ -190,4 +208,5 @@ Buffer Buffer::getLocal() const { } nxs_status Buffer::copy(void *_hostBuf, nxs_uint direction) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, copyData, _hostBuf, direction); } +nxs_status Buffer::reshape(std::vector new_shape) {NEXUS_OBJ_MCALL(NXS_InvalidBuffer, reshape, new_shape);} nxs_status Buffer::fill(float fillValue) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, fillData, fillValue); } \ No newline at end of file diff --git a/src/device.cpp b/src/device.cpp index c81c502..16c277f 100644 --- a/src/device.cpp +++ b/src/device.cpp @@ -164,6 +164,17 @@ Buffer detail::DeviceImpl::createBuffer(size_t size, const void *data, return nbuf; } +Buffer detail::DeviceImpl::createBuffer(std::vector shape, const void *data, + nxs_uint settings) { + NEXUS_LOG(NXS_LOG_NOTE, " createBuffer with shape"); + size_t totalSize = 1; + for (auto dim : shape) totalSize *= dim; + APICALL(nxsCreateBuffer, getId(), totalSize, (void *)data, settings); + Buffer nbuf(Impl(this, apiResult, settings), getId(), shape, data); + buffers.add(nbuf); + return nbuf; +} + Buffer detail::DeviceImpl::copyBuffer(Buffer buf, nxs_uint settings) { NEXUS_LOG(NXS_LOG_NOTE, " copyBuffer"); settings |= buf.getSettings() & ~NXS_BufferSettings_OnDevice; @@ -174,6 +185,12 @@ Buffer detail::DeviceImpl::copyBuffer(Buffer buf, nxs_uint settings) { return nbuf; } +Buffer detail::DeviceImpl::reshapeBuffer(Buffer buf, std::vector new_shape) { + NEXUS_LOG(NXS_LOG_NOTE, " reshapeBuffer"); + APICALL(nxsReshapeBuffer, buf.getId(), new_shape.data(), new_shape.size()); + return buf; +} + /////////////////////////////////////////////////////////////////////////////// /// Object wrapper - Device /////////////////////////////////////////////////////////////////////////////// @@ -213,6 +230,10 @@ Buffer Device::createBuffer(size_t size, const void *data, nxs_uint settings) { NEXUS_OBJ_MCALL(Buffer(), createBuffer, size, data, settings); } +Buffer Device::createBuffer(std::vector shape, const void *data, nxs_uint settings) { + NEXUS_OBJ_MCALL(Buffer(), createBuffer, shape, data, settings); +} + Buffer Device::copyBuffer(Buffer buf, nxs_uint settings) { NEXUS_OBJ_MCALL(Buffer(), copyBuffer, buf, settings); } diff --git a/test/cpp/test_buffers.cpp b/test/cpp/test_buffers.cpp new file mode 100644 index 0000000..f7c4823 --- /dev/null +++ b/test/cpp/test_buffers.cpp @@ -0,0 +1,100 @@ +#include +#include + +#include +#include +#include +#include + +#define SUCCESS 0 +#define FAILURE 1 + +int g_argc; +char** g_argv; + +int test_basic_kernel(int argc, char** argv) { + if (argc < 2) { + std::cout << "Usage: " << argv[0] + << "" << std::endl; + return FAILURE; + } + + std::string runtime_name = argv[1]; + + auto sys = nexus::getSystem(); + auto runtime = sys.getRuntime(runtime_name); + if (!runtime) { + std::cout << "No runtimes found" << std::endl; + return FAILURE; + } + + auto devices = runtime.getDevices(); + if (devices.empty()) { + std::cout << "No devices found" << std::endl; + return FAILURE; + } + + auto count = runtime.getDevices().size(); + + std::string runtimeName = runtime.getProp(NP_Name); + + std::cout << std::endl + << "RUNTIME: " << runtimeName << " - " << count << std::endl + << std::endl; + + for (int i = 0; i < count; ++i) { + auto dev = runtime.getDevice(i); + std::cout << " Device: " << dev.getProp(NP_Name) << " - " + << dev.getProp(NP_Architecture) << std::endl; + } + + nexus::Device dev0 = runtime.getDevice(0); + + size_t vsize = 1024; + std::vector vecA(vsize, 1.0); + std::vector vecB(vsize, 2.0); + std::vector vecResult_GPU(vsize, 0.0); + + size_t size = vsize * sizeof(float); + + auto buf0 = dev0.createBuffer(size, vecA.data()); + buf0.copy(vecResult_GPU.data(), NXS_BufferDeviceToHost); + + std::vector shape = {(nxs_int)vsize}; + auto buf1 = dev0.createBuffer(shape, vecB.data()); + + int i = 0; + for (auto v : vecResult_GPU) { + if (v != 1.0) { + std::cout << "Fail: result[" << i << "] = " << v << std::endl; + return FAILURE; + } + ++i; + } + + buf0.reshape({256, 4}); + + std::cout << std::endl << "Test PASSED" << std::endl << std::endl; + + return SUCCESS; +} + +// Create the NexusIntegration test fixture class +class NexusIntegration : public ::testing::Test { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(NexusIntegration, BASIC_KERNEL) { + int result = test_basic_kernel(g_argc, g_argv); + EXPECT_EQ(result, SUCCESS); +} + +int main(int argc, char** argv) { + g_argc = argc; + g_argv = argv; + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}