mirror of
https://github.com/tensorflow/tensorflow.git
synced 2024-11-21 21:05:19 +00:00
[xla:cpu] Resolve arguments/results/temp mapping from buffer assignment
PiperOrigin-RevId: 698610190
This commit is contained in:
parent
9a46acc707
commit
6fb4e335c8
15
third_party/xla/xla/backends/cpu/nanort/BUILD
vendored
15
third_party/xla/xla/backends/cpu/nanort/BUILD
vendored
@ -32,9 +32,7 @@ cc_library(
|
|||||||
"//xla/service:executable",
|
"//xla/service:executable",
|
||||||
"//xla/service:hlo_module_config",
|
"//xla/service:hlo_module_config",
|
||||||
"//xla/service/cpu:cpu_compiler_pure",
|
"//xla/service/cpu:cpu_compiler_pure",
|
||||||
"//xla/service/cpu:cpu_executable",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@local_tsl//tsl/platform:casts",
|
|
||||||
"@local_tsl//tsl/platform:env",
|
"@local_tsl//tsl/platform:env",
|
||||||
"@local_tsl//tsl/platform:logging",
|
"@local_tsl//tsl/platform:logging",
|
||||||
"@local_tsl//tsl/platform:statusor",
|
"@local_tsl//tsl/platform:statusor",
|
||||||
@ -53,6 +51,8 @@ xla_cc_test(
|
|||||||
"//xla:xla_data_proto_cc",
|
"//xla:xla_data_proto_cc",
|
||||||
"//xla/hlo/builder:xla_builder",
|
"//xla/hlo/builder:xla_builder",
|
||||||
"//xla/hlo/builder:xla_computation",
|
"//xla/hlo/builder:xla_computation",
|
||||||
|
"//xla/hlo/ir:hlo",
|
||||||
|
"//xla/hlo/parser:hlo_parser",
|
||||||
"//xla/pjrt:pjrt_client",
|
"//xla/pjrt:pjrt_client",
|
||||||
"//xla/pjrt:pjrt_executable",
|
"//xla/pjrt:pjrt_executable",
|
||||||
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
|
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
|
||||||
@ -75,22 +75,33 @@ cc_library(
|
|||||||
"//xla/backends/cpu/nanort:nanort_users",
|
"//xla/backends/cpu/nanort:nanort_users",
|
||||||
]),
|
]),
|
||||||
deps = [
|
deps = [
|
||||||
|
"//xla:shape_util",
|
||||||
"//xla:util",
|
"//xla:util",
|
||||||
"//xla/backends/cpu/runtime:buffer_allocations",
|
"//xla/backends/cpu/runtime:buffer_allocations",
|
||||||
"//xla/backends/cpu/runtime:thunk",
|
"//xla/backends/cpu/runtime:thunk",
|
||||||
|
"//xla/hlo/ir:hlo",
|
||||||
|
"//xla/service:buffer_assignment",
|
||||||
|
"//xla/service:computation_layout",
|
||||||
"//xla/service:executable",
|
"//xla/service:executable",
|
||||||
|
"//xla/service:hlo_value",
|
||||||
"//xla/service:maybe_owning_device_memory",
|
"//xla/service:maybe_owning_device_memory",
|
||||||
"//xla/service/cpu:cpu_executable",
|
"//xla/service/cpu:cpu_executable",
|
||||||
"//xla/stream_executor:device_memory",
|
"//xla/stream_executor:device_memory",
|
||||||
"//xla/tsl/concurrency:async_value",
|
"//xla/tsl/concurrency:async_value",
|
||||||
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@local_tsl//tsl/platform:casts",
|
"@local_tsl//tsl/platform:casts",
|
||||||
"@local_tsl//tsl/platform:env",
|
"@local_tsl//tsl/platform:env",
|
||||||
|
"@local_tsl//tsl/platform:errors",
|
||||||
"@local_tsl//tsl/platform:logging",
|
"@local_tsl//tsl/platform:logging",
|
||||||
|
"@local_tsl//tsl/platform:statusor",
|
||||||
"@local_tsl//tsl/profiler/lib:traceme",
|
"@local_tsl//tsl/profiler/lib:traceme",
|
||||||
"@local_tsl//tsl/profiler/lib:traceme_encode",
|
"@local_tsl//tsl/profiler/lib:traceme_encode",
|
||||||
],
|
],
|
||||||
|
@ -26,13 +26,11 @@ limitations under the License.
|
|||||||
#include "xla/pjrt/utils.h"
|
#include "xla/pjrt/utils.h"
|
||||||
#include "xla/service/compiler.h"
|
#include "xla/service/compiler.h"
|
||||||
#include "xla/service/cpu/cpu_compiler.h"
|
#include "xla/service/cpu/cpu_compiler.h"
|
||||||
#include "xla/service/cpu/cpu_executable.h"
|
|
||||||
#include "xla/service/dump.h"
|
#include "xla/service/dump.h"
|
||||||
#include "xla/service/executable.h"
|
#include "xla/service/executable.h"
|
||||||
#include "xla/service/hlo_module_config.h"
|
#include "xla/service/hlo_module_config.h"
|
||||||
#include "xla/shape.h"
|
#include "xla/shape.h"
|
||||||
#include "xla/util.h"
|
#include "xla/util.h"
|
||||||
#include "tsl/platform/casts.h"
|
|
||||||
#include "tsl/platform/env.h"
|
#include "tsl/platform/env.h"
|
||||||
#include "tsl/platform/logging.h"
|
#include "tsl/platform/logging.h"
|
||||||
#include "tsl/platform/statusor.h"
|
#include "tsl/platform/statusor.h"
|
||||||
@ -85,13 +83,6 @@ absl::StatusOr<std::unique_ptr<NanoRtExecutable>> NanoRtClient::Compile(
|
|||||||
compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr,
|
compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr,
|
||||||
compile_options));
|
compile_options));
|
||||||
|
|
||||||
// Downcast executable to CpuExecutable to sanity check compilation result.
|
|
||||||
cpu::CpuExecutable* cpu_executable =
|
|
||||||
tsl::down_cast<cpu::CpuExecutable*>(executable.get());
|
|
||||||
if (cpu_executable == nullptr) {
|
|
||||||
return Internal("Failed to downcast executable to CpuExecutable");
|
|
||||||
}
|
|
||||||
|
|
||||||
return NanoRtExecutable::Create(std::move(executable), intra_op_thread_pool_);
|
return NanoRtExecutable::Create(std::move(executable), intra_op_thread_pool_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
@ -24,6 +25,8 @@ limitations under the License.
|
|||||||
#include "xla/backends/cpu/nanort/nanort_executable.h"
|
#include "xla/backends/cpu/nanort/nanort_executable.h"
|
||||||
#include "xla/hlo/builder/xla_builder.h"
|
#include "xla/hlo/builder/xla_builder.h"
|
||||||
#include "xla/hlo/builder/xla_computation.h"
|
#include "xla/hlo/builder/xla_computation.h"
|
||||||
|
#include "xla/hlo/ir/hlo_module.h"
|
||||||
|
#include "xla/hlo/parser/hlo_parser.h"
|
||||||
#include "xla/pjrt/pjrt_client.h"
|
#include "xla/pjrt/pjrt_client.h"
|
||||||
#include "xla/pjrt/pjrt_executable.h"
|
#include "xla/pjrt/pjrt_executable.h"
|
||||||
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
|
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
|
||||||
@ -38,17 +41,99 @@ limitations under the License.
|
|||||||
namespace xla::cpu {
|
namespace xla::cpu {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
absl::StatusOr<XlaComputation> CreateAddScalarsComputation() {
|
using Arguments = absl::InlinedVector<NanoRtExecutable::Argument, 8>;
|
||||||
|
using Results = absl::InlinedVector<NanoRtExecutable::Result, 8>;
|
||||||
|
|
||||||
|
TEST(NanoRtClientTest, CompileAndRunScalarComputation) {
|
||||||
|
constexpr std::string_view hlo = R"(
|
||||||
|
HloModule add
|
||||||
|
|
||||||
|
ENTRY e {
|
||||||
|
p0 = f32[] parameter(0)
|
||||||
|
p1 = f32[] parameter(1)
|
||||||
|
ROOT add = f32[] add(p0, p1)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
|
||||||
|
XlaComputation computation(module->ToProto());
|
||||||
|
|
||||||
|
NanoRtClient client;
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
|
||||||
|
client.Compile(computation));
|
||||||
|
|
||||||
|
// Storage for executable parameters and results.
|
||||||
|
alignas(32) float p0_value = 1.0f;
|
||||||
|
alignas(32) float p1_value = 2.0f;
|
||||||
|
alignas(32) float r0_value = 0.0f;
|
||||||
|
|
||||||
|
// Prepare executable parameters, results and temp storage.
|
||||||
|
Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
|
||||||
|
Results results = {{&r0_value, 1}};
|
||||||
|
NanoRtExecutable::PreallocatedTemp temp = {};
|
||||||
|
|
||||||
|
auto event = executable->Execute(arguments, results, temp);
|
||||||
|
tsl::BlockUntilReady(event);
|
||||||
|
|
||||||
|
ASSERT_TRUE(event.IsConcrete());
|
||||||
|
EXPECT_EQ(r0_value, 3.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NanoRtClientTest, CompileAndRunTupledComputation) {
|
||||||
|
constexpr std::string_view hlo = R"(
|
||||||
|
HloModule add_and_mul
|
||||||
|
|
||||||
|
ENTRY e {
|
||||||
|
p = (f32[], f32[]) parameter(0)
|
||||||
|
p0 = f32[] get-tuple-element(p), index=0
|
||||||
|
p1 = f32[] get-tuple-element(p), index=1
|
||||||
|
add = f32[] add(p0, p1)
|
||||||
|
mul = f32[] multiply(p0, p1)
|
||||||
|
ROOT add_and_mul = (f32[], f32[]) tuple(add, mul)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
|
||||||
|
XlaComputation computation(module->ToProto());
|
||||||
|
|
||||||
|
NanoRtClient client;
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
|
||||||
|
client.Compile(computation));
|
||||||
|
|
||||||
|
// Storage for executable parameters and results.
|
||||||
|
alignas(32) float p0_value = 2.0f;
|
||||||
|
alignas(32) float p1_value = 3.0f;
|
||||||
|
alignas(32) float r0_value = 0.0f;
|
||||||
|
alignas(32) float r1_value = 0.0f;
|
||||||
|
|
||||||
|
// Prepare executable parameters, results and temp storage.
|
||||||
|
Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
|
||||||
|
Results results = {{&r0_value, 1}, {&r1_value, 1}};
|
||||||
|
NanoRtExecutable::PreallocatedTemp temp = {};
|
||||||
|
|
||||||
|
auto event = executable->Execute(arguments, results, temp);
|
||||||
|
tsl::BlockUntilReady(event);
|
||||||
|
|
||||||
|
ASSERT_TRUE(event.IsConcrete());
|
||||||
|
EXPECT_EQ(r0_value, 5.0f);
|
||||||
|
EXPECT_EQ(r1_value, 6.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Performance benchmarks below
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static absl::StatusOr<XlaComputation> CreateAddScalarsComputation() {
|
||||||
XlaBuilder b("add");
|
XlaBuilder b("add");
|
||||||
|
|
||||||
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0");
|
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0");
|
||||||
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "p1");
|
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "p1");
|
||||||
Add(Add(p0, p1), Add(p0, p1));
|
Add(p0, p1);
|
||||||
|
|
||||||
return b.Build();
|
return b.Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<XlaComputation> CreateFibonacciComputation() {
|
static absl::StatusOr<XlaComputation> CreateFibonacciComputation() {
|
||||||
XlaBuilder b("fib");
|
XlaBuilder b("fib");
|
||||||
|
|
||||||
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0");
|
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0");
|
||||||
@ -64,38 +149,6 @@ absl::StatusOr<XlaComputation> CreateFibonacciComputation() {
|
|||||||
return b.Build();
|
return b.Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NanoRtClientTest, CompileAndRun) {
|
|
||||||
NanoRtClient client;
|
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto computation, CreateAddScalarsComputation());
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
|
|
||||||
client.Compile(computation));
|
|
||||||
|
|
||||||
// Storage for executable parameters and results.
|
|
||||||
alignas(32) float p0_value = 1.0f;
|
|
||||||
alignas(32) float p1_value = 2.0f;
|
|
||||||
alignas(32) float result = 0.0f;
|
|
||||||
|
|
||||||
// Prepare executable parameters, results and temp storage.
|
|
||||||
NanoRtExecutable::Argument p0(&p0_value, 1);
|
|
||||||
NanoRtExecutable::Argument p1(&p1_value, 1);
|
|
||||||
NanoRtExecutable::Result r0(&result, 1);
|
|
||||||
NanoRtExecutable::PreallocatedTemp temp = {};
|
|
||||||
|
|
||||||
std::vector<NanoRtExecutable::Argument> arguments = {p0, p1};
|
|
||||||
std::vector<NanoRtExecutable::Result> results = {r0};
|
|
||||||
|
|
||||||
auto event = executable->Execute(arguments, results, temp);
|
|
||||||
tsl::BlockUntilReady(event);
|
|
||||||
|
|
||||||
ASSERT_TRUE(event.IsConcrete());
|
|
||||||
EXPECT_EQ(result, 6.0f);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Performance benchmarks below
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
static void BM_NanoRtAddScalars(benchmark::State& state) {
|
static void BM_NanoRtAddScalars(benchmark::State& state) {
|
||||||
NanoRtClient client;
|
NanoRtClient client;
|
||||||
|
|
||||||
@ -105,17 +158,13 @@ static void BM_NanoRtAddScalars(benchmark::State& state) {
|
|||||||
// Storage for executable arguments and results.
|
// Storage for executable arguments and results.
|
||||||
alignas(32) float p0_value = 1.0f;
|
alignas(32) float p0_value = 1.0f;
|
||||||
alignas(32) float p1_value = 2.0f;
|
alignas(32) float p1_value = 2.0f;
|
||||||
alignas(32) float result = 0.0f;
|
alignas(32) float r0_value = 0.0f;
|
||||||
|
|
||||||
for (auto _ : state) {
|
for (auto _ : state) {
|
||||||
NanoRtExecutable::Argument p0(&p0_value, 1);
|
Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
|
||||||
NanoRtExecutable::Argument p1(&p1_value, 1);
|
Results results = {{&r0_value, 1}};
|
||||||
NanoRtExecutable::Result r0(&result, 1);
|
|
||||||
NanoRtExecutable::PreallocatedTemp temp = {};
|
NanoRtExecutable::PreallocatedTemp temp = {};
|
||||||
|
|
||||||
absl::InlinedVector<NanoRtExecutable::Argument, 2> arguments = {p0, p1};
|
|
||||||
absl::InlinedVector<NanoRtExecutable::Result, 1> results = {r0};
|
|
||||||
|
|
||||||
auto event = (*executable)->Execute(arguments, results, temp);
|
auto event = (*executable)->Execute(arguments, results, temp);
|
||||||
tsl::BlockUntilReady(event);
|
tsl::BlockUntilReady(event);
|
||||||
}
|
}
|
||||||
@ -132,17 +181,13 @@ static void BM_NanoRtFibonacci(benchmark::State& state) {
|
|||||||
// Storage for executable arguments and results.
|
// Storage for executable arguments and results.
|
||||||
alignas(32) float p0_value = 1.0f;
|
alignas(32) float p0_value = 1.0f;
|
||||||
alignas(32) float p1_value = 2.0f;
|
alignas(32) float p1_value = 2.0f;
|
||||||
alignas(32) float result = 0.0f;
|
alignas(32) float r0_value = 0.0f;
|
||||||
|
|
||||||
for (auto _ : state) {
|
for (auto _ : state) {
|
||||||
NanoRtExecutable::Argument p0(&p0_value, 1);
|
Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}};
|
||||||
NanoRtExecutable::Argument p1(&p1_value, 1);
|
Results results = {{&r0_value, 1}};
|
||||||
NanoRtExecutable::Result r0(&result, 1);
|
|
||||||
NanoRtExecutable::PreallocatedTemp temp = {};
|
NanoRtExecutable::PreallocatedTemp temp = {};
|
||||||
|
|
||||||
absl::InlinedVector<NanoRtExecutable::Argument, 2> arguments = {p0, p1};
|
|
||||||
absl::InlinedVector<NanoRtExecutable::Result, 1> results = {r0};
|
|
||||||
|
|
||||||
auto event = (*executable)->Execute(arguments, results, temp);
|
auto event = (*executable)->Execute(arguments, results, temp);
|
||||||
tsl::BlockUntilReady(event);
|
tsl::BlockUntilReady(event);
|
||||||
}
|
}
|
||||||
|
@ -15,22 +15,37 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "xla/backends/cpu/nanort/nanort_executable.h"
|
#include "xla/backends/cpu/nanort/nanort_executable.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/base/optimization.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "xla/backends/cpu/runtime/buffer_allocations.h"
|
#include "xla/backends/cpu/runtime/buffer_allocations.h"
|
||||||
#include "xla/backends/cpu/runtime/thunk.h"
|
#include "xla/backends/cpu/runtime/thunk.h"
|
||||||
|
#include "xla/hlo/ir/hlo_module.h"
|
||||||
|
#include "xla/service/buffer_assignment.h"
|
||||||
|
#include "xla/service/computation_layout.h"
|
||||||
#include "xla/service/cpu/cpu_executable.h"
|
#include "xla/service/cpu/cpu_executable.h"
|
||||||
#include "xla/service/executable.h"
|
#include "xla/service/executable.h"
|
||||||
|
#include "xla/service/hlo_value.h"
|
||||||
#include "xla/service/maybe_owning_device_memory.h"
|
#include "xla/service/maybe_owning_device_memory.h"
|
||||||
|
#include "xla/shape.h"
|
||||||
|
#include "xla/shape_util.h"
|
||||||
#include "xla/stream_executor/device_memory.h"
|
#include "xla/stream_executor/device_memory.h"
|
||||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||||
#include "xla/util.h"
|
#include "xla/util.h"
|
||||||
#include "tsl/platform/casts.h"
|
#include "tsl/platform/casts.h"
|
||||||
|
#include "tsl/platform/errors.h"
|
||||||
#include "tsl/platform/logging.h"
|
#include "tsl/platform/logging.h"
|
||||||
|
#include "tsl/platform/statusor.h"
|
||||||
#include "tsl/platform/threadpool.h"
|
#include "tsl/platform/threadpool.h"
|
||||||
#include "tsl/profiler/lib/traceme.h"
|
#include "tsl/profiler/lib/traceme.h"
|
||||||
#include "tsl/profiler/lib/traceme_encode.h"
|
#include "tsl/profiler/lib/traceme_encode.h"
|
||||||
@ -40,41 +55,195 @@ namespace xla::cpu {
|
|||||||
using ::tsl::profiler::TraceMe;
|
using ::tsl::profiler::TraceMe;
|
||||||
using ::tsl::profiler::TraceMeEncode;
|
using ::tsl::profiler::TraceMeEncode;
|
||||||
|
|
||||||
|
using ArgumentIndex = std::pair<size_t, ShapeIndex>;
|
||||||
|
|
||||||
|
// Resolves the mapping from argument index to allocation index.
|
||||||
|
static absl::StatusOr<std::vector<size_t>> ResolveArgumentsMapping(
|
||||||
|
const HloModule& module, const BufferAssignment& buffer_assignment) {
|
||||||
|
const ComputationLayout& entry_layout = module.entry_computation_layout();
|
||||||
|
|
||||||
|
VLOG(3) << "Resolve executable arguments mapping:";
|
||||||
|
|
||||||
|
// Mapping from argument index to flattened executable argument index.
|
||||||
|
absl::flat_hash_map<ArgumentIndex, size_t> executable_arg_index;
|
||||||
|
for (size_t i = 0; i < entry_layout.parameter_count(); ++i) {
|
||||||
|
ShapeUtil::ForEachLeafShape(
|
||||||
|
entry_layout.parameter_shape(i),
|
||||||
|
[&](const Shape&, const ShapeIndex& index) {
|
||||||
|
size_t arg_index = executable_arg_index.size();
|
||||||
|
executable_arg_index[ArgumentIndex{i, index}] = arg_index;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> argument_to_allocation_index(executable_arg_index.size());
|
||||||
|
for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
|
||||||
|
if (allocation.is_entry_computation_parameter()) {
|
||||||
|
ArgumentIndex idx{allocation.parameter_number(),
|
||||||
|
allocation.param_shape_index()};
|
||||||
|
|
||||||
|
// Skip buffer allocations assigned to non-leaf parameters (tuples).
|
||||||
|
auto arg_idx = executable_arg_index.find(idx);
|
||||||
|
if (arg_idx == executable_arg_index.end()) continue;
|
||||||
|
|
||||||
|
VLOG(3) << absl::StreamFormat(
|
||||||
|
" - parameter %d at shape index %s:"
|
||||||
|
" argument index = %d allocation index = %d",
|
||||||
|
allocation.parameter_number(),
|
||||||
|
allocation.param_shape_index().ToString(), arg_idx->second,
|
||||||
|
allocation.index());
|
||||||
|
|
||||||
|
argument_to_allocation_index[arg_idx->second] = allocation.index();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return argument_to_allocation_index;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolves the mapping from result index to allocation index.
|
||||||
|
static absl::StatusOr<std::vector<size_t>> ResolveResultMapping(
|
||||||
|
const HloModule& module, const BufferAssignment& buffer_assignment) {
|
||||||
|
const ComputationLayout& entry_layout = module.entry_computation_layout();
|
||||||
|
|
||||||
|
VLOG(3) << "Resolve executable results mapping:";
|
||||||
|
|
||||||
|
// Mapping from result index to flattened executable result index.
|
||||||
|
absl::flat_hash_map<ShapeIndex, size_t> executable_res_index;
|
||||||
|
ShapeUtil::ForEachLeafShape(entry_layout.result_shape(),
|
||||||
|
[&](const Shape&, const ShapeIndex& index) {
|
||||||
|
size_t res_index = executable_res_index.size();
|
||||||
|
executable_res_index[index] = res_index;
|
||||||
|
});
|
||||||
|
|
||||||
|
const InstructionValueSet& root_value_set =
|
||||||
|
buffer_assignment.dataflow_analysis().GetInstructionValueSet(
|
||||||
|
module.entry_computation()->root_instruction());
|
||||||
|
|
||||||
|
std::vector<size_t> result_to_allocation_index(executable_res_index.size());
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(ShapeUtil::ForEachLeafShapeWithStatus(
|
||||||
|
entry_layout.result_shape(),
|
||||||
|
[&](const Shape&, const ShapeIndex& index) -> absl::Status {
|
||||||
|
// Skip buffer allocations assigned to non-leaf results (tuples).
|
||||||
|
auto res_idx = executable_res_index.find(index);
|
||||||
|
if (res_idx == executable_res_index.end()) return absl::OkStatus();
|
||||||
|
|
||||||
|
const HloValueSet& sources = root_value_set.element(index);
|
||||||
|
|
||||||
|
if (sources.values().size() != 1) {
|
||||||
|
return Internal(
|
||||||
|
"Expected a single value for result at shape index %s",
|
||||||
|
index.ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const HloValue* value = sources.values().front();
|
||||||
|
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
|
||||||
|
buffer_assignment.GetUniqueSlice(
|
||||||
|
value->instruction(), value->index()));
|
||||||
|
|
||||||
|
DCHECK_EQ(slice.size(), slice.allocation()->size())
|
||||||
|
<< "Result slice size must match result allocation size";
|
||||||
|
|
||||||
|
VLOG(3) << absl::StreamFormat(
|
||||||
|
" - result at shape index %s:"
|
||||||
|
" result index = %d allocation index = %d",
|
||||||
|
index.ToString(), res_idx->second, slice.index());
|
||||||
|
|
||||||
|
result_to_allocation_index[res_idx->second] = slice.index();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}));
|
||||||
|
|
||||||
|
return result_to_allocation_index;
|
||||||
|
}
|
||||||
|
|
||||||
|
static absl::StatusOr<std::optional<size_t>> ResolveTempAllocationIndex(
|
||||||
|
const BufferAssignment& buffer_assignment) {
|
||||||
|
VLOG(3) << "Resolve temp allocation index:";
|
||||||
|
|
||||||
|
std::optional<size_t> temp_allocation_index;
|
||||||
|
for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
|
||||||
|
if (allocation.IsPreallocatedTempBuffer()) {
|
||||||
|
if (temp_allocation_index.has_value()) {
|
||||||
|
return Internal("Multiple temp buffer allocations found");
|
||||||
|
}
|
||||||
|
VLOG(3) << " - temp buffer allocation index = " << allocation.index();
|
||||||
|
temp_allocation_index = allocation.index();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!temp_allocation_index.has_value()) {
|
||||||
|
VLOG(3) << " - no temp buffer allocation found";
|
||||||
|
}
|
||||||
|
|
||||||
|
return temp_allocation_index;
|
||||||
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::unique_ptr<NanoRtExecutable>> NanoRtExecutable::Create(
|
absl::StatusOr<std::unique_ptr<NanoRtExecutable>> NanoRtExecutable::Create(
|
||||||
std::unique_ptr<Executable> executable,
|
std::unique_ptr<Executable> executable,
|
||||||
std::shared_ptr<tsl::thread::ThreadPool> thread_pool) {
|
std::shared_ptr<tsl::thread::ThreadPool> thread_pool) {
|
||||||
|
const HloModule& module = executable->module();
|
||||||
|
|
||||||
|
VLOG(3) << "Create NanoRtExecutable: name = " << module.name();
|
||||||
|
|
||||||
|
// NanoRtExecutable requires a CPU executable with thunks.
|
||||||
auto* cpu_executable = tsl::down_cast<cpu::CpuExecutable*>(executable.get());
|
auto* cpu_executable = tsl::down_cast<cpu::CpuExecutable*>(executable.get());
|
||||||
|
if (cpu_executable == nullptr) {
|
||||||
|
return Internal("NanoRtExecutable requires CPU executable");
|
||||||
|
}
|
||||||
if (!cpu_executable->has_thunks()) {
|
if (!cpu_executable->has_thunks()) {
|
||||||
return Internal("NanoRtExecutable requires CPU executable to use thunks");
|
return Internal("NanoRtExecutable requires CPU executable to use thunks");
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::WrapUnique(
|
// Mappings from argument/result index to buffer allocation index.
|
||||||
new NanoRtExecutable(std::move(executable), std::move(thread_pool)));
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<size_t> argument_to_allocation_index,
|
||||||
|
ResolveArgumentsMapping(module, cpu_executable->buffer_assignment()));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<size_t> result_to_allocation_index,
|
||||||
|
ResolveResultMapping(module, cpu_executable->buffer_assignment()));
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::optional<size_t> temp_allocation_index,
|
||||||
|
ResolveTempAllocationIndex(cpu_executable->buffer_assignment()));
|
||||||
|
|
||||||
|
const auto& buffer_assignment = cpu_executable->buffer_assignment();
|
||||||
|
size_t num_allocations = buffer_assignment.Allocations().size();
|
||||||
|
|
||||||
|
return absl::WrapUnique(new NanoRtExecutable(
|
||||||
|
std::move(executable), std::move(thread_pool), num_allocations,
|
||||||
|
std::move(argument_to_allocation_index),
|
||||||
|
std::move(result_to_allocation_index), temp_allocation_index));
|
||||||
}
|
}
|
||||||
|
|
||||||
NanoRtExecutable::NanoRtExecutable(
|
NanoRtExecutable::NanoRtExecutable(
|
||||||
std::unique_ptr<Executable> executable,
|
std::unique_ptr<Executable> executable,
|
||||||
std::shared_ptr<tsl::thread::ThreadPool> thread_pool)
|
std::shared_ptr<tsl::thread::ThreadPool> thread_pool,
|
||||||
|
size_t num_allocations, std::vector<size_t> argument_to_allocation_index,
|
||||||
|
std::vector<size_t> result_to_allocation_index,
|
||||||
|
std::optional<size_t> temp_allocation_index)
|
||||||
: executable_(std::move(executable)),
|
: executable_(std::move(executable)),
|
||||||
thread_pool_(std::move(thread_pool)) {}
|
thread_pool_(std::move(thread_pool)),
|
||||||
|
num_allocations_(num_allocations),
|
||||||
|
argument_to_allocation_index_(std::move(argument_to_allocation_index)),
|
||||||
|
result_to_allocation_index_(std::move(result_to_allocation_index)),
|
||||||
|
temp_allocation_index_(temp_allocation_index) {}
|
||||||
|
|
||||||
static se::DeviceMemoryBase ToDeviceMemory(
|
static se::DeviceMemoryBase ToDeviceMemory(
|
||||||
const NanoRtExecutable::Argument& argument) {
|
const NanoRtExecutable::Argument& argument) {
|
||||||
return stream_executor::DeviceMemoryBase(
|
return se::DeviceMemoryBase(
|
||||||
const_cast<void*>(reinterpret_cast<const void*>(argument.data().data())),
|
const_cast<void*>(reinterpret_cast<const void*>(argument.data().data())),
|
||||||
argument.data().size());
|
argument.data().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
static se::DeviceMemoryBase ToDeviceMemory(
|
static se::DeviceMemoryBase ToDeviceMemory(
|
||||||
const NanoRtExecutable::Result& result) {
|
const NanoRtExecutable::Result& result) {
|
||||||
return stream_executor::DeviceMemoryBase(
|
return se::DeviceMemoryBase(reinterpret_cast<void*>(result.data().data()),
|
||||||
reinterpret_cast<void*>(result.data().data()), result.data().size());
|
result.data().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
static se::DeviceMemoryBase ToDeviceMemory(
|
static se::DeviceMemoryBase ToDeviceMemory(
|
||||||
const NanoRtExecutable::PreallocatedTemp& temp) {
|
const NanoRtExecutable::PreallocatedTemp& temp) {
|
||||||
return stream_executor::DeviceMemoryBase(reinterpret_cast<void*>(temp.data()),
|
return se::DeviceMemoryBase(reinterpret_cast<void*>(temp.data()),
|
||||||
temp.size());
|
temp.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
tsl::AsyncValueRef<NanoRtExecutable::ExecuteEvent> NanoRtExecutable::Execute(
|
tsl::AsyncValueRef<NanoRtExecutable::ExecuteEvent> NanoRtExecutable::Execute(
|
||||||
@ -87,21 +256,38 @@ tsl::AsyncValueRef<NanoRtExecutable::ExecuteEvent> NanoRtExecutable::Execute(
|
|||||||
|
|
||||||
auto* executable = tsl::down_cast<cpu::CpuExecutable*>(executable_.get());
|
auto* executable = tsl::down_cast<cpu::CpuExecutable*>(executable_.get());
|
||||||
|
|
||||||
// Convert arguments, results, and temp to device memory.
|
size_t num_arguments = argument_to_allocation_index_.size();
|
||||||
absl::InlinedVector<MaybeOwningDeviceMemory, 8> buffer_device_mem;
|
size_t num_results = result_to_allocation_index_.size();
|
||||||
buffer_device_mem.reserve(arguments.size() + results.size() + 1);
|
|
||||||
|
|
||||||
for (const Result& result : results) {
|
if (ABSL_PREDICT_FALSE(arguments.size() != num_arguments)) {
|
||||||
buffer_device_mem.emplace_back(ToDeviceMemory(result));
|
return InvalidArgument("Expected %d arguments, got %d", num_arguments,
|
||||||
|
arguments.size());
|
||||||
}
|
}
|
||||||
for (const Argument& argument : arguments) {
|
|
||||||
buffer_device_mem.emplace_back(ToDeviceMemory(argument));
|
if (ABSL_PREDICT_FALSE(results.size() != num_results)) {
|
||||||
|
return InvalidArgument("Expected %d results, got %d", num_results,
|
||||||
|
results.size());
|
||||||
}
|
}
|
||||||
buffer_device_mem.emplace_back(ToDeviceMemory(temp));
|
|
||||||
|
|
||||||
// Prepare buffer allocations for arguments, results, and temp.
|
// Prepare buffer allocations for arguments, results, and temp.
|
||||||
cpu::BufferAllocations allocations(buffer_device_mem);
|
absl::InlinedVector<MaybeOwningDeviceMemory, 8> buffers(num_allocations_);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_arguments; ++i) {
|
||||||
|
buffers[argument_to_allocation_index_[i]] =
|
||||||
|
MaybeOwningDeviceMemory(ToDeviceMemory(arguments[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_results; ++i) {
|
||||||
|
buffers[result_to_allocation_index_[i]] =
|
||||||
|
MaybeOwningDeviceMemory(ToDeviceMemory(results[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (temp_allocation_index_) {
|
||||||
|
buffers[*temp_allocation_index_] =
|
||||||
|
MaybeOwningDeviceMemory(ToDeviceMemory(temp));
|
||||||
|
}
|
||||||
|
|
||||||
|
cpu::BufferAllocations allocations(buffers);
|
||||||
cpu::Thunk::ExecuteParams execute_params = {
|
cpu::Thunk::ExecuteParams execute_params = {
|
||||||
&executable->function_registry(),
|
&executable->function_registry(),
|
||||||
&allocations,
|
&allocations,
|
||||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
|||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
@ -83,10 +85,24 @@ class NanoRtExecutable {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
NanoRtExecutable(std::unique_ptr<Executable> executable,
|
NanoRtExecutable(std::unique_ptr<Executable> executable,
|
||||||
std::shared_ptr<tsl::thread::ThreadPool> thread_pool);
|
std::shared_ptr<tsl::thread::ThreadPool> thread_pool,
|
||||||
|
size_t num_allocations,
|
||||||
|
std::vector<size_t> argument_to_allocation_index,
|
||||||
|
std::vector<size_t> result_to_allocation_index,
|
||||||
|
std::optional<size_t> temp_allocation_index);
|
||||||
|
|
||||||
std::unique_ptr<Executable> executable_;
|
std::unique_ptr<Executable> executable_;
|
||||||
std::shared_ptr<tsl::thread::ThreadPool> thread_pool_;
|
std::shared_ptr<tsl::thread::ThreadPool> thread_pool_;
|
||||||
|
|
||||||
|
size_t num_allocations_;
|
||||||
|
|
||||||
|
// A mapping from the argument/result index to the index of the corresponding
|
||||||
|
// allocation (defined by the executable's buffer assignment).
|
||||||
|
std::vector<size_t> argument_to_allocation_index_;
|
||||||
|
std::vector<size_t> result_to_allocation_index_;
|
||||||
|
|
||||||
|
// Index of the temp allocation.
|
||||||
|
std::optional<size_t> temp_allocation_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
Loading…
Reference in New Issue
Block a user