mirror of
https://github.com/tensorflow/tensorflow.git
synced 2024-11-21 12:51:46 +00:00
PR #19528: [XLA:GPU] use separte command buffer cmd flag for conditional and loop
Imported from GitHub PR https://github.com/openxla/xla/pull/19528 Observed in saxml workload that sharing the same command buffer cmd type (CONDITIONALS) for WHILE and CONDITIONAL command over kill the lowering opportunities. Many cases could allow CONDITIONAL instruction to lower into command buffer, while WHILE is not possible. This PR uses separate command buffer cmd type flag for CONDITIONAL and WHILE instructions when user specifies the type to lowering. Copybara import of the project: -- 4d62fb512995e2fc6e9077a1b3251a6754c866ca by Shawn Wang <shawnw@nvidia.com>: use separte command buffer cmd flag for conditional and loop Merging this change closes #19528 PiperOrigin-RevId: 698729891
This commit is contained in:
parent
b9f49aa824
commit
733d71db88
@ -1481,7 +1481,7 @@ class CmdBufferTest : public HloTestBase {
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(
|
||||
DebugOptions::DYNAMIC_SLICE);
|
||||
DebugOptions::DYNAMIC_SLICE_FUSION);
|
||||
return debug_options;
|
||||
}
|
||||
};
|
||||
|
@ -178,7 +178,7 @@ static bool IsCommand(const HloInstruction*, const CommandBufferConfig&);
|
||||
template <>
|
||||
bool IsCommand<HloOpcode::kWhile>(const HloInstruction* hlo,
|
||||
const CommandBufferConfig& config) {
|
||||
return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
|
||||
return config.enabled_commands.contains(DebugOptions::WHILE) &&
|
||||
IsCommand(hlo->while_body(), config) &&
|
||||
IsCommand(hlo->while_condition(), config);
|
||||
}
|
||||
@ -188,7 +188,7 @@ bool IsCommand<HloOpcode::kWhile>(const HloInstruction* hlo,
|
||||
template <>
|
||||
bool IsCommand<HloOpcode::kConditional>(const HloInstruction* hlo,
|
||||
const CommandBufferConfig& config) {
|
||||
return config.enabled_commands.contains(DebugOptions::CONDITIONALS) &&
|
||||
return config.enabled_commands.contains(DebugOptions::CONDITIONAL) &&
|
||||
absl::c_all_of(hlo->branch_computations(),
|
||||
[&](const HloComputation* comp) {
|
||||
return IsCommand(comp, config);
|
||||
@ -261,7 +261,8 @@ static bool IsCommand(const HloInstruction* hlo,
|
||||
// DynamicSliceFusionRewriter currently only rewrites for dynamic slice
|
||||
// fusion with constant or loop iteration offset values, which are all
|
||||
// supported by command buffer.
|
||||
return (config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE) &&
|
||||
return (config.enabled_commands.contains(
|
||||
DebugOptions::DYNAMIC_SLICE_FUSION) &&
|
||||
(IsCommand(hero, config) || IsAsyncStartCommand(hero, config)));
|
||||
}
|
||||
}
|
||||
@ -371,7 +372,8 @@ CommandBufferScheduling::CollectCommandBufferSequences(
|
||||
// captured in command buffer.
|
||||
auto check_dynamic_slice_operand_not_from_seq =
|
||||
[&](const HloInstructionSequence& seq, const HloInstruction* inst) {
|
||||
if (!config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE))
|
||||
if (!config.enabled_commands.contains(
|
||||
DebugOptions::DYNAMIC_SLICE_FUSION))
|
||||
return true;
|
||||
const auto* fusion = DynCast<HloFusionInstruction>(inst);
|
||||
if (!fusion) return true;
|
||||
@ -813,7 +815,8 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
|
||||
device_description_};
|
||||
|
||||
// Erase command buffer cmd types that are not supported by the gpu runtime.
|
||||
static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONALS};
|
||||
static constexpr auto kRequireConditionals = {DebugOptions::CONDITIONAL,
|
||||
DebugOptions::WHILE};
|
||||
static constexpr auto kRequireTracing = {
|
||||
DebugOptions::CUBLAS, DebugOptions::CUBLASLT, DebugOptions::CUDNN,
|
||||
DebugOptions::CUSTOM_CALL, DebugOptions::COLLECTIVES};
|
||||
|
@ -46,7 +46,8 @@ class CommandBufferSchedulingTest : public HloTestBase {
|
||||
DebugOptions GetDebugOptionsForTest() const override {
|
||||
auto debug_options = HloTestBase::GetDebugOptionsForTest();
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONAL);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::WHILE);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN);
|
||||
debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT);
|
||||
|
9
third_party/xla/xla/xla.proto
vendored
9
third_party/xla/xla/xla.proto
vendored
@ -604,10 +604,11 @@ message DebugOptions {
|
||||
CUBLAS = 2;
|
||||
CUDNN = 3;
|
||||
COLLECTIVES = 4;
|
||||
CONDITIONALS = 5;
|
||||
CUSTOM_CALL = 6;
|
||||
CUBLASLT = 7;
|
||||
DYNAMIC_SLICE = 8;
|
||||
CONDITIONAL = 5;
|
||||
WHILE = 6;
|
||||
CUSTOM_CALL = 7;
|
||||
CUBLASLT = 8;
|
||||
DYNAMIC_SLICE_FUSION = 9;
|
||||
}
|
||||
|
||||
// Determine the types of commands that are recorded into command buffers.
|
||||
|
Loading…
Reference in New Issue
Block a user