tensorflow/third_party/stablehlo/temporary.patch
2024-11-18 12:00:31 -08:00

981 lines
41 KiB
Diff
Executable File

diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
@@ -3,14 +3,14 @@
func.func @error_illformed(%arg0: tensor<3xf32>, %arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.abs %arg0 : (tensor<3xf32>) -> tensor<?xf32>
%1 = stablehlo.abs %arg1 : (tensor<4xf32>) -> tensor<?xf32>
- // expected-error@+1{{requires the same shape for all operands and results}}
+ // expected-error@+1{{'stablehlo.add' op requires the same shape for all operands and results}}
%2 = stablehlo.add %0, %1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func.return %2 : tensor<?xf32>
}
// -----
-// expected-error@+1{{must have exactly one block}}
+// expected-error@+1{{'func.func' op must have exactly one block}}
func.func @error_too_many_blocks(%arg0: tensor<f32>) -> tensor<f32> {
cf.br ^bb1(%arg0 : tensor<f32>)
^bb1(%arg1 : tensor<f32>):
@@ -49,6 +49,7 @@
// -----
+// CHECK-LABEL: func @error_unsupported_operation
func.func @error_unsupported_operation(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> index {
// CHECK: stablehlo.add{{.*}} -> tensor<?xf32>
%0 = stablehlo.add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<?xf32>
@@ -596,10 +597,288 @@
// -----
// CHECK-LABEL: func @refine_bitcast_convert_same_bitwidth
-func.func @refine_bitcast_convert_same_bitwidth(%arg0 : tensor<4xf32>) -> tensor<?xi32> {
- // CHECK: stablehlo.bitcast_convert{{.*}} -> tensor<4xi32>
- %0 = stablehlo.bitcast_convert %arg0 : (tensor<4xf32>) -> tensor<?xi32>
- func.return %0 : tensor<?xi32>
+func.func @refine_bitcast_convert_same_bitwidth() -> tensor<?x?x0xf32> {
+ %0 = stablehlo.constant dense<[3, 5, 0]> : tensor<3xi32>
+ %21 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<3xi32>) -> tensor<?x?x0xui32>
+ // CHECK: stablehlo.bitcast_convert{{.*}} -> tensor<3x5x0xf32>
+ %48 = stablehlo.bitcast_convert %21 : (tensor<?x?x0xui32>) -> tensor<?x?x0xf32>
+ return %48 : tensor<?x?x0xf32>
+}
+
+// -----
+
+// CHECK-LABEL: module @refine_call
+module @refine_call {
+ func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
+ %0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
+ %1 = stablehlo.constant dense<4> : tensor<i32>
+ // CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
+ %2 = call @refine_call_callee(%1, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
+ return %2 : tensor<?xf32>
+ }
+ // CHECK: refine_call_callee(%arg0: tensor<4xf32>) -> tensor<4xf32>
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK: stablehlo.constant dense<4>
+ %0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
+ %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
+ return %1 : tensor<?xf32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @refine_call_dimension_arguments
+module @refine_call_dimension_arguments {
+ func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
+ // CHECK: [[RESULT:%.*]] = call @callee
+ // CHECK: return [[RESULT]]
+ %0 = stablehlo.constant dense<3> : tensor<i32>
+ %1 = call @callee(%0, %0, %arg0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
+ return %1 : tensor<i32>
+ }
+ // %arg0 and %arg1 are dimension arguments
+ // CHECK: @callee([[ARG0:%.*]]: tensor<i32>) -> tensor<i32>
+ func.func private @callee(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
+ // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
+ // CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
+ // CHECK: return [[RESULT1]]
+ %0 = stablehlo.add %arg0, %arg1: tensor<i32>
+ %1 = stablehlo.add %0, %arg2: tensor<i32>
+ return %1 : tensor<i32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @refine_call_prefix_token_and_dimension_arguments
+module @refine_call_prefix_token_and_dimension_arguments {
+ func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
+ // CHECK: [[RESULT:%.*]] = call @callee
+ // CHECK: return [[RESULT]]
+ %0 = stablehlo.constant dense<3> : tensor<i32>
+ %token = stablehlo.create_token : !stablehlo.token
+ %1 = call @callee(%token, %0, %0, %arg0) : (!stablehlo.token, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
+ return %1 : tensor<i32>
+ }
+ // %arg0 and %arg1 are dimension arguments
+ // CHECK: @callee([[ARG_TOKEN:%.*]]: !stablehlo.token, [[ARG0:%.*]]: tensor<i32>
+ func.func private @callee(%arg_token: !stablehlo.token, %arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
+ // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
+ // CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
+ // CHECK: return [[RESULT1]]
+ %0 = stablehlo.add %arg0, %arg1: tensor<i32>
+ %1 = stablehlo.add %0, %arg2: tensor<i32>
+ return %1 : tensor<i32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @refine_call_dimension_arguments_followed_by_token
+module @refine_call_dimension_arguments_followed_by_token {
+ func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
+ // CHECK: [[RESULT:%.*]] = call @callee
+ // CHECK: return [[RESULT]]
+ %0 = stablehlo.constant dense<3> : tensor<i32>
+ %token = stablehlo.create_token : !stablehlo.token
+ %1 = call @callee(%0, %0, %token, %arg0) : (tensor<i32>, tensor<i32>, !stablehlo.token, tensor<i32>) -> tensor<i32>
+ return %1 : tensor<i32>
+ }
+ // %arg0 and %arg1 are dimension arguments
+ // CHECK: @callee([[ARG_TOKEN:%.*]]: !stablehlo.token, [[ARG0:%.*]]: tensor<i32>
+ func.func private @callee(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg_token: !stablehlo.token, %arg2: tensor<i32>) -> tensor<i32> {
+ // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
+ // CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
+ // CHECK: return [[RESULT1]]
+ %0 = stablehlo.add %arg0, %arg1: tensor<i32>
+ %1 = stablehlo.add %0, %arg2: tensor<i32>
+ return %1 : tensor<i32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @refine_multiple_call_with_same_context
+module @refine_multiple_call_with_same_context {
+ func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
+ %0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
+ %arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
+ // CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
+ %1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
+ %2 = call @refine_call_callee(%arg0_new, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
+ return %2 : tensor<?xf32>
+ }
+ // CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ return %arg1 : tensor<?xf32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @refine_multiple_call_constant_function
+module @refine_multiple_call_constant_function {
+ func.func @main(%arg0: tensor<5xf32>) -> tensor<i32> {
+ // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<16>
+ // CHECK: return [[RESULT0]]
+ %0 = stablehlo.constant dense<4> : tensor<i32>
+ %1 = call @refine_call_callee(%0, %arg0) : (tensor<i32>, tensor<5xf32>) -> tensor<i32>
+ %2 = call @refine_call_callee(%0, %arg0) : (tensor<i32>, tensor<5xf32>) -> tensor<i32>
+ %3 = stablehlo.add %1, %2: tensor<i32>
+ return %3 : tensor<i32>
+ }
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> tensor<i32> {
+ // CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<8>
+ // CHECK: return [[RESULT1]]
+ %0 = stablehlo.add %arg0, %arg0: tensor<i32>
+ return %0 : tensor<i32>
+ }
+}
+
+// -----
+
+module @refine_call_multiple_with_different_number_dimension_arguments {
+ func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
+ %0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
+ %arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
+ %1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
+ // Ensure that the first argument is not a constant at the second call site
+ %arg0_different_f32 = stablehlo.bitcast_convert %arg0_new : (tensor<i32>) -> tensor<f32>
+ %arg0_different_i32 = stablehlo.bitcast_convert %arg0_different_f32 : (tensor<f32>) -> tensor<i32>
+ // expected-error@+1{{incorrect number of operands for callee}}
+ %2 = call @refine_call_callee(%arg0_different_i32, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
+ return %2 : tensor<?xf32>
+ }
+ // expected-error@+1{{'func.func' op refined with invompatible refinement keys}}
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ return %arg1 : tensor<?xf32>
+ }
+}
+
+// -----
+
+module @refine_call_multiple_different_dimension_arguments {
+ func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
+ %0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
+ %arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
+ %1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
+ %arg0_different = stablehlo.add %arg0_new, %arg0_new : tensor<i32>
+ // expected-error@+1{{incorrect number of operands for callee}}
+ %2 = call @refine_call_callee(%arg0_different, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
+ return %2 : tensor<?xf32>
+ }
+ // expected-error@+1{{'func.func' op refined with invompatible refinement keys}}
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ return %arg1 : tensor<?xf32>
+ }
+}
+
+// -----
+
+module @refine_call_multiple_different_non_dimension_arguments {
+ func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
+ %0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
+ %arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
+ %1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
+ %2 = stablehlo.constant dense<[1., 2.]> : tensor<2xf32>
+ %3 = stablehlo.concatenate %1, %2, dim = 0 : (tensor<?xf32>, tensor<2xf32>) -> tensor<?xf32>
+ // expected-error@+1{{incorrect number of operands for callee}}
+ %4 = call @refine_call_callee(%arg0_new, %3) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
+ return %4 : tensor<?xf32>
+ }
+ // expected-error@+1{{'func.func' op refined with invompatible refinement keys}}
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ return %arg1 : tensor<?xf32>
+ }
+}
+
+// -----
+
+module @refine_call_recursive {
+ func.func @main() -> tensor<i32> {
+ %0 = stablehlo.constant dense<3> : tensor<i32>
+ %1 = call @refine_call_callee(%0) : (tensor<i32>) -> tensor<i32>
+ return %1 : tensor<i32>
+ }
+ // expected-error@+1{{Function refine_call_callee is being refined recursively}}
+ func.func @refine_call_callee(%arg0: tensor<i32>) -> tensor<i32> {
+ // expected-error@+1{{incorrect number of operands}}
+ %0 = call @refine_call_callee(%arg0) : (tensor<i32>) -> tensor<i32>
+ return %0 : tensor<i32>
+ }
+}
+
+// -----
+
+module @refine_call_main_argument_unranked {
+ // CHECK-LABEL: func.func public @main(%arg0: tensor<*xi32>) -> tensor<*xi32>
+ func.func public @main(%arg0: tensor<*xi32>) -> tensor<*xi32> {
+ %2 = call @callee(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
+ return %2 : tensor<*xi32>
+ }
+ func.func private @callee(%arg0: tensor<*xi32>) -> tensor<*xi32> {
+ return %arg0 : tensor<*xi32>
+ }
+}
+
+// -----
+
+module @refine_call_main_argument_dynamic_shape {
+ // CHECK: func.func public @main(%arg0: tensor<?xi32>) -> tensor<?xi32>
+ func.func public @main(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+ %2 = call @callee(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+ return %2 : tensor<?xi32>
+ }
+ func.func private @callee(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+ return %arg0 : tensor<?xi32>
+ }
+}
+
+// -----
+
+module @refine_call_callee_argument_dynamic_shape {
+ // CHECK: func.func public @main(%arg0: tensor<1xi64>) -> tensor<?xi32>
+ func.func public @main(%arg0: tensor<1xi64>) -> tensor<?xi32> {
+ %1 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<1xi64>) -> tensor<?xi32>
+ %2 = call @callee(%1) : (tensor<?xi32>) -> tensor<?xi32>
+ return %2 : tensor<?xi32>
+ }
+ func.func private @callee(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+ return %arg0 : tensor<?xi32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @refine_call_dimension_argument_non_scalar
+// The non-scalar constant is not folded into the callee
+module @refine_call_dimension_argument_non_scalar {
+ func.func public @main() -> tensor<4xi32> {
+ // CHECK: dense<[1, 2, 3, 4]> : tensor<4xi32>
+ %0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ %1 = call @callee(%0) : (tensor<4xi32>) -> tensor<4xi32>
+ return %1 : tensor<4xi32>
+ }
+ func.func private @callee(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+ // CHECK: return %arg0 : tensor<4xi32>
+ return %arg0 : tensor<4xi32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @refine_call_dimension_argument_not_integer
+module @refine_call_dimension_argument_not_integer {
+ func.func public @main() -> tensor<f32> {
+ %0 = stablehlo.constant dense<3.> : tensor<f32>
+ // CHECK: call @callee({{.*}}) : (tensor<f32>) -> tensor<f32>
+ %2 = call @callee(%0) : (tensor<f32>) -> tensor<f32>
+ return %2 : tensor<f32>
+ }
+ func.func private @callee(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+ }
}
// -----
@@ -656,6 +935,17 @@
// -----
+// CHECK-LABEL: @refine_custom_call_operand_wrapper_unranked
+func.func @refine_custom_call_operand_wrapper_unranked(%arg0: tensor<4xi32>) -> tensor<*xi32> {
+ // CHECK-NOT: stablehlo.shape_refinement_operand_wrapper
+ %0 = stablehlo.constant dense<[4]> : tensor<1xi64>
+ %1 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %0) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<*xi32>
+ // CHECK: return %arg0 : tensor<4xi32>
+ func.return %1 : tensor<*xi32>
+}
+
+// -----
+
// CHECK-LABEL: @refine_dot_general
func.func @refine_dot_general(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<?x?x?xf32> {
// CHECK: stablehlo.dot_general{{.*}} -> tensor<2x4x5xf32>
@@ -755,6 +1045,8 @@
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<?x?xf32>
func.return %1 : tensor<?x?xf32>
}
+
+
// -----
@@ -908,6 +1200,7 @@
// -----
// TODO: Implement support for these ops.
+// * dynamic_conv (#867).
// * dynamic_fft (#1366).
// * dynamic_reduce_window (#1258).
// * dynamic_rng_bit_generator (#1344).
diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td
--- stablehlo/stablehlo/transforms/Passes.td
+++ stablehlo/stablehlo/transforms/Passes.td
@@ -356,6 +356,25 @@
%1 = stablehlo.add %arg0, %arg0 : tensor<16xf32>
```
+
+ Modules valid for shape refinement must have the following properties:
+
+ * All the dynamic shapes depend only on the input shapes (no shape
+ dependency on the input array contents). We refer to the operations that
+ depend transitively only on the input shapes (e.g., as given by
+ `stablehlo.get_dimension_size`) or global constants like the resolved
+ values of symbolic integers (i.e. tensor<Axf32> : A = 5), as `dimension`
+ operations. All dimension values can be resolved to constants through
+ inter-procedural constant folding.
+ * Intermediate functions may take a number of token arguments (of type
+ !stablehlo.token) at the start of the argument list, followed by some
+ global constant arguments which are constant integer scalars, such as the
+ resolved values of symbolic integers (i.e. tensor<Axf32> : A = 5).
+ * Some intermediate functions may return computations on global constants,
+ i.e. `floordiv` on symint values. These functions are indicated by only
+ returning constant values after refinement. These functions are inlined.
+ * All calls to a single function resolve to the same argument shapes, and no
+ recursive / co-recursive function calls are made.
}];
}
@@ -375,4 +394,5 @@
Option<"targetVersionOption", "target", "std::string", "",
"The target version. Must be a version of the form #.#.# .">,
];
-}
+ let dependentDialects = ["mlir::vhlo::VhloDialect"];
+}
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp b/stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp
@@ -18,6 +18,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
@@ -201,15 +202,7 @@
return signalPassFailure();
}
- // Verify that refinements are valid
- if (failed(validateRefinedTypes(func, refinedTypes)))
- return signalPassFailure();
-
- // Wrap refined operands in operand wrapper to keep IR valid for refinement
- wrapRefinedOperands(func, refinedTypes);
-
- // Actually update main's input types.
- refineOperandsAndUpdateFunctionSignature(func, refinedTypes);
+ if (failed(refineArguments(func, refinedTypes))) return signalPassFailure();
}
private:
@@ -218,6 +211,19 @@
} // namespace
+LogicalResult refineArguments(func::FuncOp func, TypeRange refinedTypes) {
+ // Verify that refinements are valid
+ if (failed(validateRefinedTypes(func, refinedTypes))) return failure();
+
+ // Wrap refined operands in operand wrapper to keep IR valid for refinement
+ wrapRefinedOperands(func, refinedTypes);
+
+ // Actually update main's input types.
+ refineOperandsAndUpdateFunctionSignature(func, refinedTypes);
+
+ return success();
+}
+
std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
TypeRange refinedTypes) {
return std::make_unique<StablehloRefineArgumentsPass>(refinedTypes);
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -14,15 +14,20 @@
#include "stablehlo/transforms/StablehloRefineShapes.h"
+#include <cstddef>
#include <cstdint>
+#include <tuple>
#include <utility>
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -49,6 +54,8 @@
#include "stablehlo/dialect/TypeInference.h"
#include "stablehlo/transforms/Passes.h"
+#define DEBUG_TYPE "stablehlo-refine-shapes"
+
namespace mlir {
namespace stablehlo {
@@ -63,10 +70,10 @@
<< values.size() << " types, got " << types.size();
});
- // Check whether `types` contain any new information with respect to existing
- // return types. Even if just a single dimension size out of an entire tensor
- // type got updated, using `inferMostSpecificType` ensures that we don't
- // miss that.
+ // Check whether `types` contain any new information with respect to
+ // existing return types. Even if just a single dimension size out of an
+ // entire tensor type got updated, using `inferMostSpecificType` ensures
+ // that we don't miss that.
bool needsRefinement = false;
SmallVector<Type> refinedTypes;
for (auto it : llvm::zip(values.getTypes(), types)) {
@@ -76,11 +83,12 @@
auto refinement = std::get<1>(it);
auto refinedType = hlo::inferMostSpecificType(
/*location=*/{}, {currentType, refinement});
- if (failed(refinedType))
+ if (failed(refinedType)) {
return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) {
diag << "inferMostSpecificType failed for " << currentType << " and "
<< refinement;
});
+ }
refinedTypes.push_back(*refinedType);
needsRefinement |= (currentType != *refinedType);
}
@@ -106,11 +114,8 @@
// Simply changing operand type of `func.return` won't work because
// that won't update the FunctionType of the enclosing `func.func`.
- // Nonetheless, we still want to support these ops because they are widely
- // used in StableHLO programs (although the plan of record is to replace
- // `func.return` ops in StableHLO programs with `stablehlo.return`:
- // https://github.com/openxla/stablehlo/issues/425).
if (isa<func::ReturnOp>(user)) continue;
+ if (isa<func::CallOp>(user)) continue;
// Unlike in TensorFlow's type inference pass, here we work only with
// allowlisted ops to focus our support on well-defined semantics of
@@ -244,6 +249,233 @@
namespace {
+class RefinementKey {
+ public:
+ RefinementKey(func::FuncOp func, int64_t leadingTokenOperands,
+ SmallVector<APSInt> const& globalConstants,
+ SmallVector<Type> const& functionalArgumentTypes)
+ : func(func),
+ leadingTokenOperands(leadingTokenOperands),
+ globalConstants(globalConstants),
+ functionalArgumentTypes(functionalArgumentTypes) {}
+
+ static FailureOr<RefinementKey> fromCallOp(func::CallOp callOp) {
+ LLVM_DEBUG(llvm::dbgs() << "RefinementKey::fromCallOp: "
+ << callOp.getCalleeType() << "\n");
+ int64_t leadingTokenOperands = countLeadingTokenOperands(callOp);
+ SmallVector<APSInt> globalConstants =
+ getGlobalConstants(callOp, leadingTokenOperands);
+ SmallVector<Type> functionalArgumentTypes = getFunctionalArgumentTypes(
+ callOp, leadingTokenOperands, globalConstants.size());
+
+ FlatSymbolRefAttr calleeName = callOp.getCalleeAttr();
+ const SymbolTable symbolTable(callOp->getParentOfType<ModuleOp>());
+ auto callee = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
+ callOp, calleeName.getAttr());
+ if (!callee) return callOp.emitOpError() << "cannot resolve function call";
+ return RefinementKey(callee, leadingTokenOperands, globalConstants,
+ functionalArgumentTypes);
+ }
+
+ // Getters
+ func::FuncOp getFunc() const { return func; }
+ int64_t getLeadingTokenOperands() const { return leadingTokenOperands; }
+ SmallVector<APSInt> const& getGlobalConstants() const {
+ return globalConstants;
+ }
+ SmallVector<Type> const& getFunctionalArgumentTypes() const {
+ return functionalArgumentTypes;
+ }
+
+ // Get all non global-constant args, including tokens and functional args.
+ SmallVector<Type> getAllNonGlobalConstantArgumentTypes(
+ MLIRContext& context) const {
+ SmallVector<Type> types(getLeadingTokenOperands() +
+ getFunctionalArgumentTypes().size());
+ for (size_t i = 0; i < leadingTokenOperands; ++i)
+ types[i] = stablehlo::TokenType::get(&context);
+ for (auto [i, refinedType] : llvm::enumerate(getFunctionalArgumentTypes()))
+ types[i + leadingTokenOperands] = refinedType;
+ return types;
+ }
+
+ // Utilities
+ inline std::string toString() {
+ std::string buffer;
+ llvm::raw_string_ostream os(buffer);
+ os << "RefinementKey(" << func.getName()
+ << ", toks=" << leadingTokenOperands << ", dim_args=[";
+ llvm::interleaveComma(globalConstants, os);
+ os << "], fn_args=[";
+ llvm::interleaveComma(functionalArgumentTypes, os);
+ os << "])";
+ return buffer;
+ }
+
+ private:
+ static int64_t countLeadingTokenOperands(func::CallOp callOp) {
+ int64_t nrLeadingTokenOperands = 0;
+ for (auto operand : callOp.getOperands()) {
+ if (!isa<TokenType>(operand.getType())) break;
+ nrLeadingTokenOperands++;
+ }
+ return nrLeadingTokenOperands;
+ }
+
+ // global-constant arguments follow token args, and are scalar integer
+ // constants These represent the known values of symbolic shapes sizes. I.e.
+ // tensor<Axf32> : A = constant(5)
+ static SmallVector<APSInt> getGlobalConstants(func::CallOp callOp,
+ int64_t leadingTokenOperands) {
+ SmallVector<APSInt> globalConstants;
+ auto operands = callOp.getOperands();
+ for (size_t i = leadingTokenOperands; i < operands.size(); ++i) {
+ auto operandType = dyn_cast<RankedTensorType>(operands[i].getType());
+ if (!operandType || operandType.getRank() != 0 ||
+ !operandType.getElementType().isInteger())
+ break;
+
+ SmallVector<APSInt> operand_int;
+ if (failed(hlo::matchInts(operands[i], operand_int))) break;
+ globalConstants.push_back(operand_int[0]);
+ }
+ return globalConstants;
+ }
+
+ // Functional operands are the arguments that are not global-constant
+ // arguments. These are the values that will remain after symbolic shape
+ // refinement.
+ static SmallVector<Type> getFunctionalArgumentTypes(
+ func::CallOp callOp, int64_t leadingTokenOperands,
+ int64_t globalConstantsSize) {
+ SmallVector<Type> functionalArgumentTypes;
+ auto operands = callOp.getOperands();
+ for (size_t i = leadingTokenOperands + globalConstantsSize;
+ i < operands.size(); ++i) {
+ functionalArgumentTypes.push_back(operands[i].getType());
+ }
+ return functionalArgumentTypes;
+ }
+
+ private:
+ func::FuncOp func;
+ int64_t leadingTokenOperands;
+ SmallVector<APSInt> globalConstants;
+ SmallVector<Type> functionalArgumentTypes;
+};
+
+// Per-module state for shape refinement.
+// An entry is Key is <FuncOp, SmallVector<APSInt>, SmallVector<Type>>
+// Which correlates to <func, sym_int_values, arg_types>
+class RefineShapeState {
+ public:
+ enum class RefinementState {
+ NOT_ALREADY_REFINED,
+ ALREADY_REFINED,
+ };
+
+ // Validates that we are not attempting to refine a function with a different
+ // context than previously, and are not attempting recursive refinement.
+ // Returns failure() if validation fails. On success, returns a refinement
+ // state that specifies whether the function has already been refined.
+ FailureOr<RefinementState> validateFunctionRefinement(RefinementKey key) {
+ func::FuncOp func = key.getFunc();
+ StringRef funcName = func.getName();
+
+ auto found = refinementContexts.find(func);
+ if (found == refinementContexts.end())
+ return RefinementState::NOT_ALREADY_REFINED;
+ RefinementKey prevKey = found->second;
+
+ // Since we refine until fixed point, we will refine a call to a function
+ // both for the original function and for the refined one. In the latter
+ // case, we should have empty globalConstants but everything else the
+ // same.
+ if (!key.getGlobalConstants().empty() &&
+ prevKey.getGlobalConstants() != key.getGlobalConstants())
+ return emitDifferentRefinementContextError(key.getFunc(), key, prevKey);
+
+ // Check that all non-global-constant arguments are the same.
+ // Must compare all non-global-constant types, since tokens may become
+ // leading:
+ // Refine iter1: `token, dim, token, arg` : 1 leading token
+ // Refine iter2: `token, token, arg` : 2 leading tokens
+ MLIRContext& context = *func.getContext();
+ if (key.getAllNonGlobalConstantArgumentTypes(context) !=
+ prevKey.getAllNonGlobalConstantArgumentTypes(context))
+ return emitDifferentRefinementContextError(key.getFunc(), key, prevKey);
+
+ // Don't allow recursive refinement.
+ if (llvm::is_contained(functionsBeingRefined, funcName))
+ return func.emitOpError()
+ << "Function " << funcName << " is being refined recursively\n";
+
+ return RefinementState::ALREADY_REFINED;
+ }
+
+ // Updates the state to signal the starting of a function refinement.
+ // Callers must call `finishFunctionRefinement` when done.
+ [[nodiscard]] auto createScopedFunctionRefinement(RefinementKey& key) {
+ func::FuncOp func = key.getFunc();
+ auto funcName = func.getName();
+ functionsBeingRefined.push_back(funcName);
+ refinementContexts.try_emplace(func, key);
+ // Return a cleanup function that will pop the function from the stack
+ // when it goes out of scope. This can only use values that will have the
+ // same lifetime as cleanup fn. In this case, `this` and `key` are safe.
+ return llvm::make_scope_exit([this, &key]() {
+ if (key.getFunc().getName() != functionsBeingRefined.back())
+ llvm::report_fatal_error(
+ "Stack mismatch in createScopedFunctionRefinement");
+ functionsBeingRefined.pop_back();
+ });
+ }
+
+ private:
+ // Maps refined functions to the refinement context: the values of dimension
+ // arguments and the types of non-global-constant arguments. A function is
+ // added here when we start refining it.
+ DenseMap<func::FuncOp, RefinementKey> refinementContexts;
+
+ // A stack of functions that are in the process of being refined, the current
+ // one is last.
+ SmallVector<llvm::StringRef> functionsBeingRefined;
+
+ LogicalResult emitDifferentRefinementContextError(func::FuncOp func,
+ RefinementKey key,
+ RefinementKey prevKey) {
+ return func.emitOpError()
+ << "refined with invompatible refinement keys:" << "\n curr="
+ << key.toString() << "\n prev=" << prevKey.toString();
+ }
+};
+
+// Forward declaration
+LogicalResult refineFunction(MLIRContext& context, RefineShapeState& state,
+ RefinementKey& key);
+
+// Check if a function only returns constant values, if so, return the constant
+// values that it returns.
+std::optional<SmallVector<DenseIntElementsAttr>> isConstantFunction(
+ func::FuncOp func) {
+ LLVM_DEBUG(llvm::dbgs() << "check if " << func.getName()
+ << " is a constant function\n");
+ SmallVector<DenseIntElementsAttr> returnedConstants;
+ func::ReturnOp ret = *func.getOps<func::ReturnOp>().begin();
+ bool isConstant = llvm::all_of(ret->getOperands(), [&](auto returnVal) {
+ DenseIntElementsAttr attr;
+ Operation* return_operand_def = returnVal.getDefiningOp();
+ if (return_operand_def &&
+ matchPattern(return_operand_def, m_Constant(&attr))) {
+ returnedConstants.push_back(attr);
+ return true;
+ }
+ return false;
+ });
+ if (isConstant) return returnedConstants;
+ return std::nullopt;
+}
+
// The patterns below implement shape refinement of individual ops.
// In a nutshell, they use the upstream type inference infrastructure and a
// StableHLO-specific extension to refine return types based on potentially
@@ -297,6 +529,54 @@
return refineReturnShape(rewriter, op, operandType.getShape());
}
+};
+
+struct RefineCallOpPattern : public OpRewritePattern<func::CallOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ RefineCallOpPattern(MLIRContext* context, RefineShapeState& state)
+ : OpRewritePattern<func::CallOp>(context), state(state) {}
+
+ LogicalResult matchAndRewrite(func::CallOp op,
+ PatternRewriter& rewriter) const override {
+ auto refinementKey = RefinementKey::fromCallOp(op);
+ if (failed(refinementKey)) return failure();
+ if (failed(refineFunction(*rewriter.getContext(), state, *refinementKey)))
+ return failure();
+
+ // Is the callee a constant function in this refinement context?
+ auto callee = refinementKey->getFunc();
+ std::optional<SmallVector<DenseIntElementsAttr>> constantAttrs =
+ isConstantFunction(callee);
+ if (constantAttrs.has_value()) {
+ SmallVector<Value> constants;
+ for (auto constAttr : constantAttrs.value()) {
+ constants.push_back(
+ rewriter.create<ConstantOp>(op.getLoc(), constAttr));
+ }
+ rewriter.replaceOp(op, constants);
+ return success();
+ }
+ if (!refinementKey->getGlobalConstants().empty()) {
+ // Drop the global-constant arguments, but only if necessary, or else we
+ // will end up trying to refine the new CallOp forever.
+ SmallVector<Value> newOperands;
+ auto leadingTokenOperands =
+ op.getOperands().take_front(refinementKey->getLeadingTokenOperands());
+ auto functionalOperands = op.getOperands().take_back(
+ refinementKey->getFunctionalArgumentTypes().size());
+ newOperands.append(leadingTokenOperands.begin(),
+ leadingTokenOperands.end());
+ newOperands.append(functionalOperands.begin(), functionalOperands.end());
+ op = rewriter.replaceOpWithNewOp<func::CallOp>(
+ op, op.getResultTypes(), callee.getSymName(), newOperands);
+ LLVM_DEBUG(llvm::dbgs() << "Replaced call with " << op << "\n");
+ }
+ return refineReturnTypes(rewriter, op, callee.getResultTypes());
+ }
+
+ private:
+ RefineShapeState& state;
};
struct RefineConvertOpPattern : public OpRewritePattern<ConvertOp> {
@@ -718,49 +998,116 @@
}
};
+LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
+ RefineShapeState& state) {
+ MLIRContext* context = func.getContext();
+ RewritePatternSet patterns(context);
+ GreedyRewriteConfig config;
+
+ // The algorithm behind this pass consists of a single traversal of the
+ // function. This is sufficient because we only support one function per
+ // program at the moment.
+ // TODO(#1048): Find out why .maxIterations = 1 no longer works.
+ // There have been recent refactors to applyPatternsAndFoldGreedily
+ // upstream, and that might be the reason.
+ config.useTopDownTraversal = true;
+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive;
+ config.maxIterations = 2;
+ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit;
+ config.strictMode = GreedyRewriteStrictness::AnyOp;
+
+ populateStablehloRefineShapesPatterns(&patterns, context);
+ patterns.add<RefineCallOpPattern>(context, state);
+
+ // The folding patterns implement partial evaluation of shape computations
+ // which is a critical part of implementing type refinement for ops like
+ // dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
+ // depends on the value of their shape operands.
+ populateStablehloShapeFolderPatterns(&patterns, context);
+
+ if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns), config)))
+ func.emitError("Failed to converge StablehloRefineShapes in ")
+ << config.maxIterations << " iterations";
+
+ return success();
+}
+
+LogicalResult refineFunction(MLIRContext& context, RefineShapeState& state,
+ RefinementKey& key) {
+ LLVM_DEBUG(llvm::dbgs() << "Refining: " << key.toString() << "\n");
+ auto refinementState = state.validateFunctionRefinement(key);
+ if (failed(refinementState)) return failure();
+
+ auto func = key.getFunc();
+ if (*refinementState == RefineShapeState::RefinementState::ALREADY_REFINED) {
+ LLVM_DEBUG(llvm::dbgs() << "Function " << func.getName()
+ << " already refined, skipping\n");
+ return success();
+ }
+
+ auto scopedCleanup = state.createScopedFunctionRefinement(key);
+
+ // StableHLO functions must have exactly one block.
+ if (!func.getRegion().hasOneBlock())
+ return func.emitOpError() << "must have exactly one block";
+
+ // Replace the global-constant arguments with their values and drop args.
+ // Wrap non-global-constant arguments with bitcast_convert.
+ OpBuilder builder(func.getRegion());
+ builder.setInsertionPointToStart(&func.getRegion().front());
+ int64_t leadingTokenOperands = key.getLeadingTokenOperands();
+
+ for (auto [i, dimValue] : llvm::enumerate(key.getGlobalConstants())) {
+ int64_t operandIdx = leadingTokenOperands + i;
+ BlockArgument arg = func.getArgument(operandIdx);
+ Type argType = arg.getType();
+ ShapedType argShapedType = dyn_cast<ShapedType>(argType);
+
+ if (!argShapedType)
+ return func.emitOpError()
+ << "expected global constant argument to be shaped";
+
+ auto replacement_op = builder.create<stablehlo::ConstantOp>(
+ arg.getLoc(), argType, DenseElementsAttr::get(argShapedType, dimValue));
+ arg.replaceAllUsesWith(replacement_op);
+ }
+ BitVector argIndices(func.getNumArguments());
+ size_t firstFunctionalArgument =
+ leadingTokenOperands + key.getGlobalConstants().size();
+ argIndices.set(leadingTokenOperands, firstFunctionalArgument);
+ func.eraseArguments(argIndices);
+
+ // Refine the remaining argument types, wrap with shape buffer custom calls.
+ SmallVector<Type> refinedTypes =
+ key.getAllNonGlobalConstantArgumentTypes(context);
+ if (failed(refineArguments(func, refinedTypes))) return failure();
+ LLVM_DEBUG(llvm::dbgs() << "Refined function type for " << func.getName()
+ << ": " << func.getFunctionType() << "\n");
+
+ // Now iterate into the function body and apply refinement patterns.
+ if (failed(applyShapeRefinementPatterns(func, state))) return failure();
+
+ LLVM_DEBUG(llvm::dbgs() << "refineFunction " << func.getName()
+ << ": end with type " << func.getFunctionType()
+ << "\n");
+ return success();
+}
+
struct StablehloRefineShapesPass
: public impl::StablehloRefineShapesPassBase<StablehloRefineShapesPass> {
using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase;
- LogicalResult initialize(MLIRContext* context) override {
- // The algorithm behind this pass consists of a single traversal of the
- // function. This is sufficient because we only support one function per
- // program at the moment.
- // TODO(#1048): Find out why .maxIterations = 1 no longer works.
- // There have been recent refactors to applyPatternsAndFoldGreedily
- // upstream, and that might be the reason.
- config.useTopDownTraversal = true;
- config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive;
- config.maxIterations = 2;
- config.maxNumRewrites = GreedyRewriteConfig::kNoLimit;
- config.strictMode = GreedyRewriteStrictness::AnyOp;
-
- RewritePatternSet patterns_(context);
- populateStablehloRefineShapesPatterns(&patterns_, context);
-
- // The folding patterns implement partial evaluation of shape computations
- // which is a critical part of implementing type refinement for ops like
- // dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
- // depends on the value of their shape operands.
- populateStablehloShapeFolderPatterns(&patterns_, context);
- patterns = std::move(patterns_);
-
- return success();
- }
-
void runOnOperation() override {
auto func = getStablehloRefineShapesTarget(getOperation());
if (!func) return signalPassFailure();
- if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
- func.emitError("Failed to converge StablehloRefineShapes in ")
- << config.maxIterations << " iterations";
- }
- }
-
- private:
- FrozenRewritePatternSet patterns;
- GreedyRewriteConfig config;
+ // Start with empty state, and no dim args / token args.
+ MLIRContext* context = func.getContext();
+ RefineShapeState state;
+ RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes()));
+ if (failed(refineFunction(*context, state, key)))
+ return signalPassFailure();
+ }
};
} // namespace
@@ -807,6 +1154,7 @@
MLIRContext* context) {
patterns->add<RefineAllGatherOpPattern>(context);
patterns->add<RefineBitcastConvertOpPattern>(context);
+ // patterns->add<RefineCallOpPattern>(context); // Populate requires inline
patterns->add<RefineConvertOpPattern>(context);
patterns->add<RefineConvolutionOpPattern>(context);
patterns->add<RefineCustomCallOpPattern>(context);
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h
@@ -34,6 +34,11 @@
// Returns a nullptr and emits appropriate errors if such a function cannot
// be obtained from the module.
func::FuncOp getStablehloRefineShapesTarget(ModuleOp module);
+
+// Refine the arguments of the given function using the given types.
+// Wraps all operands in a custom call to keep the IR valid during refinement.
+// %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0)
+LogicalResult refineArguments(func::FuncOp func, TypeRange refinedTypes);
// Refines the values using the given types.
// Tricky implementation details: