mirror of
https://github.com/tensorflow/tensorflow.git
synced 2024-11-21 12:51:46 +00:00
[xla:cpu] Resolve arguments/results/temp mapping from buffer assignment
PiperOrigin-RevId: 698610190
This commit is contained in:
parent
9a46acc707
commit
6fb4e335c8
15
third_party/xla/xla/backends/cpu/nanort/BUILD
vendored
15
third_party/xla/xla/backends/cpu/nanort/BUILD
vendored
@ -32,9 +32,7 @@ cc_library(
|
||||
"//xla/service:executable",
|
||||
"//xla/service:hlo_module_config",
|
||||
"//xla/service/cpu:cpu_compiler_pure",
|
||||
"//xla/service/cpu:cpu_executable",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@local_tsl//tsl/platform:casts",
|
||||
"@local_tsl//tsl/platform:env",
|
||||
"@local_tsl//tsl/platform:logging",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
@ -53,6 +51,8 @@ xla_cc_test(
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/hlo/builder:xla_builder",
|
||||
"//xla/hlo/builder:xla_computation",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/parser:hlo_parser",
|
||||
"//xla/pjrt:pjrt_client",
|
||||
"//xla/pjrt:pjrt_executable",
|
||||
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
|
||||
@ -75,22 +75,33 @@ cc_library(
|
||||
"//xla/backends/cpu/nanort:nanort_users",
|
||||
]),
|
||||
deps = [
|
||||
"//xla:shape_util",
|
||||
"//xla:util",
|
||||
"//xla/backends/cpu/runtime:buffer_allocations",
|
||||
"//xla/backends/cpu/runtime:thunk",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/service:buffer_assignment",
|
||||
"//xla/service:computation_layout",
|
||||
"//xla/service:executable",
|
||||
"//xla/service:hlo_value",
|
||||
"//xla/service:maybe_owning_device_memory",
|
||||
"//xla/service/cpu:cpu_executable",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@local_tsl//tsl/platform:casts",
|
||||
"@local_tsl//tsl/platform:env",
|
||||
"@local_tsl//tsl/platform:errors",
|
||||
"@local_tsl//tsl/platform:logging",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
"@local_tsl//tsl/profiler/lib:traceme_encode",
|
||||
],
|
||||
|
@ -26,13 +26,11 @@ limitations under the License.
|
||||
#include "xla/pjrt/utils.h"
|
||||
#include "xla/service/compiler.h"
|
||||
#include "xla/service/cpu/cpu_compiler.h"
|
||||
#include "xla/service/cpu/cpu_executable.h"
|
||||
#include "xla/service/dump.h"
|
||||
#include "xla/service/executable.h"
|
||||
#include "xla/service/hlo_module_config.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/util.h"
|
||||
#include "tsl/platform/casts.h"
|
||||
#include "tsl/platform/env.h"
|
||||
#include "tsl/platform/logging.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
@ -85,13 +83,6 @@ absl::StatusOr<std::unique_ptr<NanoRtExecutable>> NanoRtClient::Compile(
|
||||
compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr,
|
||||
compile_options));
|
||||
|
||||
// Downcast executable to CpuExecutable to sanity check compilation result.
|
||||
cpu::CpuExecutable* cpu_executable =
|
||||
tsl::down_cast<cpu::CpuExecutable*>(executable.get());
|
||||
if (cpu_executable == nullptr) {
|
||||
return Internal("Failed to downcast executable to CpuExecutable");
|
||||
}
|
||||
|
||||
return NanoRtExecutable::Create(std::move(executable), intra_op_thread_pool_);
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
@ -24,6 +25,8 @@ limitations under the License.
|
||||
#include "xla/backends/cpu/nanort/nanort_executable.h"
|
||||
#include "xla/hlo/builder/xla_builder.h"
|
||||
#include "xla/hlo/builder/xla_computation.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/hlo/parser/hlo_parser.h"
|
||||
#include "xla/pjrt/pjrt_client.h"
|
||||
#include "xla/pjrt/pjrt_executable.h"
|
||||
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
|
||||
@ -38,17 +41,99 @@ limitations under the License.
|
||||
namespace xla::cpu {
|
||||
namespace {
|
||||
|
||||
absl::StatusOr<XlaComputation> CreateAddScalarsComputation() {
|
||||
using Arguments = absl::InlinedVector<NanoRtExecutable::Argument, 8>;
|
||||
using Results = absl::InlinedVector<NanoRtExecutable::Result, 8>;
|
||||
|
||||
TEST(NanoRtClientTest, CompileAndRunScalarComputation) {
|
||||
constexpr std::string_view hlo = R"(
|
||||
HloModule add
|
||||
|
||||
ENTRY e {
|
||||
p0 = f32[] parameter(0)
|
||||
p1 = f32[] parameter(1)
|
||||
ROOT add = f32[] add(p0, p1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
|
||||
XlaComputation computation(module->ToProto());
|
||||
|
||||
NanoRtClient client;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
|
||||
client.Compile(computation));
|
||||
|
||||
// Storage for executable parameters and results.
|
||||
alignas(32) float p0_value = 1.0f;
|
||||
alignas(32) float p1_value = 2.0f;
|
||||
alignas(32) float r0_value = 0.0f;
|
||||
|
||||
// Prepare executable parameters, results and temp storage.
|
||||
Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
|
||||
Results results = {{&r0_value, 1}};
|
||||
NanoRtExecutable::PreallocatedTemp temp = {};
|
||||
|
||||
auto event = executable->Execute(arguments, results, temp);
|
||||
tsl::BlockUntilReady(event);
|
||||
|
||||
ASSERT_TRUE(event.IsConcrete());
|
||||
EXPECT_EQ(r0_value, 3.0f);
|
||||
}
|
||||
|
||||
TEST(NanoRtClientTest, CompileAndRunTupledComputation) {
|
||||
constexpr std::string_view hlo = R"(
|
||||
HloModule add_and_mul
|
||||
|
||||
ENTRY e {
|
||||
p = (f32[], f32[]) parameter(0)
|
||||
p0 = f32[] get-tuple-element(p), index=0
|
||||
p1 = f32[] get-tuple-element(p), index=1
|
||||
add = f32[] add(p0, p1)
|
||||
mul = f32[] multiply(p0, p1)
|
||||
ROOT add_and_mul = (f32[], f32[]) tuple(add, mul)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
|
||||
XlaComputation computation(module->ToProto());
|
||||
|
||||
NanoRtClient client;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
|
||||
client.Compile(computation));
|
||||
|
||||
// Storage for executable parameters and results.
|
||||
alignas(32) float p0_value = 2.0f;
|
||||
alignas(32) float p1_value = 3.0f;
|
||||
alignas(32) float r0_value = 0.0f;
|
||||
alignas(32) float r1_value = 0.0f;
|
||||
|
||||
// Prepare executable parameters, results and temp storage.
|
||||
Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
|
||||
Results results = {{&r0_value, 1}, {&r1_value, 1}};
|
||||
NanoRtExecutable::PreallocatedTemp temp = {};
|
||||
|
||||
auto event = executable->Execute(arguments, results, temp);
|
||||
tsl::BlockUntilReady(event);
|
||||
|
||||
ASSERT_TRUE(event.IsConcrete());
|
||||
EXPECT_EQ(r0_value, 5.0f);
|
||||
EXPECT_EQ(r1_value, 6.0f);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Performance benchmarks below
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static absl::StatusOr<XlaComputation> CreateAddScalarsComputation() {
|
||||
XlaBuilder b("add");
|
||||
|
||||
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0");
|
||||
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "p1");
|
||||
Add(Add(p0, p1), Add(p0, p1));
|
||||
Add(p0, p1);
|
||||
|
||||
return b.Build();
|
||||
}
|
||||
|
||||
absl::StatusOr<XlaComputation> CreateFibonacciComputation() {
|
||||
static absl::StatusOr<XlaComputation> CreateFibonacciComputation() {
|
||||
XlaBuilder b("fib");
|
||||
|
||||
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0");
|
||||
@ -64,38 +149,6 @@ absl::StatusOr<XlaComputation> CreateFibonacciComputation() {
|
||||
return b.Build();
|
||||
}
|
||||
|
||||
TEST(NanoRtClientTest, CompileAndRun) {
|
||||
NanoRtClient client;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto computation, CreateAddScalarsComputation());
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
|
||||
client.Compile(computation));
|
||||
|
||||
// Storage for executable parameters and results.
|
||||
alignas(32) float p0_value = 1.0f;
|
||||
alignas(32) float p1_value = 2.0f;
|
||||
alignas(32) float result = 0.0f;
|
||||
|
||||
// Prepare executable parameters, results and temp storage.
|
||||
NanoRtExecutable::Argument p0(&p0_value, 1);
|
||||
NanoRtExecutable::Argument p1(&p1_value, 1);
|
||||
NanoRtExecutable::Result r0(&result, 1);
|
||||
NanoRtExecutable::PreallocatedTemp temp = {};
|
||||
|
||||
std::vector<NanoRtExecutable::Argument> arguments = {p0, p1};
|
||||
std::vector<NanoRtExecutable::Result> results = {r0};
|
||||
|
||||
auto event = executable->Execute(arguments, results, temp);
|
||||
tsl::BlockUntilReady(event);
|
||||
|
||||
ASSERT_TRUE(event.IsConcrete());
|
||||
EXPECT_EQ(result, 6.0f);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Performance benchmarks below
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void BM_NanoRtAddScalars(benchmark::State& state) {
|
||||
NanoRtClient client;
|
||||
|
||||
@ -105,17 +158,13 @@ static void BM_NanoRtAddScalars(benchmark::State& state) {
|
||||
// Storage for executable arguments and results.
|
||||
alignas(32) float p0_value = 1.0f;
|
||||
alignas(32) float p1_value = 2.0f;
|
||||
alignas(32) float result = 0.0f;
|
||||
alignas(32) float r0_value = 0.0f;
|
||||
|
||||
for (auto _ : state) {
|
||||
NanoRtExecutable::Argument p0(&p0_value, 1);
|
||||
NanoRtExecutable::Argument p1(&p1_value, 1);
|
||||
NanoRtExecutable::Result r0(&result, 1);
|
||||
Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
|
||||
Results results = {{&r0_value, 1}};
|
||||
NanoRtExecutable::PreallocatedTemp temp = {};
|
||||
|
||||
absl::InlinedVector<NanoRtExecutable::Argument, 2> arguments = {p0, p1};
|
||||
absl::InlinedVector<NanoRtExecutable::Result, 1> results = {r0};
|
||||
|
||||
auto event = (*executable)->Execute(arguments, results, temp);
|
||||
tsl::BlockUntilReady(event);
|
||||
}
|
||||
@ -132,17 +181,13 @@ static void BM_NanoRtFibonacci(benchmark::State& state) {
|
||||
// Storage for executable arguments and results.
|
||||
alignas(32) float p0_value = 1.0f;
|
||||
alignas(32) float p1_value = 2.0f;
|
||||
alignas(32) float result = 0.0f;
|
||||
alignas(32) float r0_value = 0.0f;
|
||||
|
||||
for (auto _ : state) {
|
||||
NanoRtExecutable::Argument p0(&p0_value, 1);
|
||||
NanoRtExecutable::Argument p1(&p1_value, 1);
|
||||
NanoRtExecutable::Result r0(&result, 1);
|
||||
Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
|
||||
Results results = {{&r0_value, 1}};
|
||||
NanoRtExecutable::PreallocatedTemp temp = {};
|
||||
|
||||
absl::InlinedVector<NanoRtExecutable::Argument, 2> arguments = {p0, p1};
|
||||
absl::InlinedVector<NanoRtExecutable::Result, 1> results = {r0};
|
||||
|
||||
auto event = (*executable)->Execute(arguments, results, temp);
|
||||
tsl::BlockUntilReady(event);
|
||||
}
|
||||
|
@ -15,22 +15,37 @@ limitations under the License.
|
||||
|
||||
#include "xla/backends/cpu/nanort/nanort_executable.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/runtime/buffer_allocations.h"
|
||||
#include "xla/backends/cpu/runtime/thunk.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/computation_layout.h"
|
||||
#include "xla/service/cpu/cpu_executable.h"
|
||||
#include "xla/service/executable.h"
|
||||
#include "xla/service/hlo_value.h"
|
||||
#include "xla/service/maybe_owning_device_memory.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
#include "xla/util.h"
|
||||
#include "tsl/platform/casts.h"
|
||||
#include "tsl/platform/errors.h"
|
||||
#include "tsl/platform/logging.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
#include "tsl/platform/threadpool.h"
|
||||
#include "tsl/profiler/lib/traceme.h"
|
||||
#include "tsl/profiler/lib/traceme_encode.h"
|
||||
@ -40,41 +55,195 @@ namespace xla::cpu {
|
||||
using ::tsl::profiler::TraceMe;
|
||||
using ::tsl::profiler::TraceMeEncode;
|
||||
|
||||
using ArgumentIndex = std::pair<size_t, ShapeIndex>;
|
||||
|
||||
// Resolves the mapping from argument index to allocation index.
|
||||
static absl::StatusOr<std::vector<size_t>> ResolveArgumentsMapping(
|
||||
const HloModule& module, const BufferAssignment& buffer_assignment) {
|
||||
const ComputationLayout& entry_layout = module.entry_computation_layout();
|
||||
|
||||
VLOG(3) << "Resolve executable arguments mapping:";
|
||||
|
||||
// Mapping from argument index to flattened executable argument index.
|
||||
absl::flat_hash_map<ArgumentIndex, size_t> executable_arg_index;
|
||||
for (size_t i = 0; i < entry_layout.parameter_count(); ++i) {
|
||||
ShapeUtil::ForEachLeafShape(
|
||||
entry_layout.parameter_shape(i),
|
||||
[&](const Shape&, const ShapeIndex& index) {
|
||||
size_t arg_index = executable_arg_index.size();
|
||||
executable_arg_index[ArgumentIndex{i, index}] = arg_index;
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<size_t> argument_to_allocation_index(executable_arg_index.size());
|
||||
for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
ArgumentIndex idx{allocation.parameter_number(),
|
||||
allocation.param_shape_index()};
|
||||
|
||||
// Skip buffer allocations assigned to non-leaf parameters (tuples).
|
||||
auto arg_idx = executable_arg_index.find(idx);
|
||||
if (arg_idx == executable_arg_index.end()) continue;
|
||||
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
" - parameter %d at shape index %s:"
|
||||
" argument index = %d allocation index = %d",
|
||||
allocation.parameter_number(),
|
||||
allocation.param_shape_index().ToString(), arg_idx->second,
|
||||
allocation.index());
|
||||
|
||||
argument_to_allocation_index[arg_idx->second] = allocation.index();
|
||||
}
|
||||
}
|
||||
|
||||
return argument_to_allocation_index;
|
||||
}
|
||||
|
||||
// Resolves the mapping from result index to allocation index.
|
||||
static absl::StatusOr<std::vector<size_t>> ResolveResultMapping(
|
||||
const HloModule& module, const BufferAssignment& buffer_assignment) {
|
||||
const ComputationLayout& entry_layout = module.entry_computation_layout();
|
||||
|
||||
VLOG(3) << "Resolve executable results mapping:";
|
||||
|
||||
// Mapping from result index to flattened executable result index.
|
||||
absl::flat_hash_map<ShapeIndex, size_t> executable_res_index;
|
||||
ShapeUtil::ForEachLeafShape(entry_layout.result_shape(),
|
||||
[&](const Shape&, const ShapeIndex& index) {
|
||||
size_t res_index = executable_res_index.size();
|
||||
executable_res_index[index] = res_index;
|
||||
});
|
||||
|
||||
const InstructionValueSet& root_value_set =
|
||||
buffer_assignment.dataflow_analysis().GetInstructionValueSet(
|
||||
module.entry_computation()->root_instruction());
|
||||
|
||||
std::vector<size_t> result_to_allocation_index(executable_res_index.size());
|
||||
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachLeafShapeWithStatus(
|
||||
entry_layout.result_shape(),
|
||||
[&](const Shape&, const ShapeIndex& index) -> absl::Status {
|
||||
// Skip buffer allocations assigned to non-leaf results (tuples).
|
||||
auto res_idx = executable_res_index.find(index);
|
||||
if (res_idx == executable_res_index.end()) return absl::OkStatus();
|
||||
|
||||
const HloValueSet& sources = root_value_set.element(index);
|
||||
|
||||
if (sources.values().size() != 1) {
|
||||
return Internal(
|
||||
"Expected a single value for result at shape index %s",
|
||||
index.ToString());
|
||||
}
|
||||
|
||||
const HloValue* value = sources.values().front();
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
|
||||
buffer_assignment.GetUniqueSlice(
|
||||
value->instruction(), value->index()));
|
||||
|
||||
DCHECK_EQ(slice.size(), slice.allocation()->size())
|
||||
<< "Result slice size must match result allocation size";
|
||||
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
" - result at shape index %s:"
|
||||
" result index = %d allocation index = %d",
|
||||
index.ToString(), res_idx->second, slice.index());
|
||||
|
||||
result_to_allocation_index[res_idx->second] = slice.index();
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
|
||||
return result_to_allocation_index;
|
||||
}
|
||||
|
||||
static absl::StatusOr<std::optional<size_t>> ResolveTempAllocationIndex(
|
||||
const BufferAssignment& buffer_assignment) {
|
||||
VLOG(3) << "Resolve temp allocation index:";
|
||||
|
||||
std::optional<size_t> temp_allocation_index;
|
||||
for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
|
||||
if (allocation.IsPreallocatedTempBuffer()) {
|
||||
if (temp_allocation_index.has_value()) {
|
||||
return Internal("Multiple temp buffer allocations found");
|
||||
}
|
||||
VLOG(3) << " - temp buffer allocation index = " << allocation.index();
|
||||
temp_allocation_index = allocation.index();
|
||||
}
|
||||
}
|
||||
|
||||
if (!temp_allocation_index.has_value()) {
|
||||
VLOG(3) << " - no temp buffer allocation found";
|
||||
}
|
||||
|
||||
return temp_allocation_index;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<NanoRtExecutable>> NanoRtExecutable::Create(
|
||||
std::unique_ptr<Executable> executable,
|
||||
std::shared_ptr<tsl::thread::ThreadPool> thread_pool) {
|
||||
const HloModule& module = executable->module();
|
||||
|
||||
VLOG(3) << "Create NanoRtExecutable: name = " << module.name();
|
||||
|
||||
// NanoRtExecutable requires a CPU executable with thunks.
|
||||
auto* cpu_executable = tsl::down_cast<cpu::CpuExecutable*>(executable.get());
|
||||
if (cpu_executable == nullptr) {
|
||||
return Internal("NanoRtExecutable requires CPU executable");
|
||||
}
|
||||
if (!cpu_executable->has_thunks()) {
|
||||
return Internal("NanoRtExecutable requires CPU executable to use thunks");
|
||||
}
|
||||
|
||||
return absl::WrapUnique(
|
||||
new NanoRtExecutable(std::move(executable), std::move(thread_pool)));
|
||||
// Mappings from argument/result index to buffer allocation index.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<size_t> argument_to_allocation_index,
|
||||
ResolveArgumentsMapping(module, cpu_executable->buffer_assignment()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<size_t> result_to_allocation_index,
|
||||
ResolveResultMapping(module, cpu_executable->buffer_assignment()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::optional<size_t> temp_allocation_index,
|
||||
ResolveTempAllocationIndex(cpu_executable->buffer_assignment()));
|
||||
|
||||
const auto& buffer_assignment = cpu_executable->buffer_assignment();
|
||||
size_t num_allocations = buffer_assignment.Allocations().size();
|
||||
|
||||
return absl::WrapUnique(new NanoRtExecutable(
|
||||
std::move(executable), std::move(thread_pool), num_allocations,
|
||||
std::move(argument_to_allocation_index),
|
||||
std::move(result_to_allocation_index), temp_allocation_index));
|
||||
}
|
||||
|
||||
NanoRtExecutable::NanoRtExecutable(
|
||||
std::unique_ptr<Executable> executable,
|
||||
std::shared_ptr<tsl::thread::ThreadPool> thread_pool)
|
||||
std::shared_ptr<tsl::thread::ThreadPool> thread_pool,
|
||||
size_t num_allocations, std::vector<size_t> argument_to_allocation_index,
|
||||
std::vector<size_t> result_to_allocation_index,
|
||||
std::optional<size_t> temp_allocation_index)
|
||||
: executable_(std::move(executable)),
|
||||
thread_pool_(std::move(thread_pool)) {}
|
||||
thread_pool_(std::move(thread_pool)),
|
||||
num_allocations_(num_allocations),
|
||||
argument_to_allocation_index_(std::move(argument_to_allocation_index)),
|
||||
result_to_allocation_index_(std::move(result_to_allocation_index)),
|
||||
temp_allocation_index_(temp_allocation_index) {}
|
||||
|
||||
static se::DeviceMemoryBase ToDeviceMemory(
|
||||
const NanoRtExecutable::Argument& argument) {
|
||||
return stream_executor::DeviceMemoryBase(
|
||||
return se::DeviceMemoryBase(
|
||||
const_cast<void*>(reinterpret_cast<const void*>(argument.data().data())),
|
||||
argument.data().size());
|
||||
}
|
||||
|
||||
static se::DeviceMemoryBase ToDeviceMemory(
|
||||
const NanoRtExecutable::Result& result) {
|
||||
return stream_executor::DeviceMemoryBase(
|
||||
reinterpret_cast<void*>(result.data().data()), result.data().size());
|
||||
return se::DeviceMemoryBase(reinterpret_cast<void*>(result.data().data()),
|
||||
result.data().size());
|
||||
}
|
||||
|
||||
static se::DeviceMemoryBase ToDeviceMemory(
|
||||
const NanoRtExecutable::PreallocatedTemp& temp) {
|
||||
return stream_executor::DeviceMemoryBase(reinterpret_cast<void*>(temp.data()),
|
||||
temp.size());
|
||||
return se::DeviceMemoryBase(reinterpret_cast<void*>(temp.data()),
|
||||
temp.size());
|
||||
}
|
||||
|
||||
tsl::AsyncValueRef<NanoRtExecutable::ExecuteEvent> NanoRtExecutable::Execute(
|
||||
@ -87,21 +256,38 @@ tsl::AsyncValueRef<NanoRtExecutable::ExecuteEvent> NanoRtExecutable::Execute(
|
||||
|
||||
auto* executable = tsl::down_cast<cpu::CpuExecutable*>(executable_.get());
|
||||
|
||||
// Convert arguments, results, and temp to device memory.
|
||||
absl::InlinedVector<MaybeOwningDeviceMemory, 8> buffer_device_mem;
|
||||
buffer_device_mem.reserve(arguments.size() + results.size() + 1);
|
||||
size_t num_arguments = argument_to_allocation_index_.size();
|
||||
size_t num_results = result_to_allocation_index_.size();
|
||||
|
||||
for (const Result& result : results) {
|
||||
buffer_device_mem.emplace_back(ToDeviceMemory(result));
|
||||
if (ABSL_PREDICT_FALSE(arguments.size() != num_arguments)) {
|
||||
return InvalidArgument("Expected %d arguments, got %d", num_arguments,
|
||||
arguments.size());
|
||||
}
|
||||
for (const Argument& argument : arguments) {
|
||||
buffer_device_mem.emplace_back(ToDeviceMemory(argument));
|
||||
|
||||
if (ABSL_PREDICT_FALSE(results.size() != num_results)) {
|
||||
return InvalidArgument("Expected %d results, got %d", num_results,
|
||||
results.size());
|
||||
}
|
||||
buffer_device_mem.emplace_back(ToDeviceMemory(temp));
|
||||
|
||||
// Prepare buffer allocations for arguments, results, and temp.
|
||||
cpu::BufferAllocations allocations(buffer_device_mem);
|
||||
absl::InlinedVector<MaybeOwningDeviceMemory, 8> buffers(num_allocations_);
|
||||
|
||||
for (size_t i = 0; i < num_arguments; ++i) {
|
||||
buffers[argument_to_allocation_index_[i]] =
|
||||
MaybeOwningDeviceMemory(ToDeviceMemory(arguments[i]));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_results; ++i) {
|
||||
buffers[result_to_allocation_index_[i]] =
|
||||
MaybeOwningDeviceMemory(ToDeviceMemory(results[i]));
|
||||
}
|
||||
|
||||
if (temp_allocation_index_) {
|
||||
buffers[*temp_allocation_index_] =
|
||||
MaybeOwningDeviceMemory(ToDeviceMemory(temp));
|
||||
}
|
||||
|
||||
cpu::BufferAllocations allocations(buffers);
|
||||
cpu::Thunk::ExecuteParams execute_params = {
|
||||
&executable->function_registry(),
|
||||
&allocations,
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/types/span.h"
|
||||
@ -83,10 +85,24 @@ class NanoRtExecutable {
|
||||
|
||||
private:
|
||||
NanoRtExecutable(std::unique_ptr<Executable> executable,
|
||||
std::shared_ptr<tsl::thread::ThreadPool> thread_pool);
|
||||
std::shared_ptr<tsl::thread::ThreadPool> thread_pool,
|
||||
size_t num_allocations,
|
||||
std::vector<size_t> argument_to_allocation_index,
|
||||
std::vector<size_t> result_to_allocation_index,
|
||||
std::optional<size_t> temp_allocation_index);
|
||||
|
||||
std::unique_ptr<Executable> executable_;
|
||||
std::shared_ptr<tsl::thread::ThreadPool> thread_pool_;
|
||||
|
||||
size_t num_allocations_;
|
||||
|
||||
// A mapping from the argument/result index to the index of the corresponding
|
||||
// allocation (defined by the executable's buffer assignment).
|
||||
std::vector<size_t> argument_to_allocation_index_;
|
||||
std::vector<size_t> result_to_allocation_index_;
|
||||
|
||||
// Index of the temp allocation.
|
||||
std::optional<size_t> temp_allocation_index_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
Loading…
Reference in New Issue
Block a user