PR #19528: [XLA:GPU] use separte command buffer cmd flag for conditional and loop
Some checks are pending
ARM CI / build (3.10) (push) Waiting to run
Creates a GitHub Issue when a PR Rolled back via Commit to Master / create-issue-on-pr-rollback (push) Waiting to run
Scorecards supply-chain security / Scorecards analysis (push) Waiting to run

Imported from GitHub PR https://github.com/openxla/xla/pull/19528

Observed in saxml workload that sharing the same command buffer cmd type (CONDITIONALS) for WHILE and CONDITIONAL command over kill the lowering opportunities.

Many cases could allow CONDITIONAL instruction to lower into command buffer, while WHILE is not possible.

This PR uses separate command buffer cmd type flag for CONDITIONAL and WHILE instructions when user specifies the type to lowering.
Copybara import of the project:

--
4d62fb512995e2fc6e9077a1b3251a6754c866ca by Shawn Wang <shawnw@nvidia.com>:

use separte command buffer cmd flag for conditional and loop

Merging this change closes #19528

PiperOrigin-RevId: 698729891
This commit is contained in:
Shaogang Wang 2024-11-21 04:25:27 -08:00 committed by TensorFlower Gardener
parent b9f49aa824
commit 733d71db88
4 changed files with 16 additions and 11 deletions

View File

@ -1481,7 +1481,7 @@ class CmdBufferTest : public HloTestBase {
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
debug_options.add_xla_gpu_enable_command_buffer(
DebugOptions::DYNAMIC_SLICE);
DebugOptions::DYNAMIC_SLICE_FUSION);
return debug_options;
}
};

View File

@ -178,7 +178,7 @@ static bool IsCommand(const HloInstruction*, const CommandBufferConfig&);
template <>
bool IsCommand<HloOpcode::kWhile>(const HloInstruction* hlo,
const CommandBufferConfig& config) {
return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
return config.enabled_commands.contains(DebugOptions::WHILE) &&
IsCommand(hlo->while_body(), config) &&
IsCommand(hlo->while_condition(), config);
}
@ -188,7 +188,7 @@ bool IsCommand<HloOpcode::kWhile>(const HloInstruction* hlo,
template <>
bool IsCommand<HloOpcode::kConditional>(const HloInstruction* hlo,
const CommandBufferConfig& config) {
return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
return config.enabled_commands.contains(DebugOptions::CONDITIONAL) &&
absl::c_all_of(hlo->branch_computations(),
[&](const HloComputation* comp) {
return IsCommand(comp, config);
@ -261,7 +261,8 @@ static bool IsCommand(const HloInstruction* hlo,
// DynamicSliceFusionRewriter currently only rewrites for dynamic slice
// fusion with constant or loop iteration offset values, which are all
// supported by command buffer.
return (config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE) &&
return (config.enabled_commands.contains(
DebugOptions::DYNAMIC_SLICE_FUSION) &&
(IsCommand(hero, config) || IsAsyncStartCommand(hero, config)));
}
}
@ -371,7 +372,8 @@ CommandBufferScheduling::CollectCommandBufferSequences(
// captured in command buffer.
auto check_dynamic_slice_operand_not_from_seq =
[&](const HloInstructionSequence& seq, const HloInstruction* inst) {
if (!config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE))
if (!config.enabled_commands.contains(
DebugOptions::DYNAMIC_SLICE_FUSION))
return true;
const auto* fusion = DynCast<HloFusionInstruction>(inst);
if (!fusion) return true;
@ -813,7 +815,8 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
device_description_};
// Erase command buffer cmd types that are not supported by the gpu runtime.
static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONALS};
static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONAL,
DebugOptions::WHILE};
static constexpr auto kRequireTracing = {
DebugOptions::CUBLAS, DebugOptions::CUBLASLT, DebugOptions::CUDNN,
DebugOptions::CUSTOM_CALL, DebugOptions::COLLECTIVES};

View File

@ -46,7 +46,8 @@ class CommandBufferSchedulingTest : public HloTestBase {
DebugOptions GetDebugOptionsForTest() const override {
auto debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONAL);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::WHILE);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT);

View File

@ -604,10 +604,11 @@ message DebugOptions {
CUBLAS = 2;
CUDNN = 3;
COLLECTIVES = 4;
CONDITIONALS = 5;
CUSTOM_CALL = 6;
CUBLASLT = 7;
DYNAMIC_SLICE = 8;
CONDITIONAL = 5;
WHILE = 6;
CUSTOM_CALL = 7;
CUBLASLT = 8;
DYNAMIC_SLICE_FUSION = 9;
}
// Determine the types of commands that are recorded into command buffers.