mirror of
https://github.com/immich-app/immich.git
synced 2025-01-22 11:42:46 +01:00
40 lines
1.8 KiB
Diff
40 lines
1.8 KiB
Diff
|
From f1ba51901c8af0f426572e217c0db2e752e3e0de Mon Sep 17 00:00:00 2001
|
||
|
From: mertalev <101130780+mertalev@users.noreply.github.com>
|
||
|
Date: Thu, 19 Dec 2024 20:26:40 -0500
|
||
|
Subject: [PATCH] fix: avoid race condition for rocm conv algo caching
|
||
|
|
||
|
---
|
||
|
onnxruntime/core/providers/rocm/nn/conv.cc | 2 ++
|
||
|
onnxruntime/core/providers/rocm/nn/conv_transpose.cc | 2 ++
|
||
|
2 files changed, 4 insertions(+)
|
||
|
|
||
|
diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc
|
||
|
index f99885634b..c9b6c0e93b 100644
|
||
|
--- a/onnxruntime/core/providers/rocm/nn/conv.cc
|
||
|
+++ b/onnxruntime/core/providers/rocm/nn/conv.cc
|
||
|
@@ -278,6 +278,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
|
||
|
HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context)));
|
||
|
}
|
||
|
|
||
|
+ // lock is needed to avoid race condition during algo search
|
||
|
+ std::lock_guard<std::mutex> lock(s_.mutex);
|
||
|
if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) {
|
||
|
miopenConvAlgoPerf_t perf;
|
||
|
int algo_count = 1;
|
||
|
diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
|
||
|
index a6848e90b4..59426cf777 100644
|
||
|
--- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
|
||
|
+++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
|
||
|
@@ -127,6 +127,8 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
|
||
|
|
||
|
y_data = reinterpret_cast<HipT*>(p.Y->MutableData<T>());
|
||
|
|
||
|
+ // lock is needed to avoid race condition during algo search
|
||
|
+ std::lock_guard<std::mutex> lock(s_.mutex);
|
||
|
if (!s_.cached_benchmark_bwd_results.contains(x_dims)) {
|
||
|
IAllocatorUniquePtr<void> algo_search_workspace = GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());
|
||
|
|
||
|
--
|
||
|
2.43.0
|
||
|
|