Merge sparsity_layout.patch into sparse_dot.patch

PiperOrigin-RevId: 698389323
This commit is contained in:
A. Unique TensorFlower 2024-11-20 07:50:07 -08:00 committed by TensorFlower Gardener
parent ed9291f0ce
commit ad81c08990
6 changed files with 102 additions and 104 deletions

View File

@ -8,6 +8,5 @@ IMPORTANT: This is a temporary hack while we are figuring out the proper way to
extensions_files_patch_list = [
"//third_party/triton:xla_extensions/sparse_dot.patch", # Sparsity internal patch
"//third_party/triton:xla_extensions/sparsity_layout.patch", # Sparsity internal patch
# Add new patches just above this line
]

View File

@ -171,3 +171,54 @@ index baed96a29..e9d7f5859 100644
// Create descriptor based on the format described in the spec:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
union WGMMADescriptor {
diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
index 34fb89954..a0172e107 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> Value {
+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
+ // remaining arguments that have been converted to a new type.
+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after
+ // 'convert-triton-to-tritongpu'.
+ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
+ inputs);
llvm_unreachable("Argument rematerialization should not happen in Triton "
"-> TritonGPU conversion");
return {};
@@ -67,5 +73,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) -> Value {
+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
+ // remaining uses of values that have been converted to a new type.
+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after
+ // 'convert-triton-to-tritongpu'.
+ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
+ inputs);
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return {};
diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp
index df3d3b042..e38c184f6 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
- if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
- mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
+ if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
+ mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
return failure();
// for hopper MMAv3
- if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
- mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
+ if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
+ mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
return dot->hasTrait<OpTrait::DotLike>();
})) {

View File

@ -1,51 +0,0 @@
diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
index 34fb89954..a0172e107 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> Value {
+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
+ // remaining arguments that have been converted to a new type.
+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after
+ // 'convert-triton-to-tritongpu'.
+ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
+ inputs);
llvm_unreachable("Argument rematerialization should not happen in Triton "
"-> TritonGPU conversion");
return {};
@@ -67,5 +73,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) -> Value {
+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
+ // remaining uses of values that have been converted to a new type.
+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after
+ // 'convert-triton-to-tritongpu'.
+ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
+ inputs);
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return {};
diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp
index df3d3b042..e38c184f6 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
- if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
- mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
+ if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
+ mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
return failure();
// for hopper MMAv3
- if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
- mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
+ if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
+ mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
return dot->hasTrait<OpTrait::DotLike>();
})) {

View File

@ -8,6 +8,5 @@ IMPORTANT: This is a temporary hack while we are figuring out the proper way to
extensions_files_patch_list = [
"//third_party/triton:xla_extensions/sparse_dot.patch", # Sparsity internal patch
"//third_party/triton:xla_extensions/sparsity_layout.patch", # Sparsity internal patch
# Add new patches just above this line
]

View File

@ -171,3 +171,54 @@ index baed96a29..e9d7f5859 100644
// Create descriptor based on the format described in the spec:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
union WGMMADescriptor {
diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
index 34fb89954..a0172e107 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> Value {
+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
+ // remaining arguments that have been converted to a new type.
+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after
+ // 'convert-triton-to-tritongpu'.
+ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
+ inputs);
llvm_unreachable("Argument rematerialization should not happen in Triton "
"-> TritonGPU conversion");
return {};
@@ -67,5 +73,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) -> Value {
+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
+ // remaining uses of values that have been converted to a new type.
+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after
+ // 'convert-triton-to-tritongpu'.
+ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
+ inputs);
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return {};
diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp
index df3d3b042..e38c184f6 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
- if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
- mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
+ if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
+ mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
return failure();
// for hopper MMAv3
- if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
- mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
+ if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
+ mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
return dot->hasTrait<OpTrait::DotLike>();
})) {

View File

@ -1,51 +0,0 @@
diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
index 34fb89954..a0172e107 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> Value {
+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
+ // remaining arguments that have been converted to a new type.
+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after
+ // 'convert-triton-to-tritongpu'.
+ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
+ inputs);
llvm_unreachable("Argument rematerialization should not happen in Triton "
"-> TritonGPU conversion");
return {};
@@ -67,5 +73,11 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) -> Value {
+ // Allows partial TTIR to TTGIR conversion by materializing a conversion for
+ // remaining uses of values that have been converted to a new type.
+ // We use this to rewrite triton_xla.sparse_dot in a separate pass after
+ // 'convert-triton-to-tritongpu'.
+ return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
+ inputs);
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return {};
diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp
index df3d3b042..e38c184f6 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
- if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
- mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
+ if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
+ mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
return failure();
// for hopper MMAv3
- if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
- mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
+ if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
+ mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
return dot->hasTrait<OpTrait::DotLike>();
})) {