diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc index 385e762fabe..c007a614976 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -1481,7 +1481,7 @@ class CmdBufferTest : public HloTestBase { debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); debug_options.add_xla_gpu_enable_command_buffer( - DebugOptions::DYNAMIC_SLICE); + DebugOptions::DYNAMIC_SLICE_FUSION); return debug_options; } }; diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index 9900deb10e8..d7c95b08984 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -178,7 +178,7 @@ static bool IsCommand(const HloInstruction*, const CommandBufferConfig&); template <> bool IsCommand(const HloInstruction* hlo, const CommandBufferConfig& config) { - return config.enabled_commands.contains(DebugOptions::CONDITIONALS) && + return config.enabled_commands.contains(DebugOptions::WHILE) && IsCommand(hlo->while_body(), config) && IsCommand(hlo->while_condition(), config); } @@ -188,7 +188,7 @@ bool IsCommand(const HloInstruction* hlo, template <> bool IsCommand(const HloInstruction* hlo, const CommandBufferConfig& config) { - return config.enabled_commands.contains(DebugOptions::CONDITIONALS) && + return config.enabled_commands.contains(DebugOptions::CONDITIONAL) && absl::c_all_of(hlo->branch_computations(), [&](const HloComputation* comp) { return IsCommand(comp, config); @@ -261,7 +261,8 @@ static bool IsCommand(const HloInstruction* hlo, // DynamicSliceFusionRewriter currently only rewrites for dynamic slice // fusion with constant or loop iteration offset values, which are all // supported by command buffer. - return (config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE) && + return (config.enabled_commands.contains( + DebugOptions::DYNAMIC_SLICE_FUSION) && (IsCommand(hero, config) || IsAsyncStartCommand(hero, config))); } } @@ -371,7 +372,8 @@ CommandBufferScheduling::CollectCommandBufferSequences( // captured in command buffer. auto check_dynamic_slice_operand_not_from_seq = [&](const HloInstructionSequence& seq, const HloInstruction* inst) { - if (!config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE)) + if (!config.enabled_commands.contains( + DebugOptions::DYNAMIC_SLICE_FUSION)) return true; const auto* fusion = DynCast(inst); if (!fusion) return true; @@ -813,7 +815,8 @@ absl::StatusOr CommandBufferScheduling::Run( device_description_}; // Erase command buffer cmd types that are not supported by the gpu runtime. - static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONALS}; + static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONAL, + DebugOptions::WHILE}; static constexpr auto kRequireTracing = { DebugOptions::CUBLAS, DebugOptions::CUBLASLT, DebugOptions::CUDNN, DebugOptions::CUSTOM_CALL, DebugOptions::COLLECTIVES}; diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 58bb419e4e0..2c8155a1d7a 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -46,7 +46,8 @@ class CommandBufferSchedulingTest : public HloTestBase { DebugOptions GetDebugOptionsForTest() const override { auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); - debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONAL); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::WHILE); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 5cc825069ac..b0bf1b93d7b 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -604,10 +604,11 @@ message DebugOptions { CUBLAS = 2; CUDNN = 3; COLLECTIVES = 4; - CONDITIONALS = 5; - CUSTOM_CALL = 6; - CUBLASLT = 7; - DYNAMIC_SLICE = 8; + CONDITIONAL = 5; + WHILE = 6; + CUSTOM_CALL = 7; + CUBLASLT = 8; + DYNAMIC_SLICE_FUSION = 9; } // Determine the types of commands that are recorded into command buffers.