From 13b6a03325ff1b27bcd795a575c8429470deccf6 Mon Sep 17 00:00:00 2001 From: Lukasz Dorau Date: Wed, 17 Dec 2025 10:08:02 +0100 Subject: [PATCH] Add hDevice argument to ur_kernel_handle_t_::setArgValue() Add hDevice argument to ur_kernel_handle_t_::setArgValue() to make it possible to set an argument only on the specified device. Signed-off-by: Lukasz Dorau --- .../level_zero/v2/command_list_manager.cpp | 9 ++++---- .../source/adapters/level_zero/v2/kernel.cpp | 21 +++++++++++++------ .../source/adapters/level_zero/v2/kernel.hpp | 5 +++-- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp index 9d463f870e895..088510db822be 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp @@ -1123,19 +1123,20 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpOld( wait_list_view &waitListView, ur_event_handle_t phEvent) { { std::scoped_lock guard(hKernel->Mutex); + ur_device_handle_t hDevice = this->hDevice.get(); for (uint32_t argIndex = 0; argIndex < numArgs; argIndex++) { switch (pArgs[argIndex].type) { case UR_EXP_KERNEL_ARG_TYPE_LOCAL: - UR_CALL(hKernel->setArgValue(pArgs[argIndex].index, + UR_CALL(hKernel->setArgValue(hDevice, pArgs[argIndex].index, pArgs[argIndex].size, nullptr, nullptr)); break; case UR_EXP_KERNEL_ARG_TYPE_VALUE: - UR_CALL(hKernel->setArgValue(pArgs[argIndex].index, + UR_CALL(hKernel->setArgValue(hDevice, pArgs[argIndex].index, pArgs[argIndex].size, nullptr, pArgs[argIndex].value.value)); break; case UR_EXP_KERNEL_ARG_TYPE_POINTER: - UR_CALL(hKernel->setArgPointer(pArgs[argIndex].index, nullptr, + UR_CALL(hKernel->setArgPointer(hDevice, pArgs[argIndex].index, nullptr, pArgs[argIndex].value.pointer)); break; case UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ: @@ -1147,7 +1148,7 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExpOld( break; case UR_EXP_KERNEL_ARG_TYPE_SAMPLER: { UR_CALL( - hKernel->setArgValue(argIndex, sizeof(void *), nullptr, + hKernel->setArgValue(hDevice, argIndex, sizeof(void *), nullptr, &pArgs[argIndex].value.sampler->ZeSampler)); break; } diff --git a/unified-runtime/source/adapters/level_zero/v2/kernel.cpp b/unified-runtime/source/adapters/level_zero/v2/kernel.cpp index 90054a286ec13..66361aa8b2ad4 100644 --- a/unified-runtime/source/adapters/level_zero/v2/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/kernel.cpp @@ -194,13 +194,20 @@ ur_kernel_handle_t_::getProperties(ur_device_handle_t hDevice) const { } ur_result_t ur_kernel_handle_t_::setArgValue( - uint32_t argIndex, size_t argSize, + ur_device_handle_t hDevice, uint32_t argIndex, size_t argSize, const ur_kernel_arg_value_properties_t * /*pProperties*/, const void *pArgValue) { if (argIndex > zeCommonProperties.numKernelArgs - 1) { return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX; } + if (hDevice) { // Set argument only on the specified device + auto &deviceKernel = deviceKernels[deviceIndex(hDevice)].value(); + UR_CALL(setArgValueOnZeKernel(deviceKernel.hKernel.get(), argIndex, argSize, + pArgValue)); + return UR_RESULT_SUCCESS; + } + for (auto &singleDeviceKernel : deviceKernels) { if (!singleDeviceKernel.has_value()) { continue; @@ -213,12 +220,13 @@ ur_result_t ur_kernel_handle_t_::setArgValue( } ur_result_t ur_kernel_handle_t_::setArgPointer( - uint32_t argIndex, + ur_device_handle_t hDevice, uint32_t argIndex, const ur_kernel_arg_pointer_properties_t * /*pProperties*/, const void *pArgValue) { // KernelSetArgValue is expecting a pointer to the argument - return setArgValue(argIndex, sizeof(const void *), nullptr, &pArgValue); + return setArgValue(hDevice, argIndex, sizeof(const void *), nullptr, + &pArgValue); } ur_program_handle_t ur_kernel_handle_t_::getProgramHandle() const { @@ -429,7 +437,8 @@ ur_result_t urKernelSetArgValue( TRACK_SCOPE_LATENCY("urKernelSetArgValue"); std::scoped_lock guard(hKernel->Mutex); - return hKernel->setArgValue(argIndex, argSize, pProperties, pArgValue); + return hKernel->setArgValue(nullptr, argIndex, argSize, pProperties, + pArgValue); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -492,7 +501,7 @@ ur_result_t urKernelSetArgLocal( std::scoped_lock guard(hKernel->Mutex); - return hKernel->setArgValue(argIndex, argSize, nullptr, nullptr); + return hKernel->setArgValue(nullptr, argIndex, argSize, nullptr, nullptr); } catch (...) { return exceptionToResult(std::current_exception()); } @@ -736,7 +745,7 @@ ur_result_t urKernelSetArgSampler( ur_sampler_handle_t hArgValue) try { TRACK_SCOPE_LATENCY("urKernelSetArgSampler"); std::scoped_lock guard(hKernel->Mutex); - return hKernel->setArgValue(argIndex, sizeof(void *), nullptr, + return hKernel->setArgValue(nullptr, argIndex, sizeof(void *), nullptr, &hArgValue->ZeSampler); } catch (...) { return exceptionToResult(std::current_exception()); diff --git a/unified-runtime/source/adapters/level_zero/v2/kernel.hpp b/unified-runtime/source/adapters/level_zero/v2/kernel.hpp index 2250e8a6603fb..b9f559c5c80b0 100644 --- a/unified-runtime/source/adapters/level_zero/v2/kernel.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/kernel.hpp @@ -67,13 +67,14 @@ struct ur_kernel_handle_t_ : ur_object { const ze_kernel_properties_t &getProperties(ur_device_handle_t hDevice) const; // Implementation of urKernelSetArgValue. - ur_result_t setArgValue(uint32_t argIndex, size_t argSize, + ur_result_t setArgValue(ur_device_handle_t hDevice, uint32_t argIndex, + size_t argSize, const ur_kernel_arg_value_properties_t *pProperties, const void *pArgValue); // Implementation of urKernelSetArgPointer. ur_result_t - setArgPointer(uint32_t argIndex, + setArgPointer(ur_device_handle_t hDevice, uint32_t argIndex, const ur_kernel_arg_pointer_properties_t *pProperties, const void *pArgValue);