Extend expand_absneg_bit to vector modes

Expand can implement NEG and ABS of scalar floating-point modes
by using logic ops to manipulate the sign bit.  This patch extends
that approach to vectors, since it fits relatively easily into the
same structure.

The motivating use case was to inline bf16 NEG and ABS operations
for AArch64.  The patch includes tests for that.

get_absneg_bit_mode required a new opt_mode constructor, so that
opt_mode<T> can be constructed from opt_mode<U> if T is no less
general than U.

gcc/
	* machmode.h (opt_mode::opt_mode): New overload.
	* optabs-query.h (get_absneg_bit_mode): Declare.
	* optabs-query.cc (get_absneg_bit_mode): New function, split
	out from expand_absneg_bit.
	(can_open_code_p): Use get_absneg_bit_mode.
	* optabs.cc (expand_absneg_bit): Likewise.  Take an outer and inner
	mode, rather than just one.  Handle vector modes.
	(expand_unop, expand_abs_nojump): Update calls accordingly.
	Handle vector modes.

gcc/testsuite/
	* gcc.target/aarch64/abs_bf_1.c: New test.
	* gcc.target/aarch64/neg_bf_1.c: Likewise.
	* gcc.target/aarch64/neg_bf_2.c: Likewise.
This commit is contained in:
Richard Sandiford 2024-11-20 10:04:46 +00:00
parent 0abb5fa523
commit ed6d0867fd
7 changed files with 107 additions and 31 deletions

View File

@ -268,6 +268,8 @@ public:
ALWAYS_INLINE CONSTEXPR opt_mode (const T &m) : m_mode (m) {} ALWAYS_INLINE CONSTEXPR opt_mode (const T &m) : m_mode (m) {}
template<typename U> template<typename U>
ALWAYS_INLINE CONSTEXPR opt_mode (const U &m) : m_mode (T (m)) {} ALWAYS_INLINE CONSTEXPR opt_mode (const U &m) : m_mode (T (m)) {}
template<typename U>
ALWAYS_INLINE CONSTEXPR opt_mode (const opt_mode<U> &);
ALWAYS_INLINE CONSTEXPR opt_mode (from_int m) : m_mode (machine_mode (m)) {} ALWAYS_INLINE CONSTEXPR opt_mode (from_int m) : m_mode (machine_mode (m)) {}
machine_mode else_void () const; machine_mode else_void () const;
@ -285,6 +287,14 @@ private:
machine_mode m_mode; machine_mode m_mode;
}; };
template<typename T>
template<typename U>
ALWAYS_INLINE CONSTEXPR
opt_mode<T>::opt_mode (const opt_mode<U> &m)
: m_mode (m.exists () ? T (m.require ()) : E_VOIDmode)
{
}
/* If the object contains a T, return its enum value, otherwise return /* If the object contains a T, return its enum value, otherwise return
E_VOIDmode. */ E_VOIDmode. */

View File

@ -782,6 +782,39 @@ can_vec_extract (machine_mode mode, machine_mode extr_mode)
return true; return true;
} }
/* OP is either neg_optab or abs_optab and FMODE is the floating-point inner
mode of MODE. Check whether we can implement OP for mode MODE by using
xor_optab to flip the sign bit (for neg_optab) or and_optab to clear the
sign bit (for abs_optab). If so, return the integral mode that should be
used to do the operation and set *BITPOS to the index of the sign bit
(counting from the lsb). */
opt_machine_mode
get_absneg_bit_mode (optab op, machine_mode mode,
scalar_float_mode fmode, int *bitpos)
{
/* The format has to have a simple sign bit. */
auto fmt = REAL_MODE_FORMAT (fmode);
if (fmt == NULL)
return {};
*bitpos = fmt->signbit_rw;
if (*bitpos < 0)
return {};
/* Don't create negative zeros if the format doesn't support them. */
if (op == neg_optab && !fmt->has_signed_zero)
return {};
if (VECTOR_MODE_P (mode))
return related_int_vector_mode (mode);
if (GET_MODE_SIZE (fmode) <= UNITS_PER_WORD)
return int_mode_for_mode (fmode);
return word_mode;
}
/* Return true if we can implement OP for mode MODE directly, without resorting /* Return true if we can implement OP for mode MODE directly, without resorting
to a libfunc. This usually means that OP will be implemented inline. to a libfunc. This usually means that OP will be implemented inline.
@ -800,6 +833,15 @@ can_open_code_p (optab op, machine_mode mode)
if (op == smul_highpart_optab) if (op == smul_highpart_optab)
return can_mult_highpart_p (mode, false); return can_mult_highpart_p (mode, false);
machine_mode new_mode;
scalar_float_mode fmode;
int bitpos;
if ((op == neg_optab || op == abs_optab)
&& is_a<scalar_float_mode> (GET_MODE_INNER (mode), &fmode)
&& get_absneg_bit_mode (op, mode, fmode, &bitpos).exists (&new_mode)
&& can_implement_p (op == neg_optab ? xor_optab : and_optab, new_mode))
return true;
return false; return false;
} }

View File

@ -171,6 +171,8 @@ bool lshift_cheap_p (bool);
bool supports_vec_gather_load_p (machine_mode = E_VOIDmode, bool supports_vec_gather_load_p (machine_mode = E_VOIDmode,
vec<int> * = nullptr); vec<int> * = nullptr);
bool supports_vec_scatter_store_p (machine_mode = E_VOIDmode); bool supports_vec_scatter_store_p (machine_mode = E_VOIDmode);
opt_machine_mode get_absneg_bit_mode (optab, machine_mode,
scalar_float_mode, int *);
bool can_vec_extract (machine_mode, machine_mode); bool can_vec_extract (machine_mode, machine_mode);
bool can_open_code_p (optab, machine_mode); bool can_open_code_p (optab, machine_mode);
bool can_implement_p (optab, machine_mode); bool can_implement_p (optab, machine_mode);

View File

@ -3101,48 +3101,37 @@ expand_ffs (scalar_int_mode mode, rtx op0, rtx target)
} }
/* Expand a floating point absolute value or negation operation via a /* Expand a floating point absolute value or negation operation via a
logical operation on the sign bit. */ logical operation on the sign bit. MODE is the mode of the operands
and FMODE is the scalar inner mode. */
static rtx static rtx
expand_absneg_bit (enum rtx_code code, scalar_float_mode mode, expand_absneg_bit (rtx_code code, machine_mode mode,
rtx op0, rtx target) scalar_float_mode fmode, rtx op0, rtx target)
{ {
const struct real_format *fmt;
int bitpos, word, nwords, i; int bitpos, word, nwords, i;
machine_mode new_mode;
scalar_int_mode imode; scalar_int_mode imode;
rtx temp; rtx temp;
rtx_insn *insns; rtx_insn *insns;
/* The format has to have a simple sign bit. */ auto op = code == NEG ? neg_optab : abs_optab;
fmt = REAL_MODE_FORMAT (mode); if (!get_absneg_bit_mode (op, mode, fmode, &bitpos).exists (&new_mode))
if (fmt == NULL)
return NULL_RTX; return NULL_RTX;
bitpos = fmt->signbit_rw; imode = as_a<scalar_int_mode> (GET_MODE_INNER (new_mode));
if (bitpos < 0) if (VECTOR_MODE_P (mode) || GET_MODE_SIZE (fmode) <= UNITS_PER_WORD)
return NULL_RTX;
/* Don't create negative zeros if the format doesn't support them. */
if (code == NEG && !fmt->has_signed_zero)
return NULL_RTX;
if (GET_MODE_SIZE (mode) <= UNITS_PER_WORD)
{ {
if (!int_mode_for_mode (mode).exists (&imode))
return NULL_RTX;
word = 0; word = 0;
nwords = 1; nwords = 1;
} }
else else
{ {
imode = word_mode;
if (FLOAT_WORDS_BIG_ENDIAN) if (FLOAT_WORDS_BIG_ENDIAN)
word = (GET_MODE_BITSIZE (mode) - bitpos) / BITS_PER_WORD; word = (GET_MODE_BITSIZE (fmode) - bitpos) / BITS_PER_WORD;
else else
word = bitpos / BITS_PER_WORD; word = bitpos / BITS_PER_WORD;
bitpos = bitpos % BITS_PER_WORD; bitpos = bitpos % BITS_PER_WORD;
nwords = (GET_MODE_BITSIZE (mode) + BITS_PER_WORD - 1) / BITS_PER_WORD; nwords = (GET_MODE_BITSIZE (fmode) + BITS_PER_WORD - 1) / BITS_PER_WORD;
} }
wide_int mask = wi::set_bit_in_zero (bitpos, GET_MODE_PRECISION (imode)); wide_int mask = wi::set_bit_in_zero (bitpos, GET_MODE_PRECISION (imode));
@ -3184,11 +3173,13 @@ expand_absneg_bit (enum rtx_code code, scalar_float_mode mode,
} }
else else
{ {
temp = expand_binop (imode, code == ABS ? and_optab : xor_optab, rtx mask_rtx = immed_wide_int_const (mask, imode);
gen_lowpart (imode, op0), if (VECTOR_MODE_P (new_mode))
immed_wide_int_const (mask, imode), mask_rtx = gen_const_vec_duplicate (new_mode, mask_rtx);
gen_lowpart (imode, target), 1, OPTAB_LIB_WIDEN); temp = expand_binop (new_mode, code == ABS ? and_optab : xor_optab,
target = force_lowpart_subreg (mode, temp, imode); gen_lowpart (new_mode, op0), mask_rtx,
gen_lowpart (new_mode, target), 1, OPTAB_LIB_WIDEN);
target = force_lowpart_subreg (mode, temp, new_mode);
set_dst_reg_note (get_last_insn (), REG_EQUAL, set_dst_reg_note (get_last_insn (), REG_EQUAL,
gen_rtx_fmt_e (code, mode, copy_rtx (op0)), gen_rtx_fmt_e (code, mode, copy_rtx (op0)),
@ -3478,9 +3469,9 @@ expand_unop (machine_mode mode, optab unoptab, rtx op0, rtx target,
if (optab_to_code (unoptab) == NEG) if (optab_to_code (unoptab) == NEG)
{ {
/* Try negating floating point values by flipping the sign bit. */ /* Try negating floating point values by flipping the sign bit. */
if (is_a <scalar_float_mode> (mode, &float_mode)) if (is_a <scalar_float_mode> (GET_MODE_INNER (mode), &float_mode))
{ {
temp = expand_absneg_bit (NEG, float_mode, op0, target); temp = expand_absneg_bit (NEG, mode, float_mode, op0, target);
if (temp) if (temp)
return temp; return temp;
} }
@ -3698,9 +3689,9 @@ expand_abs_nojump (machine_mode mode, rtx op0, rtx target,
/* For floating point modes, try clearing the sign bit. */ /* For floating point modes, try clearing the sign bit. */
scalar_float_mode float_mode; scalar_float_mode float_mode;
if (is_a <scalar_float_mode> (mode, &float_mode)) if (is_a <scalar_float_mode> (GET_MODE_INNER (mode), &float_mode))
{ {
temp = expand_absneg_bit (ABS, float_mode, op0, target); temp = expand_absneg_bit (ABS, mode, float_mode, op0, target);
if (temp) if (temp)
return temp; return temp;
} }

View File

@ -0,0 +1,10 @@
/* { dg-options "-O2 -ffast-math" } */
void
foo (__bf16 *ptr)
{
for (int i = 0; i < 8; ++i)
ptr[i] = __builtin_fabsf (ptr[i]);
}
/* { dg-final { scan-assembler {\t(?:bic|and)\t[zv]} } } */

View File

@ -0,0 +1,11 @@
/* { dg-options "-O2" } */
typedef __bf16 v8bf __attribute__((vector_size(16)));
typedef __bf16 v16bf __attribute__((vector_size(32)));
typedef __bf16 v64bf __attribute__((vector_size(128)));
v8bf f1(v8bf x) { return -x; }
v16bf f2(v16bf x) { return -x; }
v64bf f3(v64bf x) { return -x; }
/* { dg-final { scan-assembler-times {\teor\t[zv]} 11 } } */

View File

@ -0,0 +1,10 @@
/* { dg-options "-O2" } */
void
foo (__bf16 *ptr)
{
for (int i = 0; i < 8; ++i)
ptr[i] = -ptr[i];
}
/* { dg-final { scan-assembler {\teor\t[zv]} } } */