From 733d71db8844a1701d133b09e6d9b62638db017f Mon Sep 17 00:00:00 2001 From: Shaogang Wang Date: Thu, 21 Nov 2024 04:25:27 -0800 Subject: [PATCH] PR #19528: [XLA:GPU] use separte command buffer cmd flag for conditional and loop Imported from GitHub PR https://github.com/openxla/xla/pull/19528 Observed in saxml workload that sharing the same command buffer cmd type (CONDITIONALS) for WHILE and CONDITIONAL command over kill the lowering opportunities. Many cases could allow CONDITIONAL instruction to lower into command buffer, while WHILE is not possible. This PR uses separate command buffer cmd type flag for CONDITIONAL and WHILE instructions when user specifies the type to lowering. Copybara import of the project: -- 4d62fb512995e2fc6e9077a1b3251a6754c866ca by Shawn Wang : use separte command buffer cmd flag for conditional and loop Merging this change closes #19528 PiperOrigin-RevId: 698729891 --- .../gpu/runtime/command_buffer_thunk_test.cc | 2 +- .../gpu/transforms/command_buffer_scheduling.cc | 13 ++++++++----- .../transforms/command_buffer_scheduling_test.cc | 3 ++- third_party/xla/xla/xla.proto | 9 +++++---- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc index 385e762fabe..c007a614976 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -1481,7 +1481,7 @@ class CmdBufferTest : public HloTestBase { debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); debug_options.add_xla_gpu_enable_command_buffer( - DebugOptions::DYNAMIC_SLICE); + DebugOptions::DYNAMIC_SLICE_FUSION); return debug_options; } }; diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index 9900deb10e8..d7c95b08984 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -178,7 +178,7 @@ static bool IsCommand(const HloInstruction*, const CommandBufferConfig&); template <> bool IsCommand(const HloInstruction* hlo, const CommandBufferConfig& config) { - return config.enabled_commands.contains(DebugOptions::CONDITIONALS) && + return config.enabled_commands.contains(DebugOptions::WHILE) && IsCommand(hlo->while_body(), config) && IsCommand(hlo->while_condition(), config); } @@ -188,7 +188,7 @@ bool IsCommand(const HloInstruction* hlo, template <> bool IsCommand(const HloInstruction* hlo, const CommandBufferConfig& config) { - return config.enabled_commands.contains(DebugOptions::CONDITIONALS) && + return config.enabled_commands.contains(DebugOptions::CONDITIONAL) && absl::c_all_of(hlo->branch_computations(), [&](const HloComputation* comp) { return IsCommand(comp, config); @@ -261,7 +261,8 @@ static bool IsCommand(const HloInstruction* hlo, // DynamicSliceFusionRewriter currently only rewrites for dynamic slice // fusion with constant or loop iteration offset values, which are all // supported by command buffer. - return (config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE) && + return (config.enabled_commands.contains( + DebugOptions::DYNAMIC_SLICE_FUSION) && (IsCommand(hero, config) || IsAsyncStartCommand(hero, config))); } } @@ -371,7 +372,8 @@ CommandBufferScheduling::CollectCommandBufferSequences( // captured in command buffer. auto check_dynamic_slice_operand_not_from_seq = [&](const HloInstructionSequence& seq, const HloInstruction* inst) { - if (!config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE)) + if (!config.enabled_commands.contains( + DebugOptions::DYNAMIC_SLICE_FUSION)) return true; const auto* fusion = DynCast(inst); if (!fusion) return true; @@ -813,7 +815,8 @@ absl::StatusOr CommandBufferScheduling::Run( device_description_}; // Erase command buffer cmd types that are not supported by the gpu runtime. - static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONALS}; + static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONAL, + DebugOptions::WHILE}; static constexpr auto kRequireTracing = { DebugOptions::CUBLAS, DebugOptions::CUBLASLT, DebugOptions::CUDNN, DebugOptions::CUSTOM_CALL, DebugOptions::COLLECTIVES}; diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 58bb419e4e0..2c8155a1d7a 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -46,7 +46,8 @@ class CommandBufferSchedulingTest : public HloTestBase { DebugOptions GetDebugOptionsForTest() const override { auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); - debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONAL); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::WHILE); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 5cc825069ac..b0bf1b93d7b 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -604,10 +604,11 @@ message DebugOptions { CUBLAS = 2; CUDNN = 3; COLLECTIVES = 4; - CONDITIONALS = 5; - CUSTOM_CALL = 6; - CUBLASLT = 7; - DYNAMIC_SLICE = 8; + CONDITIONAL = 5; + WHILE = 6; + CUSTOM_CALL = 7; + CUBLASLT = 8; + DYNAMIC_SLICE_FUSION = 9; } // Determine the types of commands that are recorded into command buffers.