diff --git a/third_party/xla/xla/backends/cpu/nanort/BUILD b/third_party/xla/xla/backends/cpu/nanort/BUILD index 3ca8cb5cada..99f6b9275fb 100644 --- a/third_party/xla/xla/backends/cpu/nanort/BUILD +++ b/third_party/xla/xla/backends/cpu/nanort/BUILD @@ -32,9 +32,7 @@ cc_library( "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service/cpu:cpu_compiler_pure", - "//xla/service/cpu:cpu_executable", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -53,6 +51,8 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_executable", "//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", @@ -75,22 +75,33 @@ cc_library( "//xla/backends/cpu/nanort:nanort_users", ]), deps = [ + "//xla:shape_util", "//xla:util", "//xla/backends/cpu/runtime:buffer_allocations", "//xla/backends/cpu/runtime:thunk", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:computation_layout", "//xla/service:executable", + "//xla/service:hlo_value", "//xla/service:maybe_owning_device_memory", "//xla/service/cpu:cpu_executable", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/lib:traceme_encode", ], diff --git a/third_party/xla/xla/backends/cpu/nanort/nanort_client.cc b/third_party/xla/xla/backends/cpu/nanort/nanort_client.cc index b0685f0f119..d318f06de6a 100644 --- a/third_party/xla/xla/backends/cpu/nanort/nanort_client.cc +++ b/third_party/xla/xla/backends/cpu/nanort/nanort_client.cc @@ -26,13 +26,11 @@ limitations under the License. #include "xla/pjrt/utils.h" #include "xla/service/compiler.h" #include "xla/service/cpu/cpu_compiler.h" -#include "xla/service/cpu/cpu_executable.h" #include "xla/service/dump.h" #include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/util.h" -#include "tsl/platform/casts.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" @@ -85,13 +83,6 @@ absl::StatusOr> NanoRtClient::Compile( compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr, compile_options)); - // Downcast executable to CpuExecutable to sanity check compilation result. - cpu::CpuExecutable* cpu_executable = - tsl::down_cast(executable.get()); - if (cpu_executable == nullptr) { - return Internal("Failed to downcast executable to CpuExecutable"); - } - return NanoRtExecutable::Create(std::move(executable), intra_op_thread_pool_); } diff --git a/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc b/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc index fd9536ca25f..1b817869910 100644 --- a/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc +++ b/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "absl/container/inlined_vector.h" @@ -24,6 +25,8 @@ limitations under the License. #include "xla/backends/cpu/nanort/nanort_executable.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" @@ -38,17 +41,99 @@ limitations under the License. namespace xla::cpu { namespace { -absl::StatusOr CreateAddScalarsComputation() { +using Arguments = absl::InlinedVector; +using Results = absl::InlinedVector; + +TEST(NanoRtClientTest, CompileAndRunScalarComputation) { + constexpr std::string_view hlo = R"( + HloModule add + + ENTRY e { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + XlaComputation computation(module->ToProto()); + + NanoRtClient client; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + client.Compile(computation)); + + // Storage for executable parameters and results. + alignas(32) float p0_value = 1.0f; + alignas(32) float p1_value = 2.0f; + alignas(32) float r0_value = 0.0f; + + // Prepare executable parameters, results and temp storage. + Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}}; + Results results = {{&r0_value, 1}}; + NanoRtExecutable::PreallocatedTemp temp = {}; + + auto event = executable->Execute(arguments, results, temp); + tsl::BlockUntilReady(event); + + ASSERT_TRUE(event.IsConcrete()); + EXPECT_EQ(r0_value, 3.0f); +} + +TEST(NanoRtClientTest, CompileAndRunTupledComputation) { + constexpr std::string_view hlo = R"( + HloModule add_and_mul + + ENTRY e { + p = (f32[], f32[]) parameter(0) + p0 = f32[] get-tuple-element(p), index=0 + p1 = f32[] get-tuple-element(p), index=1 + add = f32[] add(p0, p1) + mul = f32[] multiply(p0, p1) + ROOT add_and_mul = (f32[], f32[]) tuple(add, mul) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + XlaComputation computation(module->ToProto()); + + NanoRtClient client; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + client.Compile(computation)); + + // Storage for executable parameters and results. + alignas(32) float p0_value = 2.0f; + alignas(32) float p1_value = 3.0f; + alignas(32) float r0_value = 0.0f; + alignas(32) float r1_value = 0.0f; + + // Prepare executable parameters, results and temp storage. + Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}}; + Results results = {{&r0_value, 1}, {&r1_value, 1}}; + NanoRtExecutable::PreallocatedTemp temp = {}; + + auto event = executable->Execute(arguments, results, temp); + tsl::BlockUntilReady(event); + + ASSERT_TRUE(event.IsConcrete()); + EXPECT_EQ(r0_value, 5.0f); + EXPECT_EQ(r1_value, 6.0f); +} + +//===----------------------------------------------------------------------===// +// Performance benchmarks below +//===----------------------------------------------------------------------===// + +static absl::StatusOr CreateAddScalarsComputation() { XlaBuilder b("add"); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0"); auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "p1"); - Add(Add(p0, p1), Add(p0, p1)); + Add(p0, p1); return b.Build(); } -absl::StatusOr CreateFibonacciComputation() { +static absl::StatusOr CreateFibonacciComputation() { XlaBuilder b("fib"); auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "p0"); @@ -64,38 +149,6 @@ absl::StatusOr CreateFibonacciComputation() { return b.Build(); } -TEST(NanoRtClientTest, CompileAndRun) { - NanoRtClient client; - - TF_ASSERT_OK_AND_ASSIGN(auto computation, CreateAddScalarsComputation()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, - client.Compile(computation)); - - // Storage for executable parameters and results. - alignas(32) float p0_value = 1.0f; - alignas(32) float p1_value = 2.0f; - alignas(32) float result = 0.0f; - - // Prepare executable parameters, results and temp storage. - NanoRtExecutable::Argument p0(&p0_value, 1); - NanoRtExecutable::Argument p1(&p1_value, 1); - NanoRtExecutable::Result r0(&result, 1); - NanoRtExecutable::PreallocatedTemp temp = {}; - - std::vector arguments = {p0, p1}; - std::vector results = {r0}; - - auto event = executable->Execute(arguments, results, temp); - tsl::BlockUntilReady(event); - - ASSERT_TRUE(event.IsConcrete()); - EXPECT_EQ(result, 6.0f); -} - -//===----------------------------------------------------------------------===// -// Performance benchmarks below -//===----------------------------------------------------------------------===// - static void BM_NanoRtAddScalars(benchmark::State& state) { NanoRtClient client; @@ -105,17 +158,13 @@ static void BM_NanoRtAddScalars(benchmark::State& state) { // Storage for executable arguments and results. alignas(32) float p0_value = 1.0f; alignas(32) float p1_value = 2.0f; - alignas(32) float result = 0.0f; + alignas(32) float r0_value = 0.0f; for (auto _ : state) { - NanoRtExecutable::Argument p0(&p0_value, 1); - NanoRtExecutable::Argument p1(&p1_value, 1); - NanoRtExecutable::Result r0(&result, 1); + Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}}; + Results results = {{&r0_value, 1}}; NanoRtExecutable::PreallocatedTemp temp = {}; - absl::InlinedVector arguments = {p0, p1}; - absl::InlinedVector results = {r0}; - auto event = (*executable)->Execute(arguments, results, temp); tsl::BlockUntilReady(event); } @@ -132,17 +181,13 @@ static void BM_NanoRtFibonacci(benchmark::State& state) { // Storage for executable arguments and results. alignas(32) float p0_value = 1.0f; alignas(32) float p1_value = 2.0f; - alignas(32) float result = 0.0f; + alignas(32) float r0_value = 0.0f; for (auto _ : state) { - NanoRtExecutable::Argument p0(&p0_value, 1); - NanoRtExecutable::Argument p1(&p1_value, 1); - NanoRtExecutable::Result r0(&result, 1); + Arguments arguments = {{&p0_value, 1}, {&p1_value, 1}}; + Results results = {{&r0_value, 1}}; NanoRtExecutable::PreallocatedTemp temp = {}; - absl::InlinedVector arguments = {p0, p1}; - absl::InlinedVector results = {r0}; - auto event = (*executable)->Execute(arguments, results, temp); tsl::BlockUntilReady(event); } diff --git a/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc b/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc index 15928261963..fb4b3ceeb77 100644 --- a/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc +++ b/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc @@ -15,22 +15,37 @@ limitations under the License. #include "xla/backends/cpu/nanort/nanort_executable.h" +#include #include +#include #include +#include +#include "absl/base/optimization.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/computation_layout.h" #include "xla/service/cpu/cpu_executable.h" #include "xla/service/executable.h" +#include "xla/service/hlo_value.h" #include "xla/service/maybe_owning_device_memory.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" #include "tsl/platform/casts.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" @@ -40,41 +55,195 @@ namespace xla::cpu { using ::tsl::profiler::TraceMe; using ::tsl::profiler::TraceMeEncode; +using ArgumentIndex = std::pair; + +// Resolves the mapping from argument index to allocation index. +static absl::StatusOr> ResolveArgumentsMapping( + const HloModule& module, const BufferAssignment& buffer_assignment) { + const ComputationLayout& entry_layout = module.entry_computation_layout(); + + VLOG(3) << "Resolve executable arguments mapping:"; + + // Mapping from argument index to flattened executable argument index. + absl::flat_hash_map executable_arg_index; + for (size_t i = 0; i < entry_layout.parameter_count(); ++i) { + ShapeUtil::ForEachLeafShape( + entry_layout.parameter_shape(i), + [&](const Shape&, const ShapeIndex& index) { + size_t arg_index = executable_arg_index.size(); + executable_arg_index[ArgumentIndex{i, index}] = arg_index; + }); + } + + std::vector argument_to_allocation_index(executable_arg_index.size()); + for (const BufferAllocation& allocation : buffer_assignment.Allocations()) { + if (allocation.is_entry_computation_parameter()) { + ArgumentIndex idx{allocation.parameter_number(), + allocation.param_shape_index()}; + + // Skip buffer allocations assigned to non-leaf parameters (tuples). + auto arg_idx = executable_arg_index.find(idx); + if (arg_idx == executable_arg_index.end()) continue; + + VLOG(3) << absl::StreamFormat( + " - parameter %d at shape index %s:" + " argument index = %d allocation index = %d", + allocation.parameter_number(), + allocation.param_shape_index().ToString(), arg_idx->second, + allocation.index()); + + argument_to_allocation_index[arg_idx->second] = allocation.index(); + } + } + + return argument_to_allocation_index; +} + +// Resolves the mapping from result index to allocation index. +static absl::StatusOr> ResolveResultMapping( + const HloModule& module, const BufferAssignment& buffer_assignment) { + const ComputationLayout& entry_layout = module.entry_computation_layout(); + + VLOG(3) << "Resolve executable results mapping:"; + + // Mapping from result index to flattened executable result index. + absl::flat_hash_map executable_res_index; + ShapeUtil::ForEachLeafShape(entry_layout.result_shape(), + [&](const Shape&, const ShapeIndex& index) { + size_t res_index = executable_res_index.size(); + executable_res_index[index] = res_index; + }); + + const InstructionValueSet& root_value_set = + buffer_assignment.dataflow_analysis().GetInstructionValueSet( + module.entry_computation()->root_instruction()); + + std::vector result_to_allocation_index(executable_res_index.size()); + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachLeafShapeWithStatus( + entry_layout.result_shape(), + [&](const Shape&, const ShapeIndex& index) -> absl::Status { + // Skip buffer allocations assigned to non-leaf results (tuples). + auto res_idx = executable_res_index.find(index); + if (res_idx == executable_res_index.end()) return absl::OkStatus(); + + const HloValueSet& sources = root_value_set.element(index); + + if (sources.values().size() != 1) { + return Internal( + "Expected a single value for result at shape index %s", + index.ToString()); + } + + const HloValue* value = sources.values().front(); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + buffer_assignment.GetUniqueSlice( + value->instruction(), value->index())); + + DCHECK_EQ(slice.size(), slice.allocation()->size()) + << "Result slice size must match result allocation size"; + + VLOG(3) << absl::StreamFormat( + " - result at shape index %s:" + " result index = %d allocation index = %d", + index.ToString(), res_idx->second, slice.index()); + + result_to_allocation_index[res_idx->second] = slice.index(); + return absl::OkStatus(); + })); + + return result_to_allocation_index; +} + +static absl::StatusOr> ResolveTempAllocationIndex( + const BufferAssignment& buffer_assignment) { + VLOG(3) << "Resolve temp allocation index:"; + + std::optional temp_allocation_index; + for (const BufferAllocation& allocation : buffer_assignment.Allocations()) { + if (allocation.IsPreallocatedTempBuffer()) { + if (temp_allocation_index.has_value()) { + return Internal("Multiple temp buffer allocations found"); + } + VLOG(3) << " - temp buffer allocation index = " << allocation.index(); + temp_allocation_index = allocation.index(); + } + } + + if (!temp_allocation_index.has_value()) { + VLOG(3) << " - no temp buffer allocation found"; + } + + return temp_allocation_index; +} + absl::StatusOr> NanoRtExecutable::Create( std::unique_ptr executable, std::shared_ptr thread_pool) { + const HloModule& module = executable->module(); + + VLOG(3) << "Create NanoRtExecutable: name = " << module.name(); + + // NanoRtExecutable requires a CPU executable with thunks. auto* cpu_executable = tsl::down_cast(executable.get()); + if (cpu_executable == nullptr) { + return Internal("NanoRtExecutable requires CPU executable"); + } if (!cpu_executable->has_thunks()) { return Internal("NanoRtExecutable requires CPU executable to use thunks"); } - return absl::WrapUnique( - new NanoRtExecutable(std::move(executable), std::move(thread_pool))); + // Mappings from argument/result index to buffer allocation index. + TF_ASSIGN_OR_RETURN( + std::vector argument_to_allocation_index, + ResolveArgumentsMapping(module, cpu_executable->buffer_assignment())); + TF_ASSIGN_OR_RETURN( + std::vector result_to_allocation_index, + ResolveResultMapping(module, cpu_executable->buffer_assignment())); + + TF_ASSIGN_OR_RETURN( + std::optional temp_allocation_index, + ResolveTempAllocationIndex(cpu_executable->buffer_assignment())); + + const auto& buffer_assignment = cpu_executable->buffer_assignment(); + size_t num_allocations = buffer_assignment.Allocations().size(); + + return absl::WrapUnique(new NanoRtExecutable( + std::move(executable), std::move(thread_pool), num_allocations, + std::move(argument_to_allocation_index), + std::move(result_to_allocation_index), temp_allocation_index)); } NanoRtExecutable::NanoRtExecutable( std::unique_ptr executable, - std::shared_ptr thread_pool) + std::shared_ptr thread_pool, + size_t num_allocations, std::vector argument_to_allocation_index, + std::vector result_to_allocation_index, + std::optional temp_allocation_index) : executable_(std::move(executable)), - thread_pool_(std::move(thread_pool)) {} + thread_pool_(std::move(thread_pool)), + num_allocations_(num_allocations), + argument_to_allocation_index_(std::move(argument_to_allocation_index)), + result_to_allocation_index_(std::move(result_to_allocation_index)), + temp_allocation_index_(temp_allocation_index) {} static se::DeviceMemoryBase ToDeviceMemory( const NanoRtExecutable::Argument& argument) { - return stream_executor::DeviceMemoryBase( + return se::DeviceMemoryBase( const_cast(reinterpret_cast(argument.data().data())), argument.data().size()); } static se::DeviceMemoryBase ToDeviceMemory( const NanoRtExecutable::Result& result) { - return stream_executor::DeviceMemoryBase( - reinterpret_cast(result.data().data()), result.data().size()); + return se::DeviceMemoryBase(reinterpret_cast(result.data().data()), + result.data().size()); } static se::DeviceMemoryBase ToDeviceMemory( const NanoRtExecutable::PreallocatedTemp& temp) { - return stream_executor::DeviceMemoryBase(reinterpret_cast(temp.data()), - temp.size()); + return se::DeviceMemoryBase(reinterpret_cast(temp.data()), + temp.size()); } tsl::AsyncValueRef NanoRtExecutable::Execute( @@ -87,21 +256,38 @@ tsl::AsyncValueRef NanoRtExecutable::Execute( auto* executable = tsl::down_cast(executable_.get()); - // Convert arguments, results, and temp to device memory. - absl::InlinedVector buffer_device_mem; - buffer_device_mem.reserve(arguments.size() + results.size() + 1); + size_t num_arguments = argument_to_allocation_index_.size(); + size_t num_results = result_to_allocation_index_.size(); - for (const Result& result : results) { - buffer_device_mem.emplace_back(ToDeviceMemory(result)); + if (ABSL_PREDICT_FALSE(arguments.size() != num_arguments)) { + return InvalidArgument("Expected %d arguments, got %d", num_arguments, + arguments.size()); } - for (const Argument& argument : arguments) { - buffer_device_mem.emplace_back(ToDeviceMemory(argument)); + + if (ABSL_PREDICT_FALSE(results.size() != num_results)) { + return InvalidArgument("Expected %d results, got %d", num_results, + results.size()); } - buffer_device_mem.emplace_back(ToDeviceMemory(temp)); // Prepare buffer allocations for arguments, results, and temp. - cpu::BufferAllocations allocations(buffer_device_mem); + absl::InlinedVector buffers(num_allocations_); + for (size_t i = 0; i < num_arguments; ++i) { + buffers[argument_to_allocation_index_[i]] = + MaybeOwningDeviceMemory(ToDeviceMemory(arguments[i])); + } + + for (size_t i = 0; i < num_results; ++i) { + buffers[result_to_allocation_index_[i]] = + MaybeOwningDeviceMemory(ToDeviceMemory(results[i])); + } + + if (temp_allocation_index_) { + buffers[*temp_allocation_index_] = + MaybeOwningDeviceMemory(ToDeviceMemory(temp)); + } + + cpu::BufferAllocations allocations(buffers); cpu::Thunk::ExecuteParams execute_params = { &executable->function_registry(), &allocations, diff --git a/third_party/xla/xla/backends/cpu/nanort/nanort_executable.h b/third_party/xla/xla/backends/cpu/nanort/nanort_executable.h index d8b73998808..e39a02b3a98 100644 --- a/third_party/xla/xla/backends/cpu/nanort/nanort_executable.h +++ b/third_party/xla/xla/backends/cpu/nanort/nanort_executable.h @@ -19,6 +19,8 @@ limitations under the License. #include #include #include +#include +#include #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -83,10 +85,24 @@ class NanoRtExecutable { private: NanoRtExecutable(std::unique_ptr executable, - std::shared_ptr thread_pool); + std::shared_ptr thread_pool, + size_t num_allocations, + std::vector argument_to_allocation_index, + std::vector result_to_allocation_index, + std::optional temp_allocation_index); std::unique_ptr executable_; std::shared_ptr thread_pool_; + + size_t num_allocations_; + + // A mapping from the argument/result index to the index of the corresponding + // allocation (defined by the executable's buffer assignment). + std::vector argument_to_allocation_index_; + std::vector result_to_allocation_index_; + + // Index of the temp allocation. + std::optional temp_allocation_index_; }; template