From ed6d0867fdbb95691f4054eef6c4b1bd16099700 Mon Sep 17 00:00:00 2001 From: Richard Sandiford Date: Wed, 20 Nov 2024 10:04:46 +0000 Subject: [PATCH] Extend expand_absneg_bit to vector modes Expand can implement NEG and ABS of scalar floating-point modes by using logic ops to manipulate the sign bit. This patch extends that approach to vectors, since it fits relatively easily into the same structure. The motivating use case was to inline bf16 NEG and ABS operations for AArch64. The patch includes tests for that. get_absneg_bit_mode required a new opt_mode constructor, so that opt_mode can be constructed from opt_mode if T is no less general than U. gcc/ * machmode.h (opt_mode::opt_mode): New overload. * optabs-query.h (get_absneg_bit_mode): Declare. * optabs-query.cc (get_absneg_bit_mode): New function, split out from expand_absneg_bit. (can_open_code_p): Use get_absneg_bit_mode. * optabs.cc (expand_absneg_bit): Likewise. Take an outer and inner mode, rather than just one. Handle vector modes. (expand_unop, expand_abs_nojump): Update calls accordingly. Handle vector modes. gcc/testsuite/ * gcc.target/aarch64/abs_bf_1.c: New test. * gcc.target/aarch64/neg_bf_1.c: Likewise. * gcc.target/aarch64/neg_bf_2.c: Likewise. --- gcc/machmode.h | 10 ++++ gcc/optabs-query.cc | 42 ++++++++++++++++ gcc/optabs-query.h | 2 + gcc/optabs.cc | 53 +++++++++------------ gcc/testsuite/gcc.target/aarch64/abs_bf_1.c | 10 ++++ gcc/testsuite/gcc.target/aarch64/neg_bf_1.c | 11 +++++ gcc/testsuite/gcc.target/aarch64/neg_bf_2.c | 10 ++++ 7 files changed, 107 insertions(+), 31 deletions(-) create mode 100644 gcc/testsuite/gcc.target/aarch64/abs_bf_1.c create mode 100644 gcc/testsuite/gcc.target/aarch64/neg_bf_1.c create mode 100644 gcc/testsuite/gcc.target/aarch64/neg_bf_2.c diff --git a/gcc/machmode.h b/gcc/machmode.h index 4c2a8d943cf..9cf792b5cca 100644 --- a/gcc/machmode.h +++ b/gcc/machmode.h @@ -268,6 +268,8 @@ public: ALWAYS_INLINE CONSTEXPR opt_mode (const T &m) : m_mode (m) {} template ALWAYS_INLINE CONSTEXPR opt_mode (const U &m) : m_mode (T (m)) {} + template + ALWAYS_INLINE CONSTEXPR opt_mode (const opt_mode &); ALWAYS_INLINE CONSTEXPR opt_mode (from_int m) : m_mode (machine_mode (m)) {} machine_mode else_void () const; @@ -285,6 +287,14 @@ private: machine_mode m_mode; }; +template +template +ALWAYS_INLINE CONSTEXPR +opt_mode::opt_mode (const opt_mode &m) + : m_mode (m.exists () ? T (m.require ()) : E_VOIDmode) +{ +} + /* If the object contains a T, return its enum value, otherwise return E_VOIDmode. */ diff --git a/gcc/optabs-query.cc b/gcc/optabs-query.cc index 6d28d620eb5..8ab4164e82c 100644 --- a/gcc/optabs-query.cc +++ b/gcc/optabs-query.cc @@ -782,6 +782,39 @@ can_vec_extract (machine_mode mode, machine_mode extr_mode) return true; } +/* OP is either neg_optab or abs_optab and FMODE is the floating-point inner + mode of MODE. Check whether we can implement OP for mode MODE by using + xor_optab to flip the sign bit (for neg_optab) or and_optab to clear the + sign bit (for abs_optab). If so, return the integral mode that should be + used to do the operation and set *BITPOS to the index of the sign bit + (counting from the lsb). */ + +opt_machine_mode +get_absneg_bit_mode (optab op, machine_mode mode, + scalar_float_mode fmode, int *bitpos) +{ + /* The format has to have a simple sign bit. */ + auto fmt = REAL_MODE_FORMAT (fmode); + if (fmt == NULL) + return {}; + + *bitpos = fmt->signbit_rw; + if (*bitpos < 0) + return {}; + + /* Don't create negative zeros if the format doesn't support them. */ + if (op == neg_optab && !fmt->has_signed_zero) + return {}; + + if (VECTOR_MODE_P (mode)) + return related_int_vector_mode (mode); + + if (GET_MODE_SIZE (fmode) <= UNITS_PER_WORD) + return int_mode_for_mode (fmode); + + return word_mode; +} + /* Return true if we can implement OP for mode MODE directly, without resorting to a libfunc. This usually means that OP will be implemented inline. @@ -800,6 +833,15 @@ can_open_code_p (optab op, machine_mode mode) if (op == smul_highpart_optab) return can_mult_highpart_p (mode, false); + machine_mode new_mode; + scalar_float_mode fmode; + int bitpos; + if ((op == neg_optab || op == abs_optab) + && is_a (GET_MODE_INNER (mode), &fmode) + && get_absneg_bit_mode (op, mode, fmode, &bitpos).exists (&new_mode) + && can_implement_p (op == neg_optab ? xor_optab : and_optab, new_mode)) + return true; + return false; } diff --git a/gcc/optabs-query.h b/gcc/optabs-query.h index 89a7b02ef43..60c8021a1b7 100644 --- a/gcc/optabs-query.h +++ b/gcc/optabs-query.h @@ -171,6 +171,8 @@ bool lshift_cheap_p (bool); bool supports_vec_gather_load_p (machine_mode = E_VOIDmode, vec * = nullptr); bool supports_vec_scatter_store_p (machine_mode = E_VOIDmode); +opt_machine_mode get_absneg_bit_mode (optab, machine_mode, + scalar_float_mode, int *); bool can_vec_extract (machine_mode, machine_mode); bool can_open_code_p (optab, machine_mode); bool can_implement_p (optab, machine_mode); diff --git a/gcc/optabs.cc b/gcc/optabs.cc index fa51e498a98..b9c51f78af4 100644 --- a/gcc/optabs.cc +++ b/gcc/optabs.cc @@ -3101,48 +3101,37 @@ expand_ffs (scalar_int_mode mode, rtx op0, rtx target) } /* Expand a floating point absolute value or negation operation via a - logical operation on the sign bit. */ + logical operation on the sign bit. MODE is the mode of the operands + and FMODE is the scalar inner mode. */ static rtx -expand_absneg_bit (enum rtx_code code, scalar_float_mode mode, - rtx op0, rtx target) +expand_absneg_bit (rtx_code code, machine_mode mode, + scalar_float_mode fmode, rtx op0, rtx target) { - const struct real_format *fmt; int bitpos, word, nwords, i; + machine_mode new_mode; scalar_int_mode imode; rtx temp; rtx_insn *insns; - /* The format has to have a simple sign bit. */ - fmt = REAL_MODE_FORMAT (mode); - if (fmt == NULL) + auto op = code == NEG ? neg_optab : abs_optab; + if (!get_absneg_bit_mode (op, mode, fmode, &bitpos).exists (&new_mode)) return NULL_RTX; - bitpos = fmt->signbit_rw; - if (bitpos < 0) - return NULL_RTX; - - /* Don't create negative zeros if the format doesn't support them. */ - if (code == NEG && !fmt->has_signed_zero) - return NULL_RTX; - - if (GET_MODE_SIZE (mode) <= UNITS_PER_WORD) + imode = as_a (GET_MODE_INNER (new_mode)); + if (VECTOR_MODE_P (mode) || GET_MODE_SIZE (fmode) <= UNITS_PER_WORD) { - if (!int_mode_for_mode (mode).exists (&imode)) - return NULL_RTX; word = 0; nwords = 1; } else { - imode = word_mode; - if (FLOAT_WORDS_BIG_ENDIAN) - word = (GET_MODE_BITSIZE (mode) - bitpos) / BITS_PER_WORD; + word = (GET_MODE_BITSIZE (fmode) - bitpos) / BITS_PER_WORD; else word = bitpos / BITS_PER_WORD; bitpos = bitpos % BITS_PER_WORD; - nwords = (GET_MODE_BITSIZE (mode) + BITS_PER_WORD - 1) / BITS_PER_WORD; + nwords = (GET_MODE_BITSIZE (fmode) + BITS_PER_WORD - 1) / BITS_PER_WORD; } wide_int mask = wi::set_bit_in_zero (bitpos, GET_MODE_PRECISION (imode)); @@ -3184,11 +3173,13 @@ expand_absneg_bit (enum rtx_code code, scalar_float_mode mode, } else { - temp = expand_binop (imode, code == ABS ? and_optab : xor_optab, - gen_lowpart (imode, op0), - immed_wide_int_const (mask, imode), - gen_lowpart (imode, target), 1, OPTAB_LIB_WIDEN); - target = force_lowpart_subreg (mode, temp, imode); + rtx mask_rtx = immed_wide_int_const (mask, imode); + if (VECTOR_MODE_P (new_mode)) + mask_rtx = gen_const_vec_duplicate (new_mode, mask_rtx); + temp = expand_binop (new_mode, code == ABS ? and_optab : xor_optab, + gen_lowpart (new_mode, op0), mask_rtx, + gen_lowpart (new_mode, target), 1, OPTAB_LIB_WIDEN); + target = force_lowpart_subreg (mode, temp, new_mode); set_dst_reg_note (get_last_insn (), REG_EQUAL, gen_rtx_fmt_e (code, mode, copy_rtx (op0)), @@ -3478,9 +3469,9 @@ expand_unop (machine_mode mode, optab unoptab, rtx op0, rtx target, if (optab_to_code (unoptab) == NEG) { /* Try negating floating point values by flipping the sign bit. */ - if (is_a (mode, &float_mode)) + if (is_a (GET_MODE_INNER (mode), &float_mode)) { - temp = expand_absneg_bit (NEG, float_mode, op0, target); + temp = expand_absneg_bit (NEG, mode, float_mode, op0, target); if (temp) return temp; } @@ -3698,9 +3689,9 @@ expand_abs_nojump (machine_mode mode, rtx op0, rtx target, /* For floating point modes, try clearing the sign bit. */ scalar_float_mode float_mode; - if (is_a (mode, &float_mode)) + if (is_a (GET_MODE_INNER (mode), &float_mode)) { - temp = expand_absneg_bit (ABS, float_mode, op0, target); + temp = expand_absneg_bit (ABS, mode, float_mode, op0, target); if (temp) return temp; } diff --git a/gcc/testsuite/gcc.target/aarch64/abs_bf_1.c b/gcc/testsuite/gcc.target/aarch64/abs_bf_1.c new file mode 100644 index 00000000000..42e03bca0be --- /dev/null +++ b/gcc/testsuite/gcc.target/aarch64/abs_bf_1.c @@ -0,0 +1,10 @@ +/* { dg-options "-O2 -ffast-math" } */ + +void +foo (__bf16 *ptr) +{ + for (int i = 0; i < 8; ++i) + ptr[i] = __builtin_fabsf (ptr[i]); +} + +/* { dg-final { scan-assembler {\t(?:bic|and)\t[zv]} } } */ diff --git a/gcc/testsuite/gcc.target/aarch64/neg_bf_1.c b/gcc/testsuite/gcc.target/aarch64/neg_bf_1.c new file mode 100644 index 00000000000..564ff1ec9cb --- /dev/null +++ b/gcc/testsuite/gcc.target/aarch64/neg_bf_1.c @@ -0,0 +1,11 @@ +/* { dg-options "-O2" } */ + +typedef __bf16 v8bf __attribute__((vector_size(16))); +typedef __bf16 v16bf __attribute__((vector_size(32))); +typedef __bf16 v64bf __attribute__((vector_size(128))); + +v8bf f1(v8bf x) { return -x; } +v16bf f2(v16bf x) { return -x; } +v64bf f3(v64bf x) { return -x; } + +/* { dg-final { scan-assembler-times {\teor\t[zv]} 11 } } */ diff --git a/gcc/testsuite/gcc.target/aarch64/neg_bf_2.c b/gcc/testsuite/gcc.target/aarch64/neg_bf_2.c new file mode 100644 index 00000000000..07292428418 --- /dev/null +++ b/gcc/testsuite/gcc.target/aarch64/neg_bf_2.c @@ -0,0 +1,10 @@ +/* { dg-options "-O2" } */ + +void +foo (__bf16 *ptr) +{ + for (int i = 0; i < 8; ++i) + ptr[i] = -ptr[i]; +} + +/* { dg-final { scan-assembler {\teor\t[zv]} } } */