diff --git a/src/blas/rocBLAS.jl b/src/blas/rocBLAS.jl index fde133ae6..a7d0d4c3d 100644 --- a/src/blas/rocBLAS.jl +++ b/src/blas/rocBLAS.jl @@ -51,7 +51,12 @@ function lib_state() (nh, s) -> rocblas_set_stream(nh, s)) end -handle() = lib_state().handle +function handle() + # Consume any sticky HIP error from prior GPU work in this context before + # any rocblas call. See rocSPARSE.handle for the rationale. + HIP.clear_last_error() + return lib_state().handle +end stream() = lib_state().stream end diff --git a/src/dnn/MIOpen.jl b/src/dnn/MIOpen.jl index b593ff0e8..eca71e6d0 100644 --- a/src/dnn/MIOpen.jl +++ b/src/dnn/MIOpen.jl @@ -83,7 +83,12 @@ lib_state() = library_state( create_handle, destroy_handle!, (nh, s) -> miopenSetStream(nh, s)) -handle() = lib_state().handle +function handle() + # Consume any sticky HIP error from prior GPU work in this context before + # any MIOpen call. See rocSPARSE.handle for the rationale. + HIP.clear_last_error() + return lib_state().handle +end stream() = lib_state().stream include("descriptors.jl") diff --git a/src/hip/error.jl b/src/hip/error.jl index d9a084720..32a679342 100644 --- a/src/hip/error.jl +++ b/src/hip/error.jl @@ -145,3 +145,21 @@ function check(err::hipError_t) throw(HIPError(err)) end end + +""" + clear_last_error() + +Consume any sticky HIP error on the current context without throwing. + +Some HIP operations (e.g. `hipDeviceSynchronize`) surface errors that were set +by previous GPU work (e.g. a kernel exception). These errors persist on the +context until consumed. Call this before creating library handles to prevent +stale errors from causing spurious failures in unrelated operations. +""" +function clear_last_error() + err = @gcsafe_ccall libhip.hipGetLastError()::hipError_t + if err != hipSuccess + @debug "Cleared sticky HIP error before library call" error=HIPError(err) + end + return +end diff --git a/src/rand/rocRAND.jl b/src/rand/rocRAND.jl index 4a665d836..89395e692 100644 --- a/src/rand/rocRAND.jl +++ b/src/rand/rocRAND.jl @@ -34,7 +34,12 @@ lib_state() = library_state( Random.seed!(nh) end) -handle() = lib_state().handle +function handle() + # Consume any sticky HIP error from prior GPU work in this context before + # any rocrand call. See rocSPARSE.handle for the rationale. + HIP.clear_last_error() + return lib_state().handle +end stream() = lib_state().stream end diff --git a/src/sparse/rocSPARSE.jl b/src/sparse/rocSPARSE.jl index 1ca043101..892a0fe60 100644 --- a/src/sparse/rocSPARSE.jl +++ b/src/sparse/rocSPARSE.jl @@ -38,7 +38,13 @@ lib_state() = library_state( :rocSPARSE, rocsparse_handle, IDLE_HANDLES, create_handle, rocsparse_destroy_handle, rocsparse_set_stream) -handle() = lib_state().handle +function handle() + # Consume any sticky HIP error from prior GPU work in this context before + # any rocsparse call. rocsparse operations internally synchronize and will + # surface a pending hipErrorLaunchFailure as rocsparse_status_internal_error. + HIP.clear_last_error() + return lib_state().handle +end stream() = lib_state().stream function version() diff --git a/test/device/hostcall.jl b/test/device/hostcall.jl index b37c1a9ce..4d9c329c7 100644 --- a/test/device/hostcall.jl +++ b/test/device/hostcall.jl @@ -39,7 +39,7 @@ end RB = ROCArray(zeros(Float32, 1)) dref = Ref{Bool}(false) - @test_logs (:error, "HostCall error") begin + @test_logs (:error, "HostCall error") match_mode=:any begin hc = HostCallHolder(Nothing, Tuple{}) do error("Some error") dref[] = true