From 350e3237eadb738a0d96295a62f2eed96653c315 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Fri, 20 Dec 2024 00:59:21 -0500 Subject: [PATCH 1/1] fix: avoid race condition for rocm conv algo caching --- onnxruntime/core/providers/rocm/nn/conv.cc | 8 ++++---- onnxruntime/core/providers/rocm/nn/conv.h | 14 ++++++++++++-- .../core/providers/rocm/nn/conv_transpose.cc | 8 ++++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc index d7f47d07a8..98b6b69212 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc @@ -127,7 +127,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (w_dims_changed) { s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_fwd_results.clear(); } ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); @@ -278,7 +277,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); } - if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) { + const std::size_t algo_key = HashConvAlgoKey(x_dims_miopen, w_dims); + if (!s_.cached_benchmark_fwd_results.contains(algo_key)) { miopenConvAlgoPerf_t perf; int algo_count = 1; const ROCMExecutionProvider* rocm_ep = static_cast(this->Info().GetExecutionProvider()); @@ -301,9 +301,9 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) algo_search_workspace.get(), max_ws_size, false)); // Do not do exhaustive algo search. - s_.cached_benchmark_fwd_results.insert(x_dims_miopen, {perf.fwd_algo, perf.memory}); + s_.cached_benchmark_fwd_results.insert(algo_key, {perf.fwd_algo, perf.memory}); } - const auto& perf = s_.cached_benchmark_fwd_results.at(x_dims_miopen); + const auto& perf = s_.cached_benchmark_fwd_results.at(algo_key); s_.fwd_algo = perf.fwd_algo; s_.workspace_bytes = perf.memory; } else { diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h index bc9846203e..b1ca5f8e4b 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.h +++ b/onnxruntime/core/providers/rocm/nn/conv.h @@ -43,6 +43,11 @@ struct vector_hash { } }; +inline std::size_t HashConvAlgoKey(const TensorShapeVector& x_dims, const TensorShapeVector& w_dims) { + vector_hash vh; + return vh(x_dims) ^ vh(w_dims); +} + template , typename KeyEqual = std::equal_to, @@ -52,6 +57,7 @@ class lru_unordered_map { lru_unordered_map(size_t max_size) : max_size_(max_size) {} void insert(const Key& key, const T& value) { + std::lock_guard guard(mutex_); auto it = items_.find(key); if (it != items_.end()) { it->second.value = value; @@ -69,6 +75,7 @@ class lru_unordered_map { } T& at(const Key& key) { + std::lock_guard guard(mutex_); auto it = items_.find(key); if (it == items_.end()) { throw std::out_of_range("There is no such key in cache"); @@ -78,6 +85,7 @@ class lru_unordered_map { } bool contains(const Key& key) const { + std::lock_guard guard(mutex_); return items_.find(key) != items_.end(); } @@ -86,6 +94,7 @@ class lru_unordered_map { } void clear() { + std::lock_guard guard(mutex_); items_.clear(); lru_list_.clear(); } @@ -106,6 +115,7 @@ class lru_unordered_map { size_t max_size_; std::unordered_map items_; list_type lru_list_; + mutable std::mutex mutex_; }; // cached miopen descriptors @@ -148,8 +158,8 @@ struct MiopenConvState { decltype(AlgoPerfType().memory) memory; }; - lru_unordered_map cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS}; - lru_unordered_map cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS}; + lru_unordered_map cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS}; + lru_unordered_map cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS}; // Some properties needed to support asymmetric padded Conv nodes bool post_slicing_required; diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc index 7447113fdf..dea9bf2a05 100644 --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc @@ -76,7 +76,6 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy if (w_dims_changed) { s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_bwd_results.clear(); } ConvTransposeAttributes::Prepare p; @@ -127,7 +126,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy y_data = reinterpret_cast(p.Y->MutableData()); - if (!s_.cached_benchmark_bwd_results.contains(x_dims)) { + const std::size_t algo_key = HashConvAlgoKey(x_dims, w_dims); + if (!s_.cached_benchmark_bwd_results.contains(algo_key)) { IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); miopenConvAlgoPerf_t perf; @@ -147,10 +147,10 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy algo_search_workspace.get(), AlgoSearchWorkspaceSize, false)); - s_.cached_benchmark_bwd_results.insert(x_dims, {perf.bwd_data_algo, perf.memory}); + s_.cached_benchmark_bwd_results.insert(algo_key, {perf.bwd_data_algo, perf.memory}); } - const auto& perf = s_.cached_benchmark_bwd_results.at(x_dims); + const auto& perf = s_.cached_benchmark_bwd_results.at(algo_key); s_.bwd_data_algo = perf.bwd_data_algo; s_.workspace_bytes = perf.memory; } -- 2.43.0