[xla:cpu] Resolve arguments/results/temp mapping from buffer assignment

PiperOrigin-RevId: 698610190
This commit is contained in:
Eugene Zhulenev 2024-11-20 19:32:57 -08:00 committed by TensorFlower Gardener
parent 9a46acc707
commit 6fb4e335c8
5 changed files with 328 additions and 79 deletions

View File

@ -32,9 +32,7 @@ cc_library(
"//xla/service:executable", "//xla/service:executable",
"//xla/service:hlo_module_config", "//xla/service:hlo_module_config",
"//xla/service/cpu:cpu_compiler_pure", "//xla/service/cpu:cpu_compiler_pure",
"//xla/service/cpu:cpu_executable",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@local_tsl//tsl/platform:casts",
"@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:statusor",
@ -53,6 +51,8 @@ xla_cc_test(
"//xla:xla_data_proto_cc", "//xla:xla_data_proto_cc",
"//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation", "//xla/hlo/builder:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_executable",
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
@ -75,22 +75,33 @@ cc_library(
"//xla/backends/cpu/nanort:nanort_users", "//xla/backends/cpu/nanort:nanort_users",
]), ]),
deps = [ deps = [
"//xla:shape_util",
"//xla:util", "//xla:util",
"//xla/backends/cpu/runtime:buffer_allocations", "//xla/backends/cpu/runtime:buffer_allocations",
"//xla/backends/cpu/runtime:thunk", "//xla/backends/cpu/runtime:thunk",
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/service:computation_layout",
"//xla/service:executable", "//xla/service:executable",
"//xla/service:hlo_value",
"//xla/service:maybe_owning_device_memory", "//xla/service:maybe_owning_device_memory",
"//xla/service/cpu:cpu_executable", "//xla/service/cpu:cpu_executable",
"//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory",
"//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:casts",
"@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/lib:traceme",
"@local_tsl//tsl/profiler/lib:traceme_encode", "@local_tsl//tsl/profiler/lib:traceme_encode",
], ],

View File

@ -26,13 +26,11 @@ limitations under the License.
#include "xla/pjrt/utils.h" #include "xla/pjrt/utils.h"
#include "xla/service/compiler.h" #include "xla/service/compiler.h"
#include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/cpu_compiler.h"
#include "xla/service/cpu/cpu_executable.h"
#include "xla/service/dump.h" #include "xla/service/dump.h"
#include "xla/service/executable.h" #include "xla/service/executable.h"
#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_module_config.h"
#include "xla/shape.h" #include "xla/shape.h"
#include "xla/util.h" #include "xla/util.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/env.h" #include "tsl/platform/env.h"
#include "tsl/platform/logging.h" #include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h" #include "tsl/platform/statusor.h"
@ -85,13 +83,6 @@ absl::StatusOr<std::unique_ptr<NanoRtExecutable>> NanoRtClient::Compile(
compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr, compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr,
compile_options)); compile_options));
// Downcast executable to CpuExecutable to sanity check compilation result.
cpu::CpuExecutable* cpu_executable =
tsl::down_cast<cpu::CpuExecutable*>(executable.get());
if (cpu_executable == nullptr) {
return Internal("Failed to downcast executable to CpuExecutable");
}
return NanoRtExecutable::Create(std::move(executable), intra_op_thread_pool_); return NanoRtExecutable::Create(std::move(executable), intra_op_thread_pool_);
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <string_view>
#include <vector> #include <vector>
#include "absl/container/inlined_vector.h" #include "absl/container/inlined_vector.h"
@ -24,6 +25,8 @@ limitations under the License.
#include "xla/backends/cpu/nanort/nanort_executable.h" #include "xla/backends/cpu/nanort/nanort_executable.h"
#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
@ -38,17 +41,99 @@ limitations under the License.
namespace xla::cpu { namespace xla::cpu {
namespace { namespace {
absl::StatusOr<XlaComputation> CreateAddScalarsComputation() { using Arguments = absl::InlinedVector<NanoRtExecutable::Argument, 8>;
using Results = absl::InlinedVector<NanoRtExecutable::Result, 8>;
TEST(NanoRtClientTest, CompileAndRunScalarComputation) {
constexpr std::string_view hlo = R"(
HloModule add
ENTRY e {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT add = f32[] add(p0, p1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
XlaComputation computation(module->ToProto());
NanoRtClient client;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
client.Compile(computation));
// Storage for executable parameters and results.
alignas(32) float p0_value = 1.0f;
alignas(32) float p1_value = 2.0f;
alignas(32) float r0_value = 0.0f;
// Prepare executable parameters, results and temp storage.
Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
Results results = {{&r0_value, 1}};
NanoRtExecutable::PreallocatedTemp temp = {};
auto event = executable->Execute(arguments, results, temp);
tsl::BlockUntilReady(event);
ASSERT_TRUE(event.IsConcrete());
EXPECT_EQ(r0_value, 3.0f);
}
TEST(NanoRtClientTest, CompileAndRunTupledComputation) {
constexpr std::string_view hlo = R"(
HloModule add_and_mul
ENTRY e {
p = (f32[], f32[]) parameter(0)
p0 = f32[] get-tuple-element(p), index=0
p1 = f32[] get-tuple-element(p), index=1
add = f32[] add(p0, p1)
mul = f32[] multiply(p0, p1)
ROOT add_and_mul = (f32[], f32[]) tuple(add, mul)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
XlaComputation computation(module->ToProto());
NanoRtClient client;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
client.Compile(computation));
// Storage for executable parameters and results.
alignas(32) float p0_value = 2.0f;
alignas(32) float p1_value = 3.0f;
alignas(32) float r0_value = 0.0f;
alignas(32) float r1_value = 0.0f;
// Prepare executable parameters, results and temp storage.
Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
Results results = {{&r0_value, 1}, {&r1_value, 1}};
NanoRtExecutable::PreallocatedTemp temp = {};
auto event = executable->Execute(arguments, results, temp);
tsl::BlockUntilReady(event);
ASSERT_TRUE(event.IsConcrete());
EXPECT_EQ(r0_value, 5.0f);
EXPECT_EQ(r1_value, 6.0f);
}
//===----------------------------------------------------------------------===//
// Performance benchmarks below
//===----------------------------------------------------------------------===//
static absl::StatusOr<XlaComputation> CreateAddScalarsComputation() {
XlaBuilder b("add"); XlaBuilder b("add");
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0"); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0");
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "p1"); auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "p1");
Add(Add(p0, p1), Add(p0, p1)); Add(p0, p1);
return b.Build(); return b.Build();
} }
absl::StatusOr<XlaComputation> CreateFibonacciComputation() { static absl::StatusOr<XlaComputation> CreateFibonacciComputation() {
XlaBuilder b("fib"); XlaBuilder b("fib");
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0"); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0");
@ -64,38 +149,6 @@ absl::StatusOr<XlaComputation> CreateFibonacciComputation() {
return b.Build(); return b.Build();
} }
TEST(NanoRtClientTest, CompileAndRun) {
NanoRtClient client;
TF_ASSERT_OK_AND_ASSIGN(auto computation, CreateAddScalarsComputation());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
client.Compile(computation));
// Storage for executable parameters and results.
alignas(32) float p0_value = 1.0f;
alignas(32) float p1_value = 2.0f;
alignas(32) float result = 0.0f;
// Prepare executable parameters, results and temp storage.
NanoRtExecutable::Argument p0(&p0_value, 1);
NanoRtExecutable::Argument p1(&p1_value, 1);
NanoRtExecutable::Result r0(&result, 1);
NanoRtExecutable::PreallocatedTemp temp = {};
std::vector<NanoRtExecutable::Argument> arguments = {p0, p1};
std::vector<NanoRtExecutable::Result> results = {r0};
auto event = executable->Execute(arguments, results, temp);
tsl::BlockUntilReady(event);
ASSERT_TRUE(event.IsConcrete());
EXPECT_EQ(result, 6.0f);
}
//===----------------------------------------------------------------------===//
// Performance benchmarks below
//===----------------------------------------------------------------------===//
static void BM_NanoRtAddScalars(benchmark::State& state) { static void BM_NanoRtAddScalars(benchmark::State& state) {
NanoRtClient client; NanoRtClient client;
@ -105,17 +158,13 @@ static void BM_NanoRtAddScalars(benchmark::State& state) {
// Storage for executable arguments and results. // Storage for executable arguments and results.
alignas(32) float p0_value = 1.0f; alignas(32) float p0_value = 1.0f;
alignas(32) float p1_value = 2.0f; alignas(32) float p1_value = 2.0f;
alignas(32) float result = 0.0f; alignas(32) float r0_value = 0.0f;
for (auto _ : state) { for (auto _ : state) {
NanoRtExecutable::Argument p0(&p0_value, 1); Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
NanoRtExecutable::Argument p1(&p1_value, 1); Results results = {{&r0_value, 1}};
NanoRtExecutable::Result r0(&result, 1);
NanoRtExecutable::PreallocatedTemp temp = {}; NanoRtExecutable::PreallocatedTemp temp = {};
absl::InlinedVector<NanoRtExecutable::Argument, 2> arguments = {p0, p1};
absl::InlinedVector<NanoRtExecutable::Result, 1> results = {r0};
auto event = (*executable)->Execute(arguments, results, temp); auto event = (*executable)->Execute(arguments, results, temp);
tsl::BlockUntilReady(event); tsl::BlockUntilReady(event);
} }
@ -132,17 +181,13 @@ static void BM_NanoRtFibonacci(benchmark::State& state) {
// Storage for executable arguments and results. // Storage for executable arguments and results.
alignas(32) float p0_value = 1.0f; alignas(32) float p0_value = 1.0f;
alignas(32) float p1_value = 2.0f; alignas(32) float p1_value = 2.0f;
alignas(32) float result = 0.0f; alignas(32) float r0_value = 0.0f;
for (auto _ : state) { for (auto _ : state) {
NanoRtExecutable::Argument p0(&p0_value, 1); Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
NanoRtExecutable::Argument p1(&p1_value, 1); Results results = {{&r0_value, 1}};
NanoRtExecutable::Result r0(&result, 1);
NanoRtExecutable::PreallocatedTemp temp = {}; NanoRtExecutable::PreallocatedTemp temp = {};
absl::InlinedVector<NanoRtExecutable::Argument, 2> arguments = {p0, p1};
absl::InlinedVector<NanoRtExecutable::Result, 1> results = {r0};
auto event = (*executable)->Execute(arguments, results, temp); auto event = (*executable)->Execute(arguments, results, temp);
tsl::BlockUntilReady(event); tsl::BlockUntilReady(event);
} }

View File

@ -15,22 +15,37 @@ limitations under the License.
#include "xla/backends/cpu/nanort/nanort_executable.h" #include "xla/backends/cpu/nanort/nanort_executable.h"
#include <cstddef>
#include <memory> #include <memory>
#include <optional>
#include <utility> #include <utility>
#include <vector>
#include "absl/base/optimization.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h" #include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/buffer_allocations.h"
#include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/thunk.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/computation_layout.h"
#include "xla/service/cpu/cpu_executable.h" #include "xla/service/cpu/cpu_executable.h"
#include "xla/service/executable.h" #include "xla/service/executable.h"
#include "xla/service/hlo_value.h"
#include "xla/service/maybe_owning_device_memory.h" #include "xla/service/maybe_owning_device_memory.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/util.h" #include "xla/util.h"
#include "tsl/platform/casts.h" #include "tsl/platform/casts.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h" #include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/threadpool.h" #include "tsl/platform/threadpool.h"
#include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme.h"
#include "tsl/profiler/lib/traceme_encode.h" #include "tsl/profiler/lib/traceme_encode.h"
@ -40,41 +55,195 @@ namespace xla::cpu {
using ::tsl::profiler::TraceMe; using ::tsl::profiler::TraceMe;
using ::tsl::profiler::TraceMeEncode; using ::tsl::profiler::TraceMeEncode;
using ArgumentIndex = std::pair<size_t, ShapeIndex>;
// Resolves the mapping from argument index to allocation index.
static absl::StatusOr<std::vector<size_t>> ResolveArgumentsMapping(
const HloModule& module, const BufferAssignment& buffer_assignment) {
const ComputationLayout& entry_layout = module.entry_computation_layout();
VLOG(3) << "Resolve executable arguments mapping:";
// Mapping from argument index to flattened executable argument index.
absl::flat_hash_map<ArgumentIndex, size_t> executable_arg_index;
for (size_t i = 0; i < entry_layout.parameter_count(); ++i) {
ShapeUtil::ForEachLeafShape(
entry_layout.parameter_shape(i),
[&](const Shape&, const ShapeIndex& index) {
size_t arg_index = executable_arg_index.size();
executable_arg_index[ArgumentIndex{i, index}] = arg_index;
});
}
std::vector<size_t> argument_to_allocation_index(executable_arg_index.size());
for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
if (allocation.is_entry_computation_parameter()) {
ArgumentIndex idx{allocation.parameter_number(),
allocation.param_shape_index()};
// Skip buffer allocations assigned to non-leaf parameters (tuples).
auto arg_idx = executable_arg_index.find(idx);
if (arg_idx == executable_arg_index.end()) continue;
VLOG(3) << absl::StreamFormat(
" - parameter %d at shape index %s:"
" argument index = %d allocation index = %d",
allocation.parameter_number(),
allocation.param_shape_index().ToString(), arg_idx->second,
allocation.index());
argument_to_allocation_index[arg_idx->second] = allocation.index();
}
}
return argument_to_allocation_index;
}
// Resolves the mapping from result index to allocation index.
static absl::StatusOr<std::vector<size_t>> ResolveResultMapping(
const HloModule& module, const BufferAssignment& buffer_assignment) {
const ComputationLayout& entry_layout = module.entry_computation_layout();
VLOG(3) << "Resolve executable results mapping:";
// Mapping from result index to flattened executable result index.
absl::flat_hash_map<ShapeIndex, size_t> executable_res_index;
ShapeUtil::ForEachLeafShape(entry_layout.result_shape(),
[&](const Shape&, const ShapeIndex& index) {
size_t res_index = executable_res_index.size();
executable_res_index[index] = res_index;
});
const InstructionValueSet& root_value_set =
buffer_assignment.dataflow_analysis().GetInstructionValueSet(
module.entry_computation()->root_instruction());
std::vector<size_t> result_to_allocation_index(executable_res_index.size());
TF_RETURN_IF_ERROR(ShapeUtil::ForEachLeafShapeWithStatus(
entry_layout.result_shape(),
[&](const Shape&, const ShapeIndex& index) -> absl::Status {
// Skip buffer allocations assigned to non-leaf results (tuples).
auto res_idx = executable_res_index.find(index);
if (res_idx == executable_res_index.end()) return absl::OkStatus();
const HloValueSet& sources = root_value_set.element(index);
if (sources.values().size() != 1) {
return Internal(
"Expected a single value for result at shape index %s",
index.ToString());
}
const HloValue* value = sources.values().front();
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
buffer_assignment.GetUniqueSlice(
value->instruction(), value->index()));
DCHECK_EQ(slice.size(), slice.allocation()->size())
<< "Result slice size must match result allocation size";
VLOG(3) << absl::StreamFormat(
" - result at shape index %s:"
" result index = %d allocation index = %d",
index.ToString(), res_idx->second, slice.index());
result_to_allocation_index[res_idx->second] = slice.index();
return absl::OkStatus();
}));
return result_to_allocation_index;
}
static absl::StatusOr<std::optional<size_t>> ResolveTempAllocationIndex(
const BufferAssignment& buffer_assignment) {
VLOG(3) << "Resolve temp allocation index:";
std::optional<size_t> temp_allocation_index;
for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
if (allocation.IsPreallocatedTempBuffer()) {
if (temp_allocation_index.has_value()) {
return Internal("Multiple temp buffer allocations found");
}
VLOG(3) << " - temp buffer allocation index = " << allocation.index();
temp_allocation_index = allocation.index();
}
}
if (!temp_allocation_index.has_value()) {
VLOG(3) << " - no temp buffer allocation found";
}
return temp_allocation_index;
}
absl::StatusOr<std::unique_ptr<NanoRtExecutable>> NanoRtExecutable::Create( absl::StatusOr<std::unique_ptr<NanoRtExecutable>> NanoRtExecutable::Create(
std::unique_ptr<Executable> executable, std::unique_ptr<Executable> executable,
std::shared_ptr<tsl::thread::ThreadPool> thread_pool) { std::shared_ptr<tsl::thread::ThreadPool> thread_pool) {
const HloModule& module = executable->module();
VLOG(3) << "Create NanoRtExecutable: name = " << module.name();
// NanoRtExecutable requires a CPU executable with thunks.
auto* cpu_executable = tsl::down_cast<cpu::CpuExecutable*>(executable.get()); auto* cpu_executable = tsl::down_cast<cpu::CpuExecutable*>(executable.get());
if (cpu_executable == nullptr) {
return Internal("NanoRtExecutable requires CPU executable");
}
if (!cpu_executable->has_thunks()) { if (!cpu_executable->has_thunks()) {
return Internal("NanoRtExecutable requires CPU executable to use thunks"); return Internal("NanoRtExecutable requires CPU executable to use thunks");
} }
return absl::WrapUnique( // Mappings from argument/result index to buffer allocation index.
new NanoRtExecutable(std::move(executable), std::move(thread_pool))); TF_ASSIGN_OR_RETURN(
std::vector<size_t> argument_to_allocation_index,
ResolveArgumentsMapping(module, cpu_executable->buffer_assignment()));
TF_ASSIGN_OR_RETURN(
std::vector<size_t> result_to_allocation_index,
ResolveResultMapping(module, cpu_executable->buffer_assignment()));
TF_ASSIGN_OR_RETURN(
std::optional<size_t> temp_allocation_index,
ResolveTempAllocationIndex(cpu_executable->buffer_assignment()));
const auto& buffer_assignment = cpu_executable->buffer_assignment();
size_t num_allocations = buffer_assignment.Allocations().size();
return absl::WrapUnique(new NanoRtExecutable(
std::move(executable), std::move(thread_pool), num_allocations,
std::move(argument_to_allocation_index),
std::move(result_to_allocation_index), temp_allocation_index));
} }
NanoRtExecutable::NanoRtExecutable( NanoRtExecutable::NanoRtExecutable(
std::unique_ptr<Executable> executable, std::unique_ptr<Executable> executable,
std::shared_ptr<tsl::thread::ThreadPool> thread_pool) std::shared_ptr<tsl::thread::ThreadPool> thread_pool,
size_t num_allocations, std::vector<size_t> argument_to_allocation_index,
std::vector<size_t> result_to_allocation_index,
std::optional<size_t> temp_allocation_index)
: executable_(std::move(executable)), : executable_(std::move(executable)),
thread_pool_(std::move(thread_pool)) {} thread_pool_(std::move(thread_pool)),
num_allocations_(num_allocations),
argument_to_allocation_index_(std::move(argument_to_allocation_index)),
result_to_allocation_index_(std::move(result_to_allocation_index)),
temp_allocation_index_(temp_allocation_index) {}
static se::DeviceMemoryBase ToDeviceMemory( static se::DeviceMemoryBase ToDeviceMemory(
const NanoRtExecutable::Argument& argument) { const NanoRtExecutable::Argument& argument) {
return stream_executor::DeviceMemoryBase( return se::DeviceMemoryBase(
const_cast<void*>(reinterpret_cast<const void*>(argument.data().data())), const_cast<void*>(reinterpret_cast<const void*>(argument.data().data())),
argument.data().size()); argument.data().size());
} }
static se::DeviceMemoryBase ToDeviceMemory( static se::DeviceMemoryBase ToDeviceMemory(
const NanoRtExecutable::Result& result) { const NanoRtExecutable::Result& result) {
return stream_executor::DeviceMemoryBase( return se::DeviceMemoryBase(reinterpret_cast<void*>(result.data().data()),
reinterpret_cast<void*>(result.data().data()), result.data().size()); result.data().size());
} }
static se::DeviceMemoryBase ToDeviceMemory( static se::DeviceMemoryBase ToDeviceMemory(
const NanoRtExecutable::PreallocatedTemp& temp) { const NanoRtExecutable::PreallocatedTemp& temp) {
return stream_executor::DeviceMemoryBase(reinterpret_cast<void*>(temp.data()), return se::DeviceMemoryBase(reinterpret_cast<void*>(temp.data()),
temp.size()); temp.size());
} }
tsl::AsyncValueRef<NanoRtExecutable::ExecuteEvent> NanoRtExecutable::Execute( tsl::AsyncValueRef<NanoRtExecutable::ExecuteEvent> NanoRtExecutable::Execute(
@ -87,21 +256,38 @@ tsl::AsyncValueRef<NanoRtExecutable::ExecuteEvent> NanoRtExecutable::Execute(
auto* executable = tsl::down_cast<cpu::CpuExecutable*>(executable_.get()); auto* executable = tsl::down_cast<cpu::CpuExecutable*>(executable_.get());
// Convert arguments, results, and temp to device memory. size_t num_arguments = argument_to_allocation_index_.size();
absl::InlinedVector<MaybeOwningDeviceMemory, 8> buffer_device_mem; size_t num_results = result_to_allocation_index_.size();
buffer_device_mem.reserve(arguments.size() + results.size() + 1);
for (const Result& result : results) { if (ABSL_PREDICT_FALSE(arguments.size() != num_arguments)) {
buffer_device_mem.emplace_back(ToDeviceMemory(result)); return InvalidArgument("Expected %d arguments, got %d", num_arguments,
arguments.size());
} }
for (const Argument& argument : arguments) {
buffer_device_mem.emplace_back(ToDeviceMemory(argument)); if (ABSL_PREDICT_FALSE(results.size() != num_results)) {
return InvalidArgument("Expected %d results, got %d", num_results,
results.size());
} }
buffer_device_mem.emplace_back(ToDeviceMemory(temp));
// Prepare buffer allocations for arguments, results, and temp. // Prepare buffer allocations for arguments, results, and temp.
cpu::BufferAllocations allocations(buffer_device_mem); absl::InlinedVector<MaybeOwningDeviceMemory, 8> buffers(num_allocations_);
for (size_t i = 0; i < num_arguments; ++i) {
buffers[argument_to_allocation_index_[i]] =
MaybeOwningDeviceMemory(ToDeviceMemory(arguments[i]));
}
for (size_t i = 0; i < num_results; ++i) {
buffers[result_to_allocation_index_[i]] =
MaybeOwningDeviceMemory(ToDeviceMemory(results[i]));
}
if (temp_allocation_index_) {
buffers[*temp_allocation_index_] =
MaybeOwningDeviceMemory(ToDeviceMemory(temp));
}
cpu::BufferAllocations allocations(buffers);
cpu::Thunk::ExecuteParams execute_params = { cpu::Thunk::ExecuteParams execute_params = {
&executable->function_registry(), &executable->function_registry(),
&allocations, &allocations,

View File

@ -19,6 +19,8 @@ limitations under the License.
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <optional>
#include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/types/span.h" #include "absl/types/span.h"
@ -83,10 +85,24 @@ class NanoRtExecutable {
private: private:
NanoRtExecutable(std::unique_ptr<Executable> executable, NanoRtExecutable(std::unique_ptr<Executable> executable,
std::shared_ptr<tsl::thread::ThreadPool> thread_pool); std::shared_ptr<tsl::thread::ThreadPool> thread_pool,
size_t num_allocations,
std::vector<size_t> argument_to_allocation_index,
std::vector<size_t> result_to_allocation_index,
std::optional<size_t> temp_allocation_index);
std::unique_ptr<Executable> executable_; std::unique_ptr<Executable> executable_;
std::shared_ptr<tsl::thread::ThreadPool> thread_pool_; std::shared_ptr<tsl::thread::ThreadPool> thread_pool_;
size_t num_allocations_;
// A mapping from the argument/result index to the index of the corresponding
// allocation (defined by the executable's buffer assignment).
std::vector<size_t> argument_to_allocation_index_;
std::vector<size_t> result_to_allocation_index_;
// Index of the temp allocation.
std::optional<size_t> temp_allocation_index_;
}; };
template <typename T> template <typename T>