From f1ba51901c8af0f426572e217c0db2e752e3e0de Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 19 Dec 2024 20:26:40 -0500 Subject: [PATCH] fix: avoid race condition for rocm conv algo caching --- onnxruntime/core/providers/rocm/nn/conv.cc | 2 ++ onnxruntime/core/providers/rocm/nn/conv_transpose.cc | 2 ++ 2 files changed, 4 insertions(+) diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc index f99885634b..c9b6c0e93b 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc @@ -278,6 +278,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); } + // lock is needed to avoid race condition during algo search + std::lock_guard lock(s_.mutex); if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) { miopenConvAlgoPerf_t perf; int algo_count = 1; diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc index a6848e90b4..59426cf777 100644 --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc @@ -127,6 +127,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy y_data = reinterpret_cast(p.Y->MutableData()); + // lock is needed to avoid race condition during algo search + std::lock_guard lock(s_.mutex); if (!s_.cached_benchmark_bwd_results.contains(x_dims)) { IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); -- 2.43.0