Consider fully pipelined FMA in get_reassociation_width

Add a new parameter param_fully_pipelined_fma. If it is non-zero,
reassociation considers the benefit of parallelizing FMA's
multiplication part and addition part, assuming FMUL and FMA use the
same units that can also do FADD.

With the patch and new option, there's ~2% improvement in spec2017
508.namd on AmpereOne. (The other options are "-Ofast -mcpu=ampere1
 -flto".)

	PR tree-optimization/110279

gcc/ChangeLog:

	* doc/invoke.texi: New parameter fully-pipelined-fma.
	* params.opt: New parameter fully-pipelined-fma.
	* tree-ssa-reassoc.cc (get_mult_latency_consider_fma): Return
	the latency of MULT_EXPRs that can't be hidden by the FMAs.
	(get_reassociation_width): Search for a smaller width
	considering the benefit of fully pipelined FMA.
	(rank_ops_for_fma): Return the number of MULT_EXPRs.
	(reassociate_bb): Pass the number of MULT_EXPRs to
	get_reassociation_width; avoid calling
	get_reassociation_width twice.

gcc/testsuite/ChangeLog:

	* gcc.dg/pr110279-2.c: New test.
This commit is contained in:
Di Zhao 2023-12-15 03:22:32 +08:00
parent 95b7054533
commit 8afdbcdd7a
4 changed files with 178 additions and 28 deletions

View file

@ -16583,6 +16583,12 @@ Maximum number of basic blocks for VRP to use a basic cache vector.
@item avoid-fma-max-bits
Maximum number of bits for which we avoid creating FMAs.
@item fully-pipelined-fma
Whether the target fully pipelines FMA instructions. If non-zero,
reassociation considers the benefit of parallelizing FMA's multiplication
part and addition part, assuming FMUL and FMA use the same units that can
also do FADD.
@item sms-loop-average-count-threshold
A threshold on the average loop count considered by the swing modulo scheduler.

View file

@ -134,6 +134,13 @@ Maximal estimated growth of function body caused by early inlining of single cal
Common Joined UInteger Var(param_fsm_scale_path_stmts) Init(2) IntegerRange(1, 10) Param Optimization
Scale factor to apply to the number of statements in a threading path crossing a loop backedge when comparing to max-jump-thread-duplication-stmts.
-param=fully-pipelined-fma=
Common Joined UInteger Var(param_fully_pipelined_fma) Init(0) IntegerRange(0, 1) Param Optimization
Whether the target fully pipelines FMA instructions. If non-zero,
reassociation considers the benefit of parallelizing FMA's multiplication
part and addition part, assuming FMUL and FMA use the same units that can
also do FADD.
-param=gcse-after-reload-critical-fraction=
Common Joined UInteger Var(param_gcse_after_reload_critical_fraction) Init(10) Param Optimization
The threshold ratio of critical edges execution count that permit performing redundancy elimination after reload.

View file

@ -0,0 +1,41 @@
/* PR tree-optimization/110279 */
/* { dg-do compile } */
/* { dg-options "-Ofast --param tree-reassoc-width=4 --param fully-pipelined-fma=1 -fdump-tree-reassoc2-details -fdump-tree-optimized" } */
/* { dg-additional-options "-march=armv8.2-a" { target aarch64-*-* } } */
#define LOOP_COUNT 800000000
typedef double data_e;
#include <stdio.h>
__attribute_noinline__ data_e
foo (data_e in)
{
data_e a1, a2, a3, a4;
data_e tmp, result = 0;
a1 = in + 0.1;
a2 = in * 0.1;
a3 = in + 0.01;
a4 = in * 0.59;
data_e result2 = 0;
for (int ic = 0; ic < LOOP_COUNT; ic++)
{
/* Test that a complete FMA chain with length=4 is not broken. */
tmp = a1 + a2 * a2 + a3 * a3 + a4 * a4 ;
result += tmp - ic;
result2 = result2 / 2 - tmp;
a1 += 0.91;
a2 += 0.1;
a3 -= 0.01;
a4 -= 0.89;
}
return result + result2;
}
/* { dg-final { scan-tree-dump-not "was chosen for reassociation" "reassoc2"} } */
/* { dg-final { scan-tree-dump-times {\.FMA } 3 "optimized"} } */

View file

@ -5430,13 +5430,35 @@ get_required_cycles (int ops_num, int cpu_width)
return res;
}
/* Given that the target fully pipelines FMA instructions, return the latency
of MULT_EXPRs that can't be hidden by the FMAs. WIDTH is the number of
pipes. */
static inline int
get_mult_latency_consider_fma (int ops_num, int mult_num, int width)
{
gcc_checking_assert (mult_num && mult_num <= ops_num);
/* For each partition, if mult_num == ops_num, there's latency(MULT)*2.
e.g:
A * B + C * D
=>
_1 = A * B;
_2 = .FMA (C, D, _1);
Otherwise there's latency(MULT)*1 in the first FMA. */
return CEIL (ops_num, width) == CEIL (mult_num, width) ? 2 : 1;
}
/* Returns an optimal number of registers to use for computation of
given statements.
LHS is the result ssa name of OPS. */
LHS is the result ssa name of OPS. MULT_NUM is number of sub-expressions
that are MULT_EXPRs, when OPS are PLUS_EXPRs or MINUS_EXPRs. */
static int
get_reassociation_width (vec<operand_entry *> *ops, tree lhs,
get_reassociation_width (vec<operand_entry *> *ops, int mult_num, tree lhs,
enum tree_code opc, machine_mode mode)
{
int param_width = param_tree_reassoc_width;
@ -5462,16 +5484,68 @@ get_reassociation_width (vec<operand_entry *> *ops, tree lhs,
so we can perform a binary search for the minimal width that still
results in the optimal cycle count. */
width_min = 1;
while (width > width_min)
{
int width_mid = (width + width_min) / 2;
if (get_required_cycles (ops_num, width_mid) == cycles_best)
width = width_mid;
else if (width_min < width_mid)
width_min = width_mid;
else
break;
/* If the target fully pipelines FMA instruction, the multiply part can start
already if its operands are ready. Assuming symmetric pipes are used for
FMUL/FADD/FMA, then for a sequence of FMA like:
_8 = .FMA (_2, _3, _1);
_9 = .FMA (_5, _4, _8);
_10 = .FMA (_7, _6, _9);
, if width=1, the latency is latency(MULT) + latency(ADD)*3.
While with width=2:
_8 = _4 * _5;
_9 = .FMA (_2, _3, _1);
_10 = .FMA (_6, _7, _8);
_11 = _9 + _10;
, it is latency(MULT)*2 + latency(ADD)*2. Assuming latency(MULT) >=
latency(ADD), the first variant is preferred.
Find out if we can get a smaller width considering FMA. */
if (width > 1 && mult_num && param_fully_pipelined_fma)
{
/* When param_fully_pipelined_fma is set, assume FMUL and FMA use the
same units that can also do FADD. For other scenarios, such as when
FMUL and FADD are using separated units, the following code may not
appy. */
int width_mult = targetm.sched.reassociation_width (MULT_EXPR, mode);
gcc_checking_assert (width_mult <= width);
/* Latency of MULT_EXPRs. */
int lat_mul
= get_mult_latency_consider_fma (ops_num, mult_num, width_mult);
/* Quick search might not apply. So start from 1. */
for (int i = 1; i < width_mult; i++)
{
int lat_mul_new
= get_mult_latency_consider_fma (ops_num, mult_num, i);
int lat_add_new = get_required_cycles (ops_num, i);
/* Assume latency(MULT) >= latency(ADD). */
if (lat_mul - lat_mul_new >= lat_add_new - cycles_best)
{
width = i;
break;
}
}
}
else
{
while (width > width_min)
{
int width_mid = (width + width_min) / 2;
if (get_required_cycles (ops_num, width_mid) == cycles_best)
width = width_mid;
else if (width_min < width_mid)
width_min = width_mid;
else
break;
}
}
/* If there's loop dependent FMA result, return width=2 to avoid it. This is
@ -6841,8 +6915,10 @@ transform_stmt_to_multiply (gimple_stmt_iterator *gsi, gimple *stmt,
Rearrange ops to -> e + a * b + c * d generates:
_4 = .FMA (c_7(D), d_8(D), _3);
_11 = .FMA (a_5(D), b_6(D), _4); */
static bool
_11 = .FMA (a_5(D), b_6(D), _4);
Return the number of MULT_EXPRs in the chain. */
static int
rank_ops_for_fma (vec<operand_entry *> *ops)
{
operand_entry *oe;
@ -6856,9 +6932,26 @@ rank_ops_for_fma (vec<operand_entry *> *ops)
if (TREE_CODE (oe->op) == SSA_NAME)
{
gimple *def_stmt = SSA_NAME_DEF_STMT (oe->op);
if (is_gimple_assign (def_stmt)
&& gimple_assign_rhs_code (def_stmt) == MULT_EXPR)
ops_mult.safe_push (oe);
if (is_gimple_assign (def_stmt))
{
if (gimple_assign_rhs_code (def_stmt) == MULT_EXPR)
ops_mult.safe_push (oe);
/* A negate on the multiplication leads to FNMA. */
else if (gimple_assign_rhs_code (def_stmt) == NEGATE_EXPR
&& TREE_CODE (gimple_assign_rhs1 (def_stmt)) == SSA_NAME)
{
gimple *neg_def_stmt
= SSA_NAME_DEF_STMT (gimple_assign_rhs1 (def_stmt));
if (is_gimple_assign (neg_def_stmt)
&& gimple_bb (neg_def_stmt) == gimple_bb (def_stmt)
&& gimple_assign_rhs_code (neg_def_stmt) == MULT_EXPR)
ops_mult.safe_push (oe);
else
ops_others.safe_push (oe);
}
else
ops_others.safe_push (oe);
}
else
ops_others.safe_push (oe);
}
@ -6874,7 +6967,8 @@ rank_ops_for_fma (vec<operand_entry *> *ops)
Putting ops that not def from mult in front can generate more FMAs.
2. If all ops are defined with mult, we don't need to rearrange them. */
if (ops_mult.length () >= 2 && ops_mult.length () != ops_length)
unsigned mult_num = ops_mult.length ();
if (mult_num >= 2 && mult_num != ops_length)
{
/* Put no-mult ops and mult ops alternately at the end of the
queue, which is conducive to generating more FMA and reducing the
@ -6890,9 +6984,8 @@ rank_ops_for_fma (vec<operand_entry *> *ops)
if (opindex > 0)
opindex--;
}
return true;
}
return false;
return mult_num;
}
/* Reassociate expressions in basic block BB and its post-dominator as
children.
@ -7057,8 +7150,8 @@ reassociate_bb (basic_block bb)
{
machine_mode mode = TYPE_MODE (TREE_TYPE (lhs));
int ops_num = ops.length ();
int width;
bool has_fma = false;
int width = 0;
int mult_num = 0;
/* For binary bit operations, if there are at least 3
operands and the last operand in OPS is a constant,
@ -7081,16 +7174,17 @@ reassociate_bb (basic_block bb)
opt_type)
&& (rhs_code == PLUS_EXPR || rhs_code == MINUS_EXPR))
{
has_fma = rank_ops_for_fma (&ops);
mult_num = rank_ops_for_fma (&ops);
}
/* Only rewrite the expression tree to parallel in the
last reassoc pass to avoid useless work back-and-forth
with initial linearization. */
bool has_fma = mult_num >= 2 && mult_num != ops_num;
if (!reassoc_insert_powi_p
&& ops.length () > 3
&& (width
= get_reassociation_width (&ops, lhs, rhs_code, mode))
&& (width = get_reassociation_width (&ops, mult_num, lhs,
rhs_code, mode))
> 1)
{
if (dump_file && (dump_flags & TDF_DETAILS))
@ -7111,10 +7205,12 @@ reassociate_bb (basic_block bb)
if (len >= 3
&& (!has_fma
/* width > 1 means ranking ops results in better
parallelism. */
|| get_reassociation_width (&ops, lhs, rhs_code,
mode)
> 1))
parallelism. Check current value to avoid
calling get_reassociation_width again. */
|| (width != 1
&& get_reassociation_width (
&ops, mult_num, lhs, rhs_code, mode)
> 1)))
swap_ops_for_binary_stmt (ops, len - 3);
new_lhs = rewrite_expr_tree (stmt, rhs_code, 0, ops,