mirror of
https://github.com/tensorflow/tensorflow.git
synced 2024-11-21 21:05:19 +00:00
591c66ee24
PiperOrigin-RevId: 697695971
981 lines
41 KiB
Diff
Executable File
981 lines
41 KiB
Diff
Executable File
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
|
|
--- stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
|
|
+++ stablehlo/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
|
|
@@ -3,14 +3,14 @@
|
|
func.func @error_illformed(%arg0: tensor<3xf32>, %arg1: tensor<4xf32>) -> tensor<?xf32> {
|
|
%0 = stablehlo.abs %arg0 : (tensor<3xf32>) -> tensor<?xf32>
|
|
%1 = stablehlo.abs %arg1 : (tensor<4xf32>) -> tensor<?xf32>
|
|
- // expected-error@+1{{requires the same shape for all operands and results}}
|
|
+ // expected-error@+1{{'stablehlo.add' op requires the same shape for all operands and results}}
|
|
%2 = stablehlo.add %0, %1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
func.return %2 : tensor<?xf32>
|
|
}
|
|
|
|
// -----
|
|
|
|
-// expected-error@+1{{must have exactly one block}}
|
|
+// expected-error@+1{{'func.func' op must have exactly one block}}
|
|
func.func @error_too_many_blocks(%arg0: tensor<f32>) -> tensor<f32> {
|
|
cf.br ^bb1(%arg0 : tensor<f32>)
|
|
^bb1(%arg1 : tensor<f32>):
|
|
@@ -49,6 +49,7 @@
|
|
|
|
// -----
|
|
|
|
+// CHECK-LABEL: func @error_unsupported_operation
|
|
func.func @error_unsupported_operation(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> index {
|
|
// CHECK: stablehlo.add{{.*}} -> tensor<?xf32>
|
|
%0 = stablehlo.add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<?xf32>
|
|
@@ -596,10 +597,288 @@
|
|
// -----
|
|
|
|
// CHECK-LABEL: func @refine_bitcast_convert_same_bitwidth
|
|
-func.func @refine_bitcast_convert_same_bitwidth(%arg0 : tensor<4xf32>) -> tensor<?xi32> {
|
|
- // CHECK: stablehlo.bitcast_convert{{.*}} -> tensor<4xi32>
|
|
- %0 = stablehlo.bitcast_convert %arg0 : (tensor<4xf32>) -> tensor<?xi32>
|
|
- func.return %0 : tensor<?xi32>
|
|
+func.func @refine_bitcast_convert_same_bitwidth() -> tensor<?x?x0xf32> {
|
|
+ %0 = stablehlo.constant dense<[3, 5, 0]> : tensor<3xi32>
|
|
+ %21 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<3xi32>) -> tensor<?x?x0xui32>
|
|
+ // CHECK: stablehlo.bitcast_convert{{.*}} -> tensor<3x5x0xf32>
|
|
+ %48 = stablehlo.bitcast_convert %21 : (tensor<?x?x0xui32>) -> tensor<?x?x0xf32>
|
|
+ return %48 : tensor<?x?x0xf32>
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+// CHECK-LABEL: module @refine_call
|
|
+module @refine_call {
|
|
+ func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
|
|
+ %0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
|
|
+ %1 = stablehlo.constant dense<4> : tensor<i32>
|
|
+ // CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
|
|
+ %2 = call @refine_call_callee(%1, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
|
+ return %2 : tensor<?xf32>
|
|
+ }
|
|
+ // CHECK: refine_call_callee(%arg0: tensor<4xf32>) -> tensor<4xf32>
|
|
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|
+ // CHECK: stablehlo.constant dense<4>
|
|
+ %0 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
|
|
+ %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
|
|
+ return %1 : tensor<?xf32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+// CHECK-LABEL: module @refine_call_dimension_arguments
|
|
+module @refine_call_dimension_arguments {
|
|
+ func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
|
|
+ // CHECK: [[RESULT:%.*]] = call @callee
|
|
+ // CHECK: return [[RESULT]]
|
|
+ %0 = stablehlo.constant dense<3> : tensor<i32>
|
|
+ %1 = call @callee(%0, %0, %arg0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
|
+ return %1 : tensor<i32>
|
|
+ }
|
|
+ // %arg0 and %arg1 are dimension arguments
|
|
+ // CHECK: @callee([[ARG0:%.*]]: tensor<i32>) -> tensor<i32>
|
|
+ func.func private @callee(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
|
|
+ // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
|
|
+ // CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
|
|
+ // CHECK: return [[RESULT1]]
|
|
+ %0 = stablehlo.add %arg0, %arg1: tensor<i32>
|
|
+ %1 = stablehlo.add %0, %arg2: tensor<i32>
|
|
+ return %1 : tensor<i32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+// CHECK-LABEL: module @refine_call_prefix_token_and_dimension_arguments
|
|
+module @refine_call_prefix_token_and_dimension_arguments {
|
|
+ func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
|
|
+ // CHECK: [[RESULT:%.*]] = call @callee
|
|
+ // CHECK: return [[RESULT]]
|
|
+ %0 = stablehlo.constant dense<3> : tensor<i32>
|
|
+ %token = stablehlo.create_token : !stablehlo.token
|
|
+ %1 = call @callee(%token, %0, %0, %arg0) : (!stablehlo.token, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
|
+ return %1 : tensor<i32>
|
|
+ }
|
|
+ // %arg0 and %arg1 are dimension arguments
|
|
+ // CHECK: @callee([[ARG_TOKEN:%.*]]: !stablehlo.token, [[ARG0:%.*]]: tensor<i32>
|
|
+ func.func private @callee(%arg_token: !stablehlo.token, %arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
|
|
+ // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
|
|
+ // CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
|
|
+ // CHECK: return [[RESULT1]]
|
|
+ %0 = stablehlo.add %arg0, %arg1: tensor<i32>
|
|
+ %1 = stablehlo.add %0, %arg2: tensor<i32>
|
|
+ return %1 : tensor<i32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+// CHECK-LABEL: module @refine_call_dimension_arguments_followed_by_token
|
|
+module @refine_call_dimension_arguments_followed_by_token {
|
|
+ func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
|
|
+ // CHECK: [[RESULT:%.*]] = call @callee
|
|
+ // CHECK: return [[RESULT]]
|
|
+ %0 = stablehlo.constant dense<3> : tensor<i32>
|
|
+ %token = stablehlo.create_token : !stablehlo.token
|
|
+ %1 = call @callee(%0, %0, %token, %arg0) : (tensor<i32>, tensor<i32>, !stablehlo.token, tensor<i32>) -> tensor<i32>
|
|
+ return %1 : tensor<i32>
|
|
+ }
|
|
+ // %arg0 and %arg1 are dimension arguments
|
|
+ // CHECK: @callee([[ARG_TOKEN:%.*]]: !stablehlo.token, [[ARG0:%.*]]: tensor<i32>
|
|
+ func.func private @callee(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg_token: !stablehlo.token, %arg2: tensor<i32>) -> tensor<i32> {
|
|
+ // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<6>
|
|
+ // CHECK: [[RESULT1:%.*]] = stablehlo.add [[RESULT0]], [[ARG0]]
|
|
+ // CHECK: return [[RESULT1]]
|
|
+ %0 = stablehlo.add %arg0, %arg1: tensor<i32>
|
|
+ %1 = stablehlo.add %0, %arg2: tensor<i32>
|
|
+ return %1 : tensor<i32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+// CHECK-LABEL: module @refine_multiple_call_with_same_context
|
|
+module @refine_multiple_call_with_same_context {
|
|
+ func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
|
|
+ %0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
|
|
+ %arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
|
|
+ // CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
|
|
+ %1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
|
+ // CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
|
|
+ %2 = call @refine_call_callee(%arg0_new, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
|
+ return %2 : tensor<?xf32>
|
|
+ }
|
|
+ // CHECK: refine_call_callee{{.*}}-> tensor<4xf32>
|
|
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|
+ return %arg1 : tensor<?xf32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+// CHECK-LABEL: module @refine_multiple_call_constant_function
|
|
+module @refine_multiple_call_constant_function {
|
|
+ func.func @main(%arg0: tensor<5xf32>) -> tensor<i32> {
|
|
+ // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<16>
|
|
+ // CHECK: return [[RESULT0]]
|
|
+ %0 = stablehlo.constant dense<4> : tensor<i32>
|
|
+ %1 = call @refine_call_callee(%0, %arg0) : (tensor<i32>, tensor<5xf32>) -> tensor<i32>
|
|
+ %2 = call @refine_call_callee(%0, %arg0) : (tensor<i32>, tensor<5xf32>) -> tensor<i32>
|
|
+ %3 = stablehlo.add %1, %2: tensor<i32>
|
|
+ return %3 : tensor<i32>
|
|
+ }
|
|
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> tensor<i32> {
|
|
+ // CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<8>
|
|
+ // CHECK: return [[RESULT1]]
|
|
+ %0 = stablehlo.add %arg0, %arg0: tensor<i32>
|
|
+ return %0 : tensor<i32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+module @refine_call_multiple_with_different_number_dimension_arguments {
|
|
+ func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
|
|
+ %0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
|
|
+ %arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
|
|
+ %1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
|
+ // Ensure that the first argument is not a constant at the second call site
|
|
+ %arg0_different_f32 = stablehlo.bitcast_convert %arg0_new : (tensor<i32>) -> tensor<f32>
|
|
+ %arg0_different_i32 = stablehlo.bitcast_convert %arg0_different_f32 : (tensor<f32>) -> tensor<i32>
|
|
+ // expected-error@+1{{incorrect number of operands for callee}}
|
|
+ %2 = call @refine_call_callee(%arg0_different_i32, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
|
+ return %2 : tensor<?xf32>
|
|
+ }
|
|
+ // expected-error@+1{{'func.func' op refined with invompatible refinement keys}}
|
|
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|
+ return %arg1 : tensor<?xf32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+module @refine_call_multiple_different_dimension_arguments {
|
|
+ func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
|
|
+ %0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
|
|
+ %arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
|
|
+ %1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
|
+ %arg0_different = stablehlo.add %arg0_new, %arg0_new : tensor<i32>
|
|
+ // expected-error@+1{{incorrect number of operands for callee}}
|
|
+ %2 = call @refine_call_callee(%arg0_different, %1) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
|
+ return %2 : tensor<?xf32>
|
|
+ }
|
|
+ // expected-error@+1{{'func.func' op refined with invompatible refinement keys}}
|
|
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|
+ return %arg1 : tensor<?xf32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+module @refine_call_multiple_different_non_dimension_arguments {
|
|
+ func.func @main(%arg1: tensor<4xf32>) -> tensor<?xf32> {
|
|
+ %0 = stablehlo.bitcast_convert %arg1 : (tensor<4xf32>) -> tensor<?xf32>
|
|
+ %arg0_new = "stablehlo.get_dimension_size"(%0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
|
|
+ %1 = call @refine_call_callee(%arg0_new, %0) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
|
+ %2 = stablehlo.constant dense<[1., 2.]> : tensor<2xf32>
|
|
+ %3 = stablehlo.concatenate %1, %2, dim = 0 : (tensor<?xf32>, tensor<2xf32>) -> tensor<?xf32>
|
|
+ // expected-error@+1{{incorrect number of operands for callee}}
|
|
+ %4 = call @refine_call_callee(%arg0_new, %3) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
|
+ return %4 : tensor<?xf32>
|
|
+ }
|
|
+ // expected-error@+1{{'func.func' op refined with invompatible refinement keys}}
|
|
+ func.func @refine_call_callee(%arg0: tensor<i32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|
+ return %arg1 : tensor<?xf32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+module @refine_call_recursive {
|
|
+ func.func @main() -> tensor<i32> {
|
|
+ %0 = stablehlo.constant dense<3> : tensor<i32>
|
|
+ %1 = call @refine_call_callee(%0) : (tensor<i32>) -> tensor<i32>
|
|
+ return %1 : tensor<i32>
|
|
+ }
|
|
+ // expected-error@+1{{Function refine_call_callee is being refined recursively}}
|
|
+ func.func @refine_call_callee(%arg0: tensor<i32>) -> tensor<i32> {
|
|
+ // expected-error@+1{{incorrect number of operands}}
|
|
+ %0 = call @refine_call_callee(%arg0) : (tensor<i32>) -> tensor<i32>
|
|
+ return %0 : tensor<i32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+module @refine_call_main_argument_unranked {
|
|
+ // CHECK-LABEL: func.func public @main(%arg0: tensor<*xi32>) -> tensor<*xi32>
|
|
+ func.func public @main(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
|
+ %2 = call @callee(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
|
+ return %2 : tensor<*xi32>
|
|
+ }
|
|
+ func.func private @callee(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
|
+ return %arg0 : tensor<*xi32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+module @refine_call_main_argument_dynamic_shape {
|
|
+ // CHECK: func.func public @main(%arg0: tensor<?xi32>) -> tensor<?xi32>
|
|
+ func.func public @main(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
|
+ %2 = call @callee(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
|
+ return %2 : tensor<?xi32>
|
|
+ }
|
|
+ func.func private @callee(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
|
+ return %arg0 : tensor<?xi32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+module @refine_call_callee_argument_dynamic_shape {
|
|
+ // CHECK: func.func public @main(%arg0: tensor<1xi64>) -> tensor<?xi32>
|
|
+ func.func public @main(%arg0: tensor<1xi64>) -> tensor<?xi32> {
|
|
+ %1 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<1xi64>) -> tensor<?xi32>
|
|
+ %2 = call @callee(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
|
+ return %2 : tensor<?xi32>
|
|
+ }
|
|
+ func.func private @callee(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
|
+ return %arg0 : tensor<?xi32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+// CHECK-LABEL: module @refine_call_dimension_argument_non_scalar
|
|
+// The non-scalar constant is not folded into the callee
|
|
+module @refine_call_dimension_argument_non_scalar {
|
|
+ func.func public @main() -> tensor<4xi32> {
|
|
+ // CHECK: dense<[1, 2, 3, 4]> : tensor<4xi32>
|
|
+ %0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
|
+ %1 = call @callee(%0) : (tensor<4xi32>) -> tensor<4xi32>
|
|
+ return %1 : tensor<4xi32>
|
|
+ }
|
|
+ func.func private @callee(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
|
+ // CHECK: return %arg0 : tensor<4xi32>
|
|
+ return %arg0 : tensor<4xi32>
|
|
+ }
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
+// CHECK-LABEL: module @refine_call_dimension_argument_not_integer
|
|
+module @refine_call_dimension_argument_not_integer {
|
|
+ func.func public @main() -> tensor<f32> {
|
|
+ %0 = stablehlo.constant dense<3.> : tensor<f32>
|
|
+ // CHECK: call @callee({{.*}}) : (tensor<f32>) -> tensor<f32>
|
|
+ %2 = call @callee(%0) : (tensor<f32>) -> tensor<f32>
|
|
+ return %2 : tensor<f32>
|
|
+ }
|
|
+ func.func private @callee(%arg0: tensor<f32>) -> tensor<f32> {
|
|
+ return %arg0 : tensor<f32>
|
|
+ }
|
|
}
|
|
|
|
// -----
|
|
@@ -656,6 +935,17 @@
|
|
|
|
// -----
|
|
|
|
+// CHECK-LABEL: @refine_custom_call_operand_wrapper_unranked
|
|
+func.func @refine_custom_call_operand_wrapper_unranked(%arg0: tensor<4xi32>) -> tensor<*xi32> {
|
|
+ // CHECK-NOT: stablehlo.shape_refinement_operand_wrapper
|
|
+ %0 = stablehlo.constant dense<[4]> : tensor<1xi64>
|
|
+ %1 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %0) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<*xi32>
|
|
+ // CHECK: return %arg0 : tensor<4xi32>
|
|
+ func.return %1 : tensor<*xi32>
|
|
+}
|
|
+
|
|
+// -----
|
|
+
|
|
// CHECK-LABEL: @refine_dot_general
|
|
func.func @refine_dot_general(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<?x?x?xf32> {
|
|
// CHECK: stablehlo.dot_general{{.*}} -> tensor<2x4x5xf32>
|
|
@@ -755,6 +1045,8 @@
|
|
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<?x?xf32>
|
|
func.return %1 : tensor<?x?xf32>
|
|
}
|
|
+
|
|
+
|
|
|
|
// -----
|
|
|
|
@@ -908,6 +1200,7 @@
|
|
// -----
|
|
|
|
// TODO: Implement support for these ops.
|
|
+// * dynamic_conv (#867).
|
|
// * dynamic_fft (#1366).
|
|
// * dynamic_reduce_window (#1258).
|
|
// * dynamic_rng_bit_generator (#1344).
|
|
diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td
|
|
--- stablehlo/stablehlo/transforms/Passes.td
|
|
+++ stablehlo/stablehlo/transforms/Passes.td
|
|
@@ -356,6 +356,25 @@
|
|
|
|
%1 = stablehlo.add %arg0, %arg0 : tensor<16xf32>
|
|
```
|
|
+
|
|
+ Modules valid for shape refinement must have the following properties:
|
|
+
|
|
+ * All the dynamic shapes depend only on the input shapes (no shape
|
|
+ dependency on the input array contents). We refer to the operations that
|
|
+ depend transitively only on the input shapes (e.g., as given by
|
|
+ `stablehlo.get_dimension_size`) or global constants like the resolved
|
|
+ values of symbolic integers (i.e. tensor<Axf32> : A = 5), as `dimension`
|
|
+ operations. All dimension values can be resolved to constants through
|
|
+ inter-procedural constant folding.
|
|
+ * Intermediate functions may take a number of token arguments (of type
|
|
+ !stablehlo.token) at the start of the argument list, followed by some
|
|
+ global constant arguments which are constant integer scalars, such as the
|
|
+ resolved values of symbolic integers (i.e. tensor<Axf32> : A = 5).
|
|
+ * Some intermediate functions may return computations on global constants,
|
|
+ i.e. `floordiv` on symint values. These functions are indicated by only
|
|
+ returning constant values after refinement. These functions are inlined.
|
|
+ * All calls to a single function resolve to the same argument shapes, and no
|
|
+ recursive / co-recursive function calls are made.
|
|
}];
|
|
}
|
|
|
|
@@ -375,4 +394,5 @@
|
|
Option<"targetVersionOption", "target", "std::string", "",
|
|
"The target version. Must be a version of the form #.#.# .">,
|
|
];
|
|
-}
|
|
+ let dependentDialects = ["mlir::vhlo::VhloDialect"];
|
|
+}
|
|
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp b/stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp
|
|
--- stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp
|
|
+++ stablehlo/stablehlo/transforms/StablehloRefineArguments.cpp
|
|
@@ -18,6 +18,7 @@
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
+#include "llvm/Support/LogicalResult.h"
|
|
#include "mlir/AsmParser/AsmParser.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
@@ -201,15 +202,7 @@
|
|
return signalPassFailure();
|
|
}
|
|
|
|
- // Verify that refinements are valid
|
|
- if (failed(validateRefinedTypes(func, refinedTypes)))
|
|
- return signalPassFailure();
|
|
-
|
|
- // Wrap refined operands in operand wrapper to keep IR valid for refinement
|
|
- wrapRefinedOperands(func, refinedTypes);
|
|
-
|
|
- // Actually update main's input types.
|
|
- refineOperandsAndUpdateFunctionSignature(func, refinedTypes);
|
|
+ if (failed(refineArguments(func, refinedTypes))) return signalPassFailure();
|
|
}
|
|
|
|
private:
|
|
@@ -218,6 +211,19 @@
|
|
|
|
} // namespace
|
|
|
|
+LogicalResult refineArguments(func::FuncOp func, TypeRange refinedTypes) {
|
|
+ // Verify that refinements are valid
|
|
+ if (failed(validateRefinedTypes(func, refinedTypes))) return failure();
|
|
+
|
|
+ // Wrap refined operands in operand wrapper to keep IR valid for refinement
|
|
+ wrapRefinedOperands(func, refinedTypes);
|
|
+
|
|
+ // Actually update main's input types.
|
|
+ refineOperandsAndUpdateFunctionSignature(func, refinedTypes);
|
|
+
|
|
+ return success();
|
|
+}
|
|
+
|
|
std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
|
|
TypeRange refinedTypes) {
|
|
return std::make_unique<StablehloRefineArgumentsPass>(refinedTypes);
|
|
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
|
|
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
|
|
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
|
|
@@ -14,15 +14,20 @@
|
|
|
|
#include "stablehlo/transforms/StablehloRefineShapes.h"
|
|
|
|
+#include <cstddef>
|
|
#include <cstdint>
|
|
+#include <tuple>
|
|
#include <utility>
|
|
|
|
#include "llvm/ADT/APInt.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
+#include "llvm/ADT/ScopeExit.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
+#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
+#include "llvm/Support/raw_ostream.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
@@ -49,6 +54,8 @@
|
|
#include "stablehlo/dialect/TypeInference.h"
|
|
#include "stablehlo/transforms/Passes.h"
|
|
|
|
+#define DEBUG_TYPE "stablehlo-refine-shapes"
|
|
+
|
|
namespace mlir {
|
|
namespace stablehlo {
|
|
|
|
@@ -63,10 +70,10 @@
|
|
<< values.size() << " types, got " << types.size();
|
|
});
|
|
|
|
- // Check whether `types` contain any new information with respect to existing
|
|
- // return types. Even if just a single dimension size out of an entire tensor
|
|
- // type got updated, using `inferMostSpecificType` ensures that we don't
|
|
- // miss that.
|
|
+ // Check whether `types` contain any new information with respect to
|
|
+ // existing return types. Even if just a single dimension size out of an
|
|
+ // entire tensor type got updated, using `inferMostSpecificType` ensures
|
|
+ // that we don't miss that.
|
|
bool needsRefinement = false;
|
|
SmallVector<Type> refinedTypes;
|
|
for (auto it : llvm::zip(values.getTypes(), types)) {
|
|
@@ -76,11 +83,12 @@
|
|
auto refinement = std::get<1>(it);
|
|
auto refinedType = hlo::inferMostSpecificType(
|
|
/*location=*/{}, {currentType, refinement});
|
|
- if (failed(refinedType))
|
|
+ if (failed(refinedType)) {
|
|
return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) {
|
|
diag << "inferMostSpecificType failed for " << currentType << " and "
|
|
<< refinement;
|
|
});
|
|
+ }
|
|
refinedTypes.push_back(*refinedType);
|
|
needsRefinement |= (currentType != *refinedType);
|
|
}
|
|
@@ -106,11 +114,8 @@
|
|
|
|
// Simply changing operand type of `func.return` won't work because
|
|
// that won't update the FunctionType of the enclosing `func.func`.
|
|
- // Nonetheless, we still want to support these ops because they are widely
|
|
- // used in StableHLO programs (although the plan of record is to replace
|
|
- // `func.return` ops in StableHLO programs with `stablehlo.return`:
|
|
- // https://github.com/openxla/stablehlo/issues/425).
|
|
if (isa<func::ReturnOp>(user)) continue;
|
|
+ if (isa<func::CallOp>(user)) continue;
|
|
|
|
// Unlike in TensorFlow's type inference pass, here we work only with
|
|
// allowlisted ops to focus our support on well-defined semantics of
|
|
@@ -244,6 +249,233 @@
|
|
|
|
namespace {
|
|
|
|
+class RefinementKey {
|
|
+ public:
|
|
+ RefinementKey(func::FuncOp func, int64_t leadingTokenOperands,
|
|
+ SmallVector<APSInt> const& globalConstants,
|
|
+ SmallVector<Type> const& functionalArgumentTypes)
|
|
+ : func(func),
|
|
+ leadingTokenOperands(leadingTokenOperands),
|
|
+ globalConstants(globalConstants),
|
|
+ functionalArgumentTypes(functionalArgumentTypes) {}
|
|
+
|
|
+ static FailureOr<RefinementKey> fromCallOp(func::CallOp callOp) {
|
|
+ LLVM_DEBUG(llvm::dbgs() << "RefinementKey::fromCallOp: "
|
|
+ << callOp.getCalleeType() << "\n");
|
|
+ int64_t leadingTokenOperands = countLeadingTokenOperands(callOp);
|
|
+ SmallVector<APSInt> globalConstants =
|
|
+ getGlobalConstants(callOp, leadingTokenOperands);
|
|
+ SmallVector<Type> functionalArgumentTypes = getFunctionalArgumentTypes(
|
|
+ callOp, leadingTokenOperands, globalConstants.size());
|
|
+
|
|
+ FlatSymbolRefAttr calleeName = callOp.getCalleeAttr();
|
|
+ const SymbolTable symbolTable(callOp->getParentOfType<ModuleOp>());
|
|
+ auto callee = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
|
|
+ callOp, calleeName.getAttr());
|
|
+ if (!callee) return callOp.emitOpError() << "cannot resolve function call";
|
|
+ return RefinementKey(callee, leadingTokenOperands, globalConstants,
|
|
+ functionalArgumentTypes);
|
|
+ }
|
|
+
|
|
+ // Getters
|
|
+ func::FuncOp getFunc() const { return func; }
|
|
+ int64_t getLeadingTokenOperands() const { return leadingTokenOperands; }
|
|
+ SmallVector<APSInt> const& getGlobalConstants() const {
|
|
+ return globalConstants;
|
|
+ }
|
|
+ SmallVector<Type> const& getFunctionalArgumentTypes() const {
|
|
+ return functionalArgumentTypes;
|
|
+ }
|
|
+
|
|
+ // Get all non global-constant args, including tokens and functional args.
|
|
+ SmallVector<Type> getAllNonGlobalConstantArgumentTypes(
|
|
+ MLIRContext& context) const {
|
|
+ SmallVector<Type> types(getLeadingTokenOperands() +
|
|
+ getFunctionalArgumentTypes().size());
|
|
+ for (size_t i = 0; i < leadingTokenOperands; ++i)
|
|
+ types[i] = stablehlo::TokenType::get(&context);
|
|
+ for (auto [i, refinedType] : llvm::enumerate(getFunctionalArgumentTypes()))
|
|
+ types[i + leadingTokenOperands] = refinedType;
|
|
+ return types;
|
|
+ }
|
|
+
|
|
+ // Utilities
|
|
+ inline std::string toString() {
|
|
+ std::string buffer;
|
|
+ llvm::raw_string_ostream os(buffer);
|
|
+ os << "RefinementKey(" << func.getName()
|
|
+ << ", toks=" << leadingTokenOperands << ", dim_args=[";
|
|
+ llvm::interleaveComma(globalConstants, os);
|
|
+ os << "], fn_args=[";
|
|
+ llvm::interleaveComma(functionalArgumentTypes, os);
|
|
+ os << "])";
|
|
+ return buffer;
|
|
+ }
|
|
+
|
|
+ private:
|
|
+ static int64_t countLeadingTokenOperands(func::CallOp callOp) {
|
|
+ int64_t nrLeadingTokenOperands = 0;
|
|
+ for (auto operand : callOp.getOperands()) {
|
|
+ if (!isa<TokenType>(operand.getType())) break;
|
|
+ nrLeadingTokenOperands++;
|
|
+ }
|
|
+ return nrLeadingTokenOperands;
|
|
+ }
|
|
+
|
|
+ // global-constant arguments follow token args, and are scalar integer
|
|
+ // constants These represent the known values of symbolic shapes sizes. I.e.
|
|
+ // tensor<Axf32> : A = constant(5)
|
|
+ static SmallVector<APSInt> getGlobalConstants(func::CallOp callOp,
|
|
+ int64_t leadingTokenOperands) {
|
|
+ SmallVector<APSInt> globalConstants;
|
|
+ auto operands = callOp.getOperands();
|
|
+ for (size_t i = leadingTokenOperands; i < operands.size(); ++i) {
|
|
+ auto operandType = dyn_cast<RankedTensorType>(operands[i].getType());
|
|
+ if (!operandType || operandType.getRank() != 0 ||
|
|
+ !operandType.getElementType().isInteger())
|
|
+ break;
|
|
+
|
|
+ SmallVector<APSInt> operand_int;
|
|
+ if (failed(hlo::matchInts(operands[i], operand_int))) break;
|
|
+ globalConstants.push_back(operand_int[0]);
|
|
+ }
|
|
+ return globalConstants;
|
|
+ }
|
|
+
|
|
+ // Functional operands are the arguments that are not global-constant
|
|
+ // arguments. These are the values that will remain after symbolic shape
|
|
+ // refinement.
|
|
+ static SmallVector<Type> getFunctionalArgumentTypes(
|
|
+ func::CallOp callOp, int64_t leadingTokenOperands,
|
|
+ int64_t globalConstantsSize) {
|
|
+ SmallVector<Type> functionalArgumentTypes;
|
|
+ auto operands = callOp.getOperands();
|
|
+ for (size_t i = leadingTokenOperands + globalConstantsSize;
|
|
+ i < operands.size(); ++i) {
|
|
+ functionalArgumentTypes.push_back(operands[i].getType());
|
|
+ }
|
|
+ return functionalArgumentTypes;
|
|
+ }
|
|
+
|
|
+ private:
|
|
+ func::FuncOp func;
|
|
+ int64_t leadingTokenOperands;
|
|
+ SmallVector<APSInt> globalConstants;
|
|
+ SmallVector<Type> functionalArgumentTypes;
|
|
+};
|
|
+
|
|
+// Per-module state for shape refinement.
|
|
+// An entry is Key is <FuncOp, SmallVector<APSInt>, SmallVector<Type>>
|
|
+// Which correlates to <func, sym_int_values, arg_types>
|
|
+class RefineShapeState {
|
|
+ public:
|
|
+ enum class RefinementState {
|
|
+ NOT_ALREADY_REFINED,
|
|
+ ALREADY_REFINED,
|
|
+ };
|
|
+
|
|
+ // Validates that we are not attempting to refine a function with a different
|
|
+ // context than previously, and are not attempting recursive refinement.
|
|
+ // Returns failure() if validation fails. On success, returns a refinement
|
|
+ // state that specifies whether the function has already been refined.
|
|
+ FailureOr<RefinementState> validateFunctionRefinement(RefinementKey key) {
|
|
+ func::FuncOp func = key.getFunc();
|
|
+ StringRef funcName = func.getName();
|
|
+
|
|
+ auto found = refinementContexts.find(func);
|
|
+ if (found == refinementContexts.end())
|
|
+ return RefinementState::NOT_ALREADY_REFINED;
|
|
+ RefinementKey prevKey = found->second;
|
|
+
|
|
+ // Since we refine until fixed point, we will refine a call to a function
|
|
+ // both for the original function and for the refined one. In the latter
|
|
+ // case, we should have empty globalConstants but everything else the
|
|
+ // same.
|
|
+ if (!key.getGlobalConstants().empty() &&
|
|
+ prevKey.getGlobalConstants() != key.getGlobalConstants())
|
|
+ return emitDifferentRefinementContextError(key.getFunc(), key, prevKey);
|
|
+
|
|
+ // Check that all non-global-constant arguments are the same.
|
|
+ // Must compare all non-global-constant types, since tokens may become
|
|
+ // leading:
|
|
+ // Refine iter1: `token, dim, token, arg` : 1 leading token
|
|
+ // Refine iter2: `token, token, arg` : 2 leading tokens
|
|
+ MLIRContext& context = *func.getContext();
|
|
+ if (key.getAllNonGlobalConstantArgumentTypes(context) !=
|
|
+ prevKey.getAllNonGlobalConstantArgumentTypes(context))
|
|
+ return emitDifferentRefinementContextError(key.getFunc(), key, prevKey);
|
|
+
|
|
+ // Don't allow recursive refinement.
|
|
+ if (llvm::is_contained(functionsBeingRefined, funcName))
|
|
+ return func.emitOpError()
|
|
+ << "Function " << funcName << " is being refined recursively\n";
|
|
+
|
|
+ return RefinementState::ALREADY_REFINED;
|
|
+ }
|
|
+
|
|
+ // Updates the state to signal the starting of a function refinement.
|
|
+ // Callers must call `finishFunctionRefinement` when done.
|
|
+ [[nodiscard]] auto createScopedFunctionRefinement(RefinementKey& key) {
|
|
+ func::FuncOp func = key.getFunc();
|
|
+ auto funcName = func.getName();
|
|
+ functionsBeingRefined.push_back(funcName);
|
|
+ refinementContexts.try_emplace(func, key);
|
|
+ // Return a cleanup function that will pop the function from the stack
|
|
+ // when it goes out of scope. This can only use values that will have the
|
|
+ // same lifetime as cleanup fn. In this case, `this` and `key` are safe.
|
|
+ return llvm::make_scope_exit([this, &key]() {
|
|
+ if (key.getFunc().getName() != functionsBeingRefined.back())
|
|
+ llvm::report_fatal_error(
|
|
+ "Stack mismatch in createScopedFunctionRefinement");
|
|
+ functionsBeingRefined.pop_back();
|
|
+ });
|
|
+ }
|
|
+
|
|
+ private:
|
|
+ // Maps refined functions to the refinement context: the values of dimension
|
|
+ // arguments and the types of non-global-constant arguments. A function is
|
|
+ // added here when we start refining it.
|
|
+ DenseMap<func::FuncOp, RefinementKey> refinementContexts;
|
|
+
|
|
+ // A stack of functions that are in the process of being refined, the current
|
|
+ // one is last.
|
|
+ SmallVector<llvm::StringRef> functionsBeingRefined;
|
|
+
|
|
+ LogicalResult emitDifferentRefinementContextError(func::FuncOp func,
|
|
+ RefinementKey key,
|
|
+ RefinementKey prevKey) {
|
|
+ return func.emitOpError()
|
|
+ << "refined with invompatible refinement keys:" << "\n curr="
|
|
+ << key.toString() << "\n prev=" << prevKey.toString();
|
|
+ }
|
|
+};
|
|
+
|
|
+// Forward declaration
|
|
+LogicalResult refineFunction(MLIRContext& context, RefineShapeState& state,
|
|
+ RefinementKey& key);
|
|
+
|
|
+// Check if a function only returns constant values, if so, return the constant
|
|
+// values that it returns.
|
|
+std::optional<SmallVector<DenseIntElementsAttr>> isConstantFunction(
|
|
+ func::FuncOp func) {
|
|
+ LLVM_DEBUG(llvm::dbgs() << "check if " << func.getName()
|
|
+ << " is a constant function\n");
|
|
+ SmallVector<DenseIntElementsAttr> returnedConstants;
|
|
+ func::ReturnOp ret = *func.getOps<func::ReturnOp>().begin();
|
|
+ bool isConstant = llvm::all_of(ret->getOperands(), [&](auto returnVal) {
|
|
+ DenseIntElementsAttr attr;
|
|
+ Operation* return_operand_def = returnVal.getDefiningOp();
|
|
+ if (return_operand_def &&
|
|
+ matchPattern(return_operand_def, m_Constant(&attr))) {
|
|
+ returnedConstants.push_back(attr);
|
|
+ return true;
|
|
+ }
|
|
+ return false;
|
|
+ });
|
|
+ if (isConstant) return returnedConstants;
|
|
+ return std::nullopt;
|
|
+}
|
|
+
|
|
// The patterns below implement shape refinement of individual ops.
|
|
// In a nutshell, they use the upstream type inference infrastructure and a
|
|
// StableHLO-specific extension to refine return types based on potentially
|
|
@@ -297,6 +529,54 @@
|
|
|
|
return refineReturnShape(rewriter, op, operandType.getShape());
|
|
}
|
|
+};
|
|
+
|
|
+struct RefineCallOpPattern : public OpRewritePattern<func::CallOp> {
|
|
+ using OpRewritePattern::OpRewritePattern;
|
|
+
|
|
+ RefineCallOpPattern(MLIRContext* context, RefineShapeState& state)
|
|
+ : OpRewritePattern<func::CallOp>(context), state(state) {}
|
|
+
|
|
+ LogicalResult matchAndRewrite(func::CallOp op,
|
|
+ PatternRewriter& rewriter) const override {
|
|
+ auto refinementKey = RefinementKey::fromCallOp(op);
|
|
+ if (failed(refinementKey)) return failure();
|
|
+ if (failed(refineFunction(*rewriter.getContext(), state, *refinementKey)))
|
|
+ return failure();
|
|
+
|
|
+ // Is the callee a constant function in this refinement context?
|
|
+ auto callee = refinementKey->getFunc();
|
|
+ std::optional<SmallVector<DenseIntElementsAttr>> constantAttrs =
|
|
+ isConstantFunction(callee);
|
|
+ if (constantAttrs.has_value()) {
|
|
+ SmallVector<Value> constants;
|
|
+ for (auto constAttr : constantAttrs.value()) {
|
|
+ constants.push_back(
|
|
+ rewriter.create<ConstantOp>(op.getLoc(), constAttr));
|
|
+ }
|
|
+ rewriter.replaceOp(op, constants);
|
|
+ return success();
|
|
+ }
|
|
+ if (!refinementKey->getGlobalConstants().empty()) {
|
|
+ // Drop the global-constant arguments, but only if necessary, or else we
|
|
+ // will end up trying to refine the new CallOp forever.
|
|
+ SmallVector<Value> newOperands;
|
|
+ auto leadingTokenOperands =
|
|
+ op.getOperands().take_front(refinementKey->getLeadingTokenOperands());
|
|
+ auto functionalOperands = op.getOperands().take_back(
|
|
+ refinementKey->getFunctionalArgumentTypes().size());
|
|
+ newOperands.append(leadingTokenOperands.begin(),
|
|
+ leadingTokenOperands.end());
|
|
+ newOperands.append(functionalOperands.begin(), functionalOperands.end());
|
|
+ op = rewriter.replaceOpWithNewOp<func::CallOp>(
|
|
+ op, op.getResultTypes(), callee.getSymName(), newOperands);
|
|
+ LLVM_DEBUG(llvm::dbgs() << "Replaced call with " << op << "\n");
|
|
+ }
|
|
+ return refineReturnTypes(rewriter, op, callee.getResultTypes());
|
|
+ }
|
|
+
|
|
+ private:
|
|
+ RefineShapeState& state;
|
|
};
|
|
|
|
struct RefineConvertOpPattern : public OpRewritePattern<ConvertOp> {
|
|
@@ -718,49 +998,116 @@
|
|
}
|
|
};
|
|
|
|
+LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
|
|
+ RefineShapeState& state) {
|
|
+ MLIRContext* context = func.getContext();
|
|
+ RewritePatternSet patterns(context);
|
|
+ GreedyRewriteConfig config;
|
|
+
|
|
+ // The algorithm behind this pass consists of a single traversal of the
|
|
+ // function. This is sufficient because we only support one function per
|
|
+ // program at the moment.
|
|
+ // TODO(#1048): Find out why .maxIterations = 1 no longer works.
|
|
+ // There have been recent refactors to applyPatternsAndFoldGreedily
|
|
+ // upstream, and that might be the reason.
|
|
+ config.useTopDownTraversal = true;
|
|
+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive;
|
|
+ config.maxIterations = 2;
|
|
+ config.maxNumRewrites = GreedyRewriteConfig::kNoLimit;
|
|
+ config.strictMode = GreedyRewriteStrictness::AnyOp;
|
|
+
|
|
+ populateStablehloRefineShapesPatterns(&patterns, context);
|
|
+ patterns.add<RefineCallOpPattern>(context, state);
|
|
+
|
|
+ // The folding patterns implement partial evaluation of shape computations
|
|
+ // which is a critical part of implementing type refinement for ops like
|
|
+ // dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
|
|
+ // depends on the value of their shape operands.
|
|
+ populateStablehloShapeFolderPatterns(&patterns, context);
|
|
+
|
|
+ if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns), config)))
|
|
+ func.emitError("Failed to converge StablehloRefineShapes in ")
|
|
+ << config.maxIterations << " iterations";
|
|
+
|
|
+ return success();
|
|
+}
|
|
+
|
|
+LogicalResult refineFunction(MLIRContext& context, RefineShapeState& state,
|
|
+ RefinementKey& key) {
|
|
+ LLVM_DEBUG(llvm::dbgs() << "Refining: " << key.toString() << "\n");
|
|
+ auto refinementState = state.validateFunctionRefinement(key);
|
|
+ if (failed(refinementState)) return failure();
|
|
+
|
|
+ auto func = key.getFunc();
|
|
+ if (*refinementState == RefineShapeState::RefinementState::ALREADY_REFINED) {
|
|
+ LLVM_DEBUG(llvm::dbgs() << "Function " << func.getName()
|
|
+ << " already refined, skipping\n");
|
|
+ return success();
|
|
+ }
|
|
+
|
|
+ auto scopedCleanup = state.createScopedFunctionRefinement(key);
|
|
+
|
|
+ // StableHLO functions must have exactly one block.
|
|
+ if (!func.getRegion().hasOneBlock())
|
|
+ return func.emitOpError() << "must have exactly one block";
|
|
+
|
|
+ // Replace the global-constant arguments with their values and drop args.
|
|
+ // Wrap non-global-constant arguments with bitcast_convert.
|
|
+ OpBuilder builder(func.getRegion());
|
|
+ builder.setInsertionPointToStart(&func.getRegion().front());
|
|
+ int64_t leadingTokenOperands = key.getLeadingTokenOperands();
|
|
+
|
|
+ for (auto [i, dimValue] : llvm::enumerate(key.getGlobalConstants())) {
|
|
+ int64_t operandIdx = leadingTokenOperands + i;
|
|
+ BlockArgument arg = func.getArgument(operandIdx);
|
|
+ Type argType = arg.getType();
|
|
+ ShapedType argShapedType = dyn_cast<ShapedType>(argType);
|
|
+
|
|
+ if (!argShapedType)
|
|
+ return func.emitOpError()
|
|
+ << "expected global constant argument to be shaped";
|
|
+
|
|
+ auto replacement_op = builder.create<stablehlo::ConstantOp>(
|
|
+ arg.getLoc(), argType, DenseElementsAttr::get(argShapedType, dimValue));
|
|
+ arg.replaceAllUsesWith(replacement_op);
|
|
+ }
|
|
+ BitVector argIndices(func.getNumArguments());
|
|
+ size_t firstFunctionalArgument =
|
|
+ leadingTokenOperands + key.getGlobalConstants().size();
|
|
+ argIndices.set(leadingTokenOperands, firstFunctionalArgument);
|
|
+ func.eraseArguments(argIndices);
|
|
+
|
|
+ // Refine the remaining argument types, wrap with shape buffer custom calls.
|
|
+ SmallVector<Type> refinedTypes =
|
|
+ key.getAllNonGlobalConstantArgumentTypes(context);
|
|
+ if (failed(refineArguments(func, refinedTypes))) return failure();
|
|
+ LLVM_DEBUG(llvm::dbgs() << "Refined function type for " << func.getName()
|
|
+ << ": " << func.getFunctionType() << "\n");
|
|
+
|
|
+ // Now iterate into the function body and apply refinement patterns.
|
|
+ if (failed(applyShapeRefinementPatterns(func, state))) return failure();
|
|
+
|
|
+ LLVM_DEBUG(llvm::dbgs() << "refineFunction " << func.getName()
|
|
+ << ": end with type " << func.getFunctionType()
|
|
+ << "\n");
|
|
+ return success();
|
|
+}
|
|
+
|
|
struct StablehloRefineShapesPass
|
|
: public impl::StablehloRefineShapesPassBase<StablehloRefineShapesPass> {
|
|
using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase;
|
|
|
|
- LogicalResult initialize(MLIRContext* context) override {
|
|
- // The algorithm behind this pass consists of a single traversal of the
|
|
- // function. This is sufficient because we only support one function per
|
|
- // program at the moment.
|
|
- // TODO(#1048): Find out why .maxIterations = 1 no longer works.
|
|
- // There have been recent refactors to applyPatternsAndFoldGreedily
|
|
- // upstream, and that might be the reason.
|
|
- config.useTopDownTraversal = true;
|
|
- config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive;
|
|
- config.maxIterations = 2;
|
|
- config.maxNumRewrites = GreedyRewriteConfig::kNoLimit;
|
|
- config.strictMode = GreedyRewriteStrictness::AnyOp;
|
|
-
|
|
- RewritePatternSet patterns_(context);
|
|
- populateStablehloRefineShapesPatterns(&patterns_, context);
|
|
-
|
|
- // The folding patterns implement partial evaluation of shape computations
|
|
- // which is a critical part of implementing type refinement for ops like
|
|
- // dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
|
|
- // depends on the value of their shape operands.
|
|
- populateStablehloShapeFolderPatterns(&patterns_, context);
|
|
- patterns = std::move(patterns_);
|
|
-
|
|
- return success();
|
|
- }
|
|
-
|
|
void runOnOperation() override {
|
|
auto func = getStablehloRefineShapesTarget(getOperation());
|
|
if (!func) return signalPassFailure();
|
|
|
|
- if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
|
|
- func.emitError("Failed to converge StablehloRefineShapes in ")
|
|
- << config.maxIterations << " iterations";
|
|
- }
|
|
- }
|
|
-
|
|
- private:
|
|
- FrozenRewritePatternSet patterns;
|
|
- GreedyRewriteConfig config;
|
|
+ // Start with empty state, and no dim args / token args.
|
|
+ MLIRContext* context = func.getContext();
|
|
+ RefineShapeState state;
|
|
+ RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes()));
|
|
+ if (failed(refineFunction(*context, state, key)))
|
|
+ return signalPassFailure();
|
|
+ }
|
|
};
|
|
|
|
} // namespace
|
|
@@ -807,6 +1154,7 @@
|
|
MLIRContext* context) {
|
|
patterns->add<RefineAllGatherOpPattern>(context);
|
|
patterns->add<RefineBitcastConvertOpPattern>(context);
|
|
+ // patterns->add<RefineCallOpPattern>(context); // Populate requires inline
|
|
patterns->add<RefineConvertOpPattern>(context);
|
|
patterns->add<RefineConvolutionOpPattern>(context);
|
|
patterns->add<RefineCustomCallOpPattern>(context);
|
|
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h
|
|
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h
|
|
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h
|
|
@@ -34,6 +34,11 @@
|
|
// Returns a nullptr and emits appropriate errors if such a function cannot
|
|
// be obtained from the module.
|
|
func::FuncOp getStablehloRefineShapesTarget(ModuleOp module);
|
|
+
|
|
+// Refine the arguments of the given function using the given types.
|
|
+// Wraps all operands in a custom call to keep the IR valid during refinement.
|
|
+// %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0)
|
|
+LogicalResult refineArguments(func::FuncOp func, TypeRange refinedTypes);
|
|
|
|
// Refines the values using the given types.
|
|
// Tricky implementation details:
|
|
|