From ace7dc99da01839e5b71cc9014a0fa33b6af2b32 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Fri, 20 Dec 2024 13:55:20 -0500 Subject: [PATCH] acquire lock before any changes can be made --- ...-condition-for-rocm-conv-algo-caching.patch | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/machine-learning/0001-fix-avoid-race-condition-for-rocm-conv-algo-caching.patch b/machine-learning/0001-fix-avoid-race-condition-for-rocm-conv-algo-caching.patch index a8fa9df0e2..728e1d0f84 100644 --- a/machine-learning/0001-fix-avoid-race-condition-for-rocm-conv-algo-caching.patch +++ b/machine-learning/0001-fix-avoid-race-condition-for-rocm-conv-algo-caching.patch @@ -1,4 +1,4 @@ -From e267bc9bab8b3873dba57323ddcd9a9d09a1211e Mon Sep 17 00:00:00 2001 +From c4e6f81b4901b2b8178b964377e8fde7118f0e2e Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Fri, 20 Dec 2024 00:59:21 -0500 Subject: [PATCH] fix: avoid race condition for rocm conv algo caching @@ -8,18 +8,18 @@ Subject: [PATCH] fix: avoid race condition for rocm conv algo caching 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc -index d7f47d07a8..ec438287ac 100644 +index d7f47d07a8..134f8a3b43 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc -@@ -278,6 +278,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) - HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); - } - +@@ -122,6 +122,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) + bool input_dims_changed = (s_.last_x_dims != x_dims); + bool w_dims_changed = (s_.last_w_dims != w_dims); + if (input_dims_changed || w_dims_changed) { + // lock is needed to avoid race condition during algo search + std::lock_guard lock(s_.mutex); - if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) { - miopenConvAlgoPerf_t perf; - int algo_count = 1; + if (input_dims_changed) + s_.last_x_dims = gsl::make_span(x_dims); + -- 2.43.0