[xla:cpu] Resolve constant buffers
Some checks are pending
ARM CI / build (3.10) (push) Waiting to run
Creates a GitHub Issue when a PR Rolled back via Commit to Master / create-issue-on-pr-rollback (push) Waiting to run
Scorecards supply-chain security / Scorecards analysis (push) Waiting to run

PiperOrigin-RevId: 698625663
This commit is contained in:
Eugene Zhulenev 2024-11-20 20:31:23 -08:00 committed by TensorFlower Gardener
parent 6fb4e335c8
commit 39fb1ff1a8
2 changed files with 41 additions and 0 deletions

View File

@ -119,6 +119,37 @@ TEST(NanoRtClientTest, CompileAndRunTupledComputation) {
EXPECT_EQ(r1_value, 6.0f);
}
TEST(NanoRtClientTest, CompileAndRunConstantComputation) {
std::string_view hlo = R"(
HloModule cst
ENTRY e {
ROOT cst = f32[] constant(42.0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
XlaComputation computation(module->ToProto());
NanoRtClient client;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
client.Compile(computation));
// Storage for executable results.
alignas(32) float r0_value = 0.0f;
// Prepare executable parameters, results and temp storage.
Arguments arguments;
Results results = {{&r0_value, 1}};
NanoRtExecutable::PreallocatedTemp temp = {};
auto event = executable->Execute(arguments, results, temp);
tsl::BlockUntilReady(event);
ASSERT_TRUE(event.IsConcrete());
EXPECT_EQ(r0_value, 42.0f);
}
//===----------------------------------------------------------------------===//
// Performance benchmarks below
//===----------------------------------------------------------------------===//

View File

@ -287,6 +287,16 @@ tsl::AsyncValueRef<NanoRtExecutable::ExecuteEvent> NanoRtExecutable::Execute(
MaybeOwningDeviceMemory(ToDeviceMemory(temp));
}
for (const auto& constant : executable->constants()) {
// Constants are re-indexed by the buffer allocation index at CpuExecutable
// construction time, and `executable->constants()` actually returns the
// vector of buffer allocations, and only allocations corresponding to
// constants have a valid index.
if (constant.index >= 0) {
buffers[constant.index] = constant.AsDeviceMemoryBase();
}
}
cpu::BufferAllocations allocations(buffers);
cpu::Thunk::ExecuteParams execute_params = {
&executable->function_registry(),