[Git][ghc/ghc][master] Improve performance of genericWordQuotRem2Op (#22966)
Marge Bot (@marge-bot)
gitlab at gitlab.haskell.org
Mon Jun 3 06:10:56 UTC 2024
Marge Bot pushed to branch master at Glasgow Haskell Compiler / GHC
Commits:
4998a6ed by Alex Mason at 2024-06-03T02:09:29-04:00
Improve performance of genericWordQuotRem2Op (#22966)
Implements the algorithm from compiler-rt's udiv128by64to64default. This
rewrite results in a roughly 24x improvement in runtime on AArch64 (and
likely any other arch that uses it).
- - - - -
3 changed files:
- compiler/GHC/StgToCmm/Prim.hs
- testsuite/tests/numeric/should_run/all.T
- + testsuite/tests/numeric/should_run/quotRem2Large.hs
Changes:
=====================================
compiler/GHC/StgToCmm/Prim.hs
=====================================
@@ -1894,53 +1894,179 @@ genericWordQuotRemOp width [res_q, res_r] [arg_x, arg_y]
(CmmMachOp (MO_U_Rem width) [arg_x, arg_y])
genericWordQuotRemOp _ _ _ = panic "genericWordQuotRemOp"
+-- Based on the algorithm from LLVM's compiler-rt:
+-- https://github.com/llvm/llvm-project/blob/7339f7ba3053db7595ece1ca5f49bd2e4c3c8305/compiler-rt/lib/builtins/udivmodti4.c#L23
+-- See that file for licensing and copyright.
genericWordQuotRem2Op :: Platform -> GenericOp
-genericWordQuotRem2Op platform [res_q, res_r] [arg_x_high, arg_x_low, arg_y]
- = emit =<< f (widthInBits (wordWidth platform)) zero arg_x_high arg_x_low
- where ty = cmmExprType platform arg_x_high
- shl x i = CmmMachOp (MO_Shl (wordWidth platform)) [x, i]
- shr x i = CmmMachOp (MO_U_Shr (wordWidth platform)) [x, i]
- or x y = CmmMachOp (MO_Or (wordWidth platform)) [x, y]
- ge x y = CmmMachOp (MO_U_Ge (wordWidth platform)) [x, y]
- ne x y = CmmMachOp (MO_Ne (wordWidth platform)) [x, y]
- minus x y = CmmMachOp (MO_Sub (wordWidth platform)) [x, y]
- times x y = CmmMachOp (MO_Mul (wordWidth platform)) [x, y]
- zero = lit 0
- one = lit 1
- negone = lit (fromIntegral (platformWordSizeInBits platform) - 1)
- lit i = CmmLit (CmmInt i (wordWidth platform))
-
- f :: Int -> CmmExpr -> CmmExpr -> CmmExpr -> FCode CmmAGraph
- f 0 acc high _ = return (mkAssign (CmmLocal res_q) acc <*>
- mkAssign (CmmLocal res_r) high)
- f i acc high low =
- do roverflowedBit <- newTemp ty
- rhigh' <- newTemp ty
- rhigh'' <- newTemp ty
- rlow' <- newTemp ty
- risge <- newTemp ty
- racc' <- newTemp ty
- let high' = CmmReg (CmmLocal rhigh')
- isge = CmmReg (CmmLocal risge)
- overflowedBit = CmmReg (CmmLocal roverflowedBit)
- let this = catAGraphs
- [mkAssign (CmmLocal roverflowedBit)
- (shr high negone),
- mkAssign (CmmLocal rhigh')
- (or (shl high one) (shr low negone)),
- mkAssign (CmmLocal rlow')
- (shl low one),
- mkAssign (CmmLocal risge)
- (or (overflowedBit `ne` zero)
- (high' `ge` arg_y)),
- mkAssign (CmmLocal rhigh'')
- (high' `minus` (arg_y `times` isge)),
- mkAssign (CmmLocal racc')
- (or (shl acc one) isge)]
- rest <- f (i - 1) (CmmReg (CmmLocal racc'))
- (CmmReg (CmmLocal rhigh''))
- (CmmReg (CmmLocal rlow'))
- return (this <*> rest)
+genericWordQuotRem2Op platform [res_q, res_r] [arg_u1, arg_u0, arg_v]
+ = do
+ -- v gets modified below based on clz v
+ v <- newTemp ty
+ emit $ mkAssign (CmmLocal v) arg_v
+ go arg_u1 arg_u0 v
+ where ty = cmmExprType platform arg_u1
+ shl x i = CmmMachOp (MO_Shl (wordWidth platform)) [x, i]
+ shr x i = CmmMachOp (MO_U_Shr (wordWidth platform)) [x, i]
+ or x y = CmmMachOp (MO_Or (wordWidth platform)) [x, y]
+ ge x y = CmmMachOp (MO_U_Ge (wordWidth platform)) [x, y]
+ le x y = CmmMachOp (MO_U_Le (wordWidth platform)) [x, y]
+ eq x y = CmmMachOp (MO_Eq (wordWidth platform)) [x, y]
+ plus x y = CmmMachOp (MO_Add (wordWidth platform)) [x, y]
+ minus x y = CmmMachOp (MO_Sub (wordWidth platform)) [x, y]
+ times x y = CmmMachOp (MO_Mul (wordWidth platform)) [x, y]
+ udiv x y = CmmMachOp (MO_U_Quot (wordWidth platform)) [x, y]
+ and x y = CmmMachOp (MO_And (wordWidth platform)) [x, y]
+ lit i = CmmLit (CmmInt i (wordWidth platform))
+ one = lit 1
+ zero = lit 0
+ masklow = lit ((1 `shiftL` (platformWordSizeInBits platform `div` 2)) - 1)
+ gotoIf pred target = emit =<< mkCmmIfGoto pred target
+ mkTmp ty = do
+ t <- newTemp ty
+ pure (t, CmmReg (CmmLocal t))
+ infixr 8 .=
+ r .= e = emit $ mkAssign (CmmLocal r) e
+
+ go :: CmmActual -> CmmActual -> LocalReg -> FCode ()
+ go u1 u0 v = do
+ -- Computes (ret,r) = (u1<<WORDSIZE*8+u0) `divMod` v
+ -- du_int udiv128by64to64default(du_int u1, du_int u0, du_int v, du_int *r)
+ -- const unsigned n_udword_bits = sizeof(du_int) * CHAR_BIT;
+ let n_udword_bits' = widthInBits (wordWidth platform)
+ n_udword_bits = fromIntegral n_udword_bits'
+ -- const du_int b = (1ULL << (n_udword_bits / 2)); // Number base (32 bits)
+ b = 1 `shiftL` (n_udword_bits' `div` 2)
+ v' = CmmReg (CmmLocal v)
+ -- du_int un1, un0; // Norm. dividend LSD's
+ (un1, un1') <- mkTmp ty
+ (un0, un0') <- mkTmp ty
+ -- du_int vn1, vn0; // Norm. divisor digits
+ (vn1, vn1') <- mkTmp ty
+ (vn0, vn0') <- mkTmp ty
+ -- du_int q1, q0; // Quotient digits
+ (q1, q1') <- mkTmp ty
+ (q0, q0') <- mkTmp ty
+ -- du_int un64, un21, un10; // Dividend digit pairs
+ (un64, un64') <- mkTmp ty
+ (un21, un21') <- mkTmp ty
+ (un10, un10') <- mkTmp ty
+
+ -- du_int rhat; // A remainder
+ (rhat, rhat') <- mkTmp ty
+ -- si_int s; // Shift amount for normalization
+ (s, s') <- mkTmp ty
+
+ -- s = __builtin_clzll(v);
+ -- clz(0) in GHC returns N on N bit systems, whereas
+ -- __builtin_clzll returns 0 (or is undefined)
+ emitClzCall s v' (wordWidth platform)
+
+ if_else <- newBlockId
+ if_done <- newBlockId
+ -- if (s > 0) {
+ -- actually if (s > 0 && s /= wordSizeInBits) {
+ gotoIf (s' `eq` zero) if_else
+ gotoIf (s' `eq` lit n_udword_bits) if_else
+ do
+ -- // Normalize the divisor.
+ -- v = v << s;
+ v .= shl v' s'
+ -- un64 = (u1 << s) | (u0 >> (n_udword_bits - s));
+ un64 .= (u1 `shl` s') `or` (u0 `shr` (lit n_udword_bits `minus` s'))
+ -- un10 = u0 << s; // Shift dividend left
+ un10 .= shl u0 s'
+ emit $ mkBranch if_done
+ -- } else {
+ do
+ -- // Avoid undefined behavior of (u0 >> 64).
+ emitLabel if_else
+ -- un64 = u1;
+ un64 .= u1
+ -- un10 = u0;
+ un10 .= u0
+ s .= lit 0 -- Otherwise leads to >>/<< 64
+ -- }
+ emitLabel if_done
+
+ -- // Break divisor up into two 32-bit digits.
+ -- vn1 = v >> (n_udword_bits / 2);
+ vn1 .= v' `shr` lit (n_udword_bits `div` 2)
+ -- vn0 = v & 0xFFFFFFFF;
+ vn0 .= v' `and` masklow
+
+ -- // Break right half of dividend into two digits.
+ -- un1 = un10 >> (n_udword_bits / 2);
+ un1 .= un10' `shr` lit (n_udword_bits `div` 2)
+ -- un0 = un10 & 0xFFFFFFFF;
+ un0 .= un10' `and` masklow
+
+ -- // Compute the first quotient digit, q1.
+ -- q1 = un64 / vn1;
+ q1 .= un64' `udiv` vn1'
+ -- rhat = un64 - q1 * vn1;
+ rhat .= un64' `minus` times q1' vn1'
+
+ while_1_entry <- newBlockId
+ while_1_body <- newBlockId
+ while_1_done <- newBlockId
+ -- // q1 has at most error 2. No more than 2 iterations.
+ -- while (q1 >= b || q1 * vn0 > b * rhat + un1) {
+ emitLabel while_1_entry
+ gotoIf (q1' `ge` lit b) while_1_body
+ gotoIf (le (times q1' vn0')
+ (times (lit b) rhat' `plus` un1'))
+ while_1_done
+ do
+ emitLabel while_1_body
+ -- q1 = q1 - 1;
+ q1 .= q1' `minus` one
+ -- rhat = rhat + vn1;
+ rhat .= rhat' `plus` vn1'
+ -- if (rhat >= b)
+ -- break;
+ gotoIf (rhat' `ge` lit b)
+ while_1_done
+ emit $ mkBranch while_1_entry
+ -- }
+ emitLabel while_1_done
+
+ -- un21 = un64 * b + un1 - q1 * v;
+ un21 .= (times un64' (lit b) `plus` un1') `minus` times q1' v'
+
+ -- // Compute the second quotient digit.
+ -- q0 = un21 / vn1;
+ q0 .= un21' `udiv` vn1'
+ -- rhat = un21 - q0 * vn1;
+ rhat .= un21' `minus` times q0' vn1'
+
+ -- // q0 has at most error 2. No more than 2 iterations.
+ while_2_entry <- newBlockId
+ while_2_body <- newBlockId
+ while_2_done <- newBlockId
+ emitLabel while_2_entry
+ -- while (q0 >= b || q0 * vn0 > b * rhat + un0) {
+ gotoIf (q0' `ge` lit b)
+ while_2_body
+ gotoIf (le (times q0' vn0')
+ (times (lit b) rhat' `plus` un0'))
+ while_2_done
+ do
+ emitLabel while_2_body
+ -- q0 = q0 - 1;
+ q0 .= q0' `minus` one
+ -- rhat = rhat + vn1;
+ rhat .= rhat' `plus` vn1'
+ -- if (rhat >= b)
+ -- break;
+ gotoIf (rhat' `ge` lit b) while_2_done
+ emit $ mkBranch while_2_entry
+ -- }
+ emitLabel while_2_done
+
+ -- r = (un21 * b + un0 - q0 * v) >> s;
+ res_r .= ((times un21' (lit b) `plus` un0') `minus` times q0' v') `shr` s'
+ -- return q1 * b + q0;
+ res_q .= times q1' (lit b) `plus` q0'
genericWordQuotRem2Op _ _ _ = panic "genericWordQuotRem2Op"
genericWordAdd2Op :: GenericOp
=====================================
testsuite/tests/numeric/should_run/all.T
=====================================
@@ -52,6 +52,7 @@ test('add2', normal, compile_and_run, ['-fobject-code'])
test('mul2', normal, compile_and_run, ['-fobject-code'])
test('mul2int', normal, compile_and_run, ['-fobject-code'])
test('quotRem2', normal, compile_and_run, ['-fobject-code'])
+test('quotRem2Large', normal, compile_and_run, ['-fobject-code'])
test('T5863', normal, compile_and_run, [''])
test('T7014', js_skip, makefile_test, [])
=====================================
testsuite/tests/numeric/should_run/quotRem2Large.hs
=====================================
The diff for this file was not included because it is too large.
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/4998a6edb61d3c3f5542106322cee56105b88f91
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/4998a6edb61d3c3f5542106322cee56105b88f91
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240603/14789bdf/attachment-0001.html>
More information about the ghc-commits
mailing list