mirror of
https://github.com/tensorflow/tensorflow.git
synced 2024-11-21 12:51:46 +00:00
[xla:cpu] Resolve constant buffers
PiperOrigin-RevId: 698625663
This commit is contained in:
parent
6fb4e335c8
commit
39fb1ff1a8
@ -119,6 +119,37 @@ TEST(NanoRtClientTest, CompileAndRunTupledComputation) {
|
||||
EXPECT_EQ(r1_value, 6.0f);
|
||||
}
|
||||
|
||||
TEST(NanoRtClientTest, CompileAndRunConstantComputation) {
|
||||
std::string_view hlo = R"(
|
||||
HloModule cst
|
||||
|
||||
ENTRY e {
|
||||
ROOT cst = f32[] constant(42.0)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
|
||||
XlaComputation computation(module->ToProto());
|
||||
|
||||
NanoRtClient client;
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<NanoRtExecutable> executable,
|
||||
client.Compile(computation));
|
||||
|
||||
// Storage for executable results.
|
||||
alignas(32) float r0_value = 0.0f;
|
||||
|
||||
// Prepare executable parameters, results and temp storage.
|
||||
Arguments arguments;
|
||||
Results results = {{&r0_value, 1}};
|
||||
NanoRtExecutable::PreallocatedTemp temp = {};
|
||||
|
||||
auto event = executable->Execute(arguments, results, temp);
|
||||
tsl::BlockUntilReady(event);
|
||||
|
||||
ASSERT_TRUE(event.IsConcrete());
|
||||
EXPECT_EQ(r0_value, 42.0f);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Performance benchmarks below
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -287,6 +287,16 @@ tsl::AsyncValueRef<NanoRtExecutable::ExecuteEvent> NanoRtExecutable::Execute(
|
||||
MaybeOwningDeviceMemory(ToDeviceMemory(temp));
|
||||
}
|
||||
|
||||
for (const auto& constant : executable->constants()) {
|
||||
// Constants are re-indexed by the buffer allocation index at CpuExecutable
|
||||
// construction time, and `executable->constants()` actually returns the
|
||||
// vector of buffer allocations, and only allocations corresponding to
|
||||
// constants have a valid index.
|
||||
if (constant.index >= 0) {
|
||||
buffers[constant.index] = constant.AsDeviceMemoryBase();
|
||||
}
|
||||
}
|
||||
|
||||
cpu::BufferAllocations allocations(buffers);
|
||||
cpu::Thunk::ExecuteParams execute_params = {
|
||||
&executable->function_registry(),
|
||||
|
Loading…
Reference in New Issue
Block a user