diff --git a/third_party/triton/xla_extensions/series.bzl b/third_party/triton/xla_extensions/series.bzl index be33c18e17f..ac8bec0d659 100644 --- a/third_party/triton/xla_extensions/series.bzl +++ b/third_party/triton/xla_extensions/series.bzl @@ -8,6 +8,5 @@ IMPORTANT: This is a temporary hack while we are figuring out the proper way to extensions_files_patch_list = [ "//third_party/triton:xla_extensions/sparse_dot.patch", # Sparsity internal patch - "//third_party/triton:xla_extensions/sparsity_layout.patch", # Sparsity internal patch # Add new patches just above this line ] diff --git a/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/triton/xla_extensions/sparse_dot.patch index 4d4e008acac..c8e22b4a68b 100644 --- a/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/triton/xla_extensions/sparse_dot.patch @@ -171,3 +171,54 @@ index baed96a29..e9d7f5859 100644 // Create descriptor based on the format described in the spec: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor union WGMMADescriptor { +diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +index 34fb89954..a0172e107 100644 +--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp ++++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) -> Value { ++ // Allows partial TTIR to TTGIR conversion by materializing a conversion for ++ // remaining arguments that have been converted to a new type. ++ // We use this to rewrite triton_xla.sparse_dot in a separate pass after ++ // 'convert-triton-to-tritongpu'. ++ return builder.create(loc, tensorType, ++ inputs); + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonGPU conversion"); + return {}; +@@ -67,5 +73,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) -> Value { ++ // Allows partial TTIR to TTGIR conversion by materializing a conversion for ++ // remaining uses of values that have been converted to a new type. ++ // We use this to rewrite triton_xla.sparse_dot in a separate pass after ++ // 'convert-triton-to-tritongpu'. ++ return builder.create(loc, tensorType, ++ inputs); + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonGPU Conversion"); + return {}; +diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp +index df3d3b042..e38c184f6 100644 +--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp ++++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp +@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert + // heuristic to accommodate fused attention. + auto srcType = op.getSrc().getType(); + auto dstType = op.getType(); +- if (mlir::isa(dstType.getEncoding()) && +- mlir::isa(srcType.getEncoding())) ++ if (mlir::isa_and_nonnull(dstType.getEncoding()) && ++ mlir::isa_and_nonnull(srcType.getEncoding())) + return failure(); + + // for hopper MMAv3 +- if (mlir::isa(dstType.getEncoding()) && +- mlir::isa(srcType.getEncoding()) && ++ if (mlir::isa_and_nonnull(dstType.getEncoding()) && ++ mlir::isa_and_nonnull(srcType.getEncoding()) && + llvm::any_of(op.getResult().getUsers(), [](Operation *dot) { + return dot->hasTrait(); + })) { diff --git a/third_party/triton/xla_extensions/sparsity_layout.patch b/third_party/triton/xla_extensions/sparsity_layout.patch deleted file mode 100644 index 15021fdfa47..00000000000 --- a/third_party/triton/xla_extensions/sparsity_layout.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp -index 34fb89954..a0172e107 100644 ---- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp -+++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp -@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, - addArgumentMaterialization([&](OpBuilder &builder, - RankedTensorType tensorType, ValueRange inputs, - Location loc) -> Value { -+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for -+ // remaining arguments that have been converted to a new type. -+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after -+ // 'convert-triton-to-tritongpu'. -+ return builder.create(loc, tensorType, -+ inputs); - llvm_unreachable("Argument rematerialization should not happen in Triton " - "-> TritonGPU conversion"); - return {}; -@@ -67,5 +73,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, - addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, Location loc) -> Value { -+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for -+ // remaining uses of values that have been converted to a new type. -+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after -+ // 'convert-triton-to-tritongpu'. -+ return builder.create(loc, tensorType, -+ inputs); - llvm_unreachable("Source rematerialization should not happen in Triton -> " - "TritonGPU Conversion"); - return {}; -diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp -index df3d3b042..e38c184f6 100644 ---- a/lib/Dialect/TritonGPU/IR/Dialect.cpp -+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp -@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert - // heuristic to accommodate fused attention. - auto srcType = op.getSrc().getType(); - auto dstType = op.getType(); -- if (mlir::isa(dstType.getEncoding()) && -- mlir::isa(srcType.getEncoding())) -+ if (mlir::isa_and_nonnull(dstType.getEncoding()) && -+ mlir::isa_and_nonnull(srcType.getEncoding())) - return failure(); - - // for hopper MMAv3 -- if (mlir::isa(dstType.getEncoding()) && -- mlir::isa(srcType.getEncoding()) && -+ if (mlir::isa_and_nonnull(dstType.getEncoding()) && -+ mlir::isa_and_nonnull(srcType.getEncoding()) && - llvm::any_of(op.getResult().getUsers(), [](Operation *dot) { - return dot->hasTrait(); - })) { diff --git a/third_party/xla/third_party/triton/xla_extensions/series.bzl b/third_party/xla/third_party/triton/xla_extensions/series.bzl index be33c18e17f..ac8bec0d659 100644 --- a/third_party/xla/third_party/triton/xla_extensions/series.bzl +++ b/third_party/xla/third_party/triton/xla_extensions/series.bzl @@ -8,6 +8,5 @@ IMPORTANT: This is a temporary hack while we are figuring out the proper way to extensions_files_patch_list = [ "//third_party/triton:xla_extensions/sparse_dot.patch", # Sparsity internal patch - "//third_party/triton:xla_extensions/sparsity_layout.patch", # Sparsity internal patch # Add new patches just above this line ] diff --git a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch index 4d4e008acac..c8e22b4a68b 100644 --- a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch @@ -171,3 +171,54 @@ index baed96a29..e9d7f5859 100644 // Create descriptor based on the format described in the spec: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor union WGMMADescriptor { +diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +index 34fb89954..a0172e107 100644 +--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp ++++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) -> Value { ++ // Allows partial TTIR to TTGIR conversion by materializing a conversion for ++ // remaining arguments that have been converted to a new type. ++ // We use this to rewrite triton_xla.sparse_dot in a separate pass after ++ // 'convert-triton-to-tritongpu'. ++ return builder.create(loc, tensorType, ++ inputs); + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonGPU conversion"); + return {}; +@@ -67,5 +73,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) -> Value { ++ // Allows partial TTIR to TTGIR conversion by materializing a conversion for ++ // remaining uses of values that have been converted to a new type. ++ // We use this to rewrite triton_xla.sparse_dot in a separate pass after ++ // 'convert-triton-to-tritongpu'. ++ return builder.create(loc, tensorType, ++ inputs); + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonGPU Conversion"); + return {}; +diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp +index df3d3b042..e38c184f6 100644 +--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp ++++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp +@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert + // heuristic to accommodate fused attention. + auto srcType = op.getSrc().getType(); + auto dstType = op.getType(); +- if (mlir::isa(dstType.getEncoding()) && +- mlir::isa(srcType.getEncoding())) ++ if (mlir::isa_and_nonnull(dstType.getEncoding()) && ++ mlir::isa_and_nonnull(srcType.getEncoding())) + return failure(); + + // for hopper MMAv3 +- if (mlir::isa(dstType.getEncoding()) && +- mlir::isa(srcType.getEncoding()) && ++ if (mlir::isa_and_nonnull(dstType.getEncoding()) && ++ mlir::isa_and_nonnull(srcType.getEncoding()) && + llvm::any_of(op.getResult().getUsers(), [](Operation *dot) { + return dot->hasTrait(); + })) { diff --git a/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch b/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch deleted file mode 100644 index 15021fdfa47..00000000000 --- a/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp -index 34fb89954..a0172e107 100644 ---- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp -+++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp -@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, - addArgumentMaterialization([&](OpBuilder &builder, - RankedTensorType tensorType, ValueRange inputs, - Location loc) -> Value { -+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for -+ // remaining arguments that have been converted to a new type. -+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after -+ // 'convert-triton-to-tritongpu'. -+ return builder.create(loc, tensorType, -+ inputs); - llvm_unreachable("Argument rematerialization should not happen in Triton " - "-> TritonGPU conversion"); - return {}; -@@ -67,5 +73,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, - addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, Location loc) -> Value { -+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for -+ // remaining uses of values that have been converted to a new type. -+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after -+ // 'convert-triton-to-tritongpu'. -+ return builder.create(loc, tensorType, -+ inputs); - llvm_unreachable("Source rematerialization should not happen in Triton -> " - "TritonGPU Conversion"); - return {}; -diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp -index df3d3b042..e38c184f6 100644 ---- a/lib/Dialect/TritonGPU/IR/Dialect.cpp -+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp -@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert - // heuristic to accommodate fused attention. - auto srcType = op.getSrc().getType(); - auto dstType = op.getType(); -- if (mlir::isa(dstType.getEncoding()) && -- mlir::isa(srcType.getEncoding())) -+ if (mlir::isa_and_nonnull(dstType.getEncoding()) && -+ mlir::isa_and_nonnull(srcType.getEncoding())) - return failure(); - - // for hopper MMAv3 -- if (mlir::isa(dstType.getEncoding()) && -- mlir::isa(srcType.getEncoding()) && -+ if (mlir::isa_and_nonnull(dstType.getEncoding()) && -+ mlir::isa_and_nonnull(srcType.getEncoding()) && - llvm::any_of(op.getResult().getUsers(), [](Operation *dot) { - return dot->hasTrait(); - })) {