Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/nexus-api/_nxs_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ NEXUS_API_FUNC(nxs_status, CopyBuffer,
void* host_ptr,
nxs_uint buffer_settings
)

/************************************************************************
* @def FillBuffer
* @brief Fill buffer on the device with a value
* @return Negative value is an error status.
* Non-negative is the bufferId.
***********************************************************************/
NEXUS_API_FUNC(nxs_status, FillBuffer,
nxs_int buffer_id,
void *value,
size_t size
)

/************************************************************************
* @def ReleaseBuffer
* @brief Release the buffer on the device
Expand Down
2 changes: 2 additions & 0 deletions include/nexus/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Buffer : public Object<detail::BufferImpl> {
Buffer getLocal() const;

nxs_status copy(void *_hostBuf, nxs_uint direction = NXS_BufferDeviceToHost);

nxs_status fill(void *value, size_t size);
};

typedef Objects<Buffer> Buffers;
Expand Down
2 changes: 2 additions & 0 deletions include/nexus/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class Device : public Object<detail::DeviceImpl> {
Buffer createBuffer(size_t size, const void *data = nullptr,
nxs_uint settings = 0);
Buffer copyBuffer(Buffer buf, nxs_uint settings = 0);

Buffer fillBuffer(void *value, size_t size);
};

typedef Objects<Device> Devices;
Expand Down
37 changes: 37 additions & 0 deletions plugins/cuda/cuda_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,43 @@ extern "C" nxs_status NXS_API_CALL nxsCopyBuffer(nxs_int buffer_id,
return NXS_Success;
}

extern "C" nxs_status NXS_API_CALL nxsFillBuffer(nxs_int buffer_id, void *value, size_t value_size) {
auto rt = getRuntime();
auto buffer = rt->get<rt::Buffer>(buffer_id);
if (!buffer || value_size == 0) return NXS_InvalidBuffer;

size_t total_size = buffer->size();
uint8_t* val_ptr = static_cast<uint8_t*>(value);

bool is_zero = true;
for (size_t i = 0; i < value_size; ++i) {
if (val_ptr[i] != 0) {
is_zero = false;
break;
}
}

if (is_zero) {
cudaMemset(buffer->get(), 0, total_size);
return NXS_Success;
}

else {
std::vector<uint8_t> host_buffer(total_size);
for (size_t i = 0; i < total_size; i += value_size) {
size_t to_copy = std::min(value_size, total_size - i);
std::memcpy(host_buffer.data() + i, value, to_copy);
}

CUDA_CHECK(NXS_InvalidBuffer, cudaMemcpy,
buffer->get(),
host_buffer.data(),
total_size,
cudaMemcpyHostToDevice);
}
return NXS_Success;
}

/*
* Release a buffer on the device.
*/
Expand Down
13 changes: 7 additions & 6 deletions scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,20 @@ main() {
make -j$(nproc)

printf "Running CPU tests"
./test/cpp/test_basic_kernel cpu kernel_libs/cpu_kernel.so add_vectors
#./test/cpp/test_basic_kernel cpu kernel_libs/cpu_kernel.so add_vectors

if [[ "$os_type" == "macos" ]]; then
printf "Running macOS test"
#./test/cpp/gpu/nexus_gpu_integration_test metal metal_kernels/kernel.metallib add_vectors

elif [[ "$os_type" == "linux" ]]; then
printf "Running Linux test"
./test/cpp/test_basic_kernel cuda kernel_libs/add_vectors.ptx add_vectors
./test/cpp/test_kernel_catalog cuda kernel_libs/add_vectors.kc add_vectors
./test/cpp/test_smi cuda kernel_libs/add_vectors.ptx add_vectors
./test/cpp/test_multi_stream_sync cuda kernel_libs/add_vectors.ptx add_vectors
./test/cpp/test_graph cuda kernel_libs/add_vectors.ptx add_vectors
./test/cpp/test_buffers cuda
# ./test/cpp/test_basic_kernel cuda kernel_libs/add_vectors.ptx add_vectors
# ./test/cpp/test_kernel_catalog cuda kernel_libs/add_vectors.kc add_vectors
# ./test/cpp/test_smi cuda kernel_libs/add_vectors.ptx add_vectors
# ./test/cpp/test_multi_stream_sync cuda kernel_libs/add_vectors.ptx add_vectors
# ./test/cpp/test_graph cuda kernel_libs/add_vectors.ptx add_vectors
else
printf "Unsupported OS: $os_type"
exit 1
Expand Down
2 changes: 1 addition & 1 deletion src/_buffer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BufferImpl : public Impl {

Buffer getLocal();
nxs_status copyData(void *_hostBuf, nxs_uint direction) const;

nxs_status fillData(void *value, size_t size) const;
std::string print() const;

private:
Expand Down
2 changes: 2 additions & 0 deletions src/_device_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class DeviceImpl : public Impl {
Buffer createBuffer(size_t size, const void *data = nullptr,
nxs_uint settings = 0);
Buffer copyBuffer(Buffer buf, nxs_uint settings = 0);
Buffer fillBuffer(void *value, size_t size);

};

} // namespace detail
Expand Down
12 changes: 12 additions & 0 deletions src/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,17 @@ nxs_status detail::BufferImpl::copyData(void *_hostBuf, nxs_uint direction) cons
return NXS_Success;
}

nxs_status detail::BufferImpl::fillData(void *value, size_t size) const {
nxs_status return_stat;
if (nxs_valid_id(getDeviceId())) {
NEXUS_LOG(NXS_LOG_NOTE, "fillData: on device: ", getSize());
auto *rt = getParentOfType<RuntimeImpl>();
return_stat = (nxs_status)rt->runAPIFunction<NF_nxsFillBuffer>(getId(), value, size);
}
NEXUS_LOG(NXS_LOG_NOTE, "fillData: on host: ", getSize());
return return_stat;
}

///////////////////////////////////////////////////////////////////////////////
Buffer::Buffer(detail::Impl base, size_t _sz, const void *_hostData)
: Object(base, _sz, (const char *)_hostData) {}
Expand All @@ -177,3 +188,4 @@ Buffer Buffer::getLocal() const {
}

nxs_status Buffer::copy(void *_hostBuf, nxs_uint direction) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, copyData, _hostBuf, direction); }
nxs_status Buffer::fill(void *value, size_t size) { NEXUS_OBJ_MCALL(NXS_InvalidBuffer, fillData, value, size); }
54 changes: 54 additions & 0 deletions test/cpp/test_buffers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include <gtest/gtest.h>
#include <nexus.h>
#include <vector>

#define SUCCESS 0
#define FAILURE 1

// DEFINITIONS (Remove 'extern' here so the linker allocates space)
int g_argc;
char** g_argv;

int test_direct_buffer_fill(std::string runtime_name) {
auto sys = nexus::getSystem();
auto runtime = sys.getRuntime(runtime_name);
if (!runtime || runtime.getDevices().empty()) return FAILURE;

auto dev = runtime.getDevice(0);

// // Using a 6-byte pattern to be thorough
std::vector<uint8_t> pattern = {0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE};
size_t buffer_size = 1024;

auto buf = dev.createBuffer(buffer_size, nullptr);

buf.fill(pattern.data(), pattern.size());

// Verify
std::vector<uint8_t> host_out(buffer_size);
buf.copy(host_out.data(), NXS_BufferDeviceToHost);

for (size_t i = 0; i < buffer_size; ++i) {
if (host_out[i] != pattern[i % pattern.size()]) {
return FAILURE;
}
}

return SUCCESS;
}

class BufferTest : public ::testing::Test {};

TEST_F(BufferTest, DIRECT_FILL) {
// Access the now-defined global variables
std::string runtime_name = (g_argc > 1) ? g_argv[1] : "cuda";
EXPECT_EQ(test_direct_buffer_fill(runtime_name), SUCCESS);
}

// Ensure your main function actually sets these globals
int main(int argc, char** argv) {
g_argc = argc;
g_argv = argv;
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
Loading