From 39fb1ff1a8ebc47710e5c20917ecd793ac47dfff Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 20 Nov 2024 20:31:23 -0800 Subject: [PATCH] [xla:cpu] Resolve constant buffers PiperOrigin-RevId: 698625663 --- .../backends/cpu/nanort/nanort_client_test.cc | 31 +++++++++++++++++++ .../backends/cpu/nanort/nanort_executable.cc | 10 ++++++ 2 files changed, 41 insertions(+) diff --git a/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc b/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc index 1b817869910..94586aec454 100644 --- a/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc +++ b/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc @@ -119,6 +119,37 @@ TEST(NanoRtClientTest, CompileAndRunTupledComputation) { EXPECT_EQ(r1_value, 6.0f); } +TEST(NanoRtClientTest, CompileAndRunConstantComputation) { + std::string_view hlo = R"( + HloModule cst + + ENTRY e { + ROOT cst = f32[] constant(42.0) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + XlaComputation computation(module->ToProto()); + + NanoRtClient client; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + client.Compile(computation)); + + // Storage for executable results. + alignas(32) float r0_value = 0.0f; + + // Prepare executable parameters, results and temp storage. + Arguments arguments; + Results results = {{&r0_value, 1}}; + NanoRtExecutable::PreallocatedTemp temp = {}; + + auto event = executable->Execute(arguments, results, temp); + tsl::BlockUntilReady(event); + + ASSERT_TRUE(event.IsConcrete()); + EXPECT_EQ(r0_value, 42.0f); +} + //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc b/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc index fb4b3ceeb77..123de632081 100644 --- a/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc +++ b/third_party/xla/xla/backends/cpu/nanort/nanort_executable.cc @@ -287,6 +287,16 @@ tsl::AsyncValueRef NanoRtExecutable::Execute( MaybeOwningDeviceMemory(ToDeviceMemory(temp)); } + for (const auto& constant : executable->constants()) { + // Constants are re-indexed by the buffer allocation index at CpuExecutable + // construction time, and `executable->constants()` actually returns the + // vector of buffer allocations, and only allocations corresponding to + // constants have a valid index. + if (constant.index >= 0) { + buffers[constant.index] = constant.AsDeviceMemoryBase(); + } + } + cpu::BufferAllocations allocations(buffers); cpu::Thunk::ExecuteParams execute_params = { &executable->function_registry(),