tensorflow/third_party/mkl_dnn/onednn_acl_threadcap.patch

44 lines
2.0 KiB
Diff

*******************************************************************************
Copyright 2023 Arm Limited and affiliates.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*******************************************************************************
diff --git a/src/cpu/aarch64/acl_thread.cpp b/src/cpu/aarch64/acl_thread.cpp
index fd2c76d01..2d7c76d48 100644
--- a/src/cpu/aarch64/acl_thread.cpp
+++ b/src/cpu/aarch64/acl_thread.cpp
@@ -17,6 +17,8 @@
#include "cpu/aarch64/acl_thread.hpp"
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
#include "cpu/aarch64/acl_threadpool_scheduler.hpp"
+#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
+#include <thread>
#endif
#include "cpu/aarch64/acl_benchmark_scheduler.hpp"
@@ -30,9 +32,10 @@ namespace acl_thread_utils {
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
void acl_thread_bind() {
static std::once_flag flag_once;
- // The threads in Compute Library are bound for the cores 0..max_threads-1
- // dnnl_get_max_threads() returns OMP_NUM_THREADS
- const int max_threads = dnnl_get_max_threads();
+ // Cap the number of threads to 90% of the total core count
+ // to ensure Compute Library doesn't use too much resource
+ int capped_threads = (int)std::floor(0.9*std::thread::hardware_concurrency());
+ const int max_threads = std::min(capped_threads, dnnl_get_max_threads());
// arm_compute::Scheduler does not support concurrent access thus a
// workaround here restricts it to only one call
std::call_once(flag_once, [&]() {