[Git][ghc/ghc][master] Improve performance of genericWordQuotRem2Op (#22966)

Marge Bot (@marge-bot) gitlab at gitlab.haskell.org
Mon Jun 3 06:10:56 UTC 2024



Marge Bot pushed to branch master at Glasgow Haskell Compiler / GHC


Commits:
4998a6ed by Alex Mason at 2024-06-03T02:09:29-04:00
Improve performance of genericWordQuotRem2Op (#22966)

Implements the algorithm from compiler-rt's udiv128by64to64default. This
rewrite results in a roughly 24x improvement in runtime on AArch64 (and
likely any other arch that uses it).

- - - - -


3 changed files:

- compiler/GHC/StgToCmm/Prim.hs
- testsuite/tests/numeric/should_run/all.T
- + testsuite/tests/numeric/should_run/quotRem2Large.hs


Changes:

=====================================
compiler/GHC/StgToCmm/Prim.hs
=====================================
@@ -1894,53 +1894,179 @@ genericWordQuotRemOp width [res_q, res_r] [arg_x, arg_y]
                (CmmMachOp (MO_U_Rem  width) [arg_x, arg_y])
 genericWordQuotRemOp _ _ _ = panic "genericWordQuotRemOp"
 
+-- Based on the algorithm from LLVM's compiler-rt:
+-- https://github.com/llvm/llvm-project/blob/7339f7ba3053db7595ece1ca5f49bd2e4c3c8305/compiler-rt/lib/builtins/udivmodti4.c#L23
+-- See that file for licensing and copyright.
 genericWordQuotRem2Op :: Platform -> GenericOp
-genericWordQuotRem2Op platform [res_q, res_r] [arg_x_high, arg_x_low, arg_y]
-    = emit =<< f (widthInBits (wordWidth platform)) zero arg_x_high arg_x_low
-    where    ty = cmmExprType platform arg_x_high
-             shl   x i = CmmMachOp (MO_Shl   (wordWidth platform)) [x, i]
-             shr   x i = CmmMachOp (MO_U_Shr (wordWidth platform)) [x, i]
-             or    x y = CmmMachOp (MO_Or    (wordWidth platform)) [x, y]
-             ge    x y = CmmMachOp (MO_U_Ge  (wordWidth platform)) [x, y]
-             ne    x y = CmmMachOp (MO_Ne    (wordWidth platform)) [x, y]
-             minus x y = CmmMachOp (MO_Sub   (wordWidth platform)) [x, y]
-             times x y = CmmMachOp (MO_Mul   (wordWidth platform)) [x, y]
-             zero   = lit 0
-             one    = lit 1
-             negone = lit (fromIntegral (platformWordSizeInBits platform) - 1)
-             lit i = CmmLit (CmmInt i (wordWidth platform))
-
-             f :: Int -> CmmExpr -> CmmExpr -> CmmExpr -> FCode CmmAGraph
-             f 0 acc high _ = return (mkAssign (CmmLocal res_q) acc <*>
-                                      mkAssign (CmmLocal res_r) high)
-             f i acc high low =
-                 do roverflowedBit <- newTemp ty
-                    rhigh'         <- newTemp ty
-                    rhigh''        <- newTemp ty
-                    rlow'          <- newTemp ty
-                    risge          <- newTemp ty
-                    racc'          <- newTemp ty
-                    let high'         = CmmReg (CmmLocal rhigh')
-                        isge          = CmmReg (CmmLocal risge)
-                        overflowedBit = CmmReg (CmmLocal roverflowedBit)
-                    let this = catAGraphs
-                               [mkAssign (CmmLocal roverflowedBit)
-                                          (shr high negone),
-                                mkAssign (CmmLocal rhigh')
-                                          (or (shl high one) (shr low negone)),
-                                mkAssign (CmmLocal rlow')
-                                          (shl low one),
-                                mkAssign (CmmLocal risge)
-                                          (or (overflowedBit `ne` zero)
-                                              (high' `ge` arg_y)),
-                                mkAssign (CmmLocal rhigh'')
-                                          (high' `minus` (arg_y `times` isge)),
-                                mkAssign (CmmLocal racc')
-                                          (or (shl acc one) isge)]
-                    rest <- f (i - 1) (CmmReg (CmmLocal racc'))
-                                      (CmmReg (CmmLocal rhigh''))
-                                      (CmmReg (CmmLocal rlow'))
-                    return (this <*> rest)
+genericWordQuotRem2Op platform [res_q, res_r] [arg_u1, arg_u0, arg_v]
+    = do
+      -- v gets modified below based on clz v
+      v <- newTemp ty
+      emit $ mkAssign (CmmLocal v) arg_v
+      go arg_u1 arg_u0 v
+  where   ty = cmmExprType platform arg_u1
+          shl   x i = CmmMachOp (MO_Shl    (wordWidth platform)) [x, i]
+          shr   x i = CmmMachOp (MO_U_Shr  (wordWidth platform)) [x, i]
+          or    x y = CmmMachOp (MO_Or     (wordWidth platform)) [x, y]
+          ge    x y = CmmMachOp (MO_U_Ge   (wordWidth platform)) [x, y]
+          le    x y = CmmMachOp (MO_U_Le   (wordWidth platform)) [x, y]
+          eq    x y = CmmMachOp (MO_Eq     (wordWidth platform)) [x, y]
+          plus  x y = CmmMachOp (MO_Add    (wordWidth platform)) [x, y]
+          minus x y = CmmMachOp (MO_Sub    (wordWidth platform)) [x, y]
+          times x y = CmmMachOp (MO_Mul    (wordWidth platform)) [x, y]
+          udiv  x y = CmmMachOp (MO_U_Quot (wordWidth platform)) [x, y]
+          and   x y = CmmMachOp (MO_And    (wordWidth platform)) [x, y]
+          lit i     = CmmLit (CmmInt i (wordWidth platform))
+          one       = lit 1
+          zero      = lit 0
+          masklow   = lit ((1 `shiftL` (platformWordSizeInBits platform `div` 2)) - 1)
+          gotoIf pred target = emit =<< mkCmmIfGoto pred target
+          mkTmp ty = do
+            t <- newTemp ty
+            pure (t, CmmReg (CmmLocal t))
+          infixr 8 .=
+          r .= e = emit $ mkAssign (CmmLocal r) e
+
+          go :: CmmActual -> CmmActual -> LocalReg -> FCode ()
+          go u1 u0 v = do
+            -- Computes (ret,r) = (u1<<WORDSIZE*8+u0) `divMod` v
+            -- du_int udiv128by64to64default(du_int u1, du_int u0, du_int v, du_int *r)
+            -- const unsigned n_udword_bits = sizeof(du_int) * CHAR_BIT;
+            let n_udword_bits' = widthInBits (wordWidth platform)
+                n_udword_bits = fromIntegral n_udword_bits'
+            -- const du_int b = (1ULL << (n_udword_bits / 2)); // Number base (32 bits)
+                b = 1 `shiftL` (n_udword_bits' `div` 2)
+                v' = CmmReg (CmmLocal v)
+            -- du_int un1, un0;                                // Norm. dividend LSD's
+            (un1, un1')   <- mkTmp ty
+            (un0, un0')   <- mkTmp ty
+            -- du_int vn1, vn0;                                // Norm. divisor digits
+            (vn1, vn1')   <- mkTmp ty
+            (vn0, vn0')   <- mkTmp ty
+            -- du_int q1, q0;                                  // Quotient digits
+            (q1, q1')     <- mkTmp ty
+            (q0, q0')     <- mkTmp ty
+            -- du_int un64, un21, un10;                        // Dividend digit pairs
+            (un64, un64') <- mkTmp ty
+            (un21, un21') <- mkTmp ty
+            (un10, un10') <- mkTmp ty
+
+            -- du_int rhat;                                    // A remainder
+            (rhat, rhat') <- mkTmp ty
+            -- si_int s;                                       // Shift amount for normalization
+            (s, s')       <- mkTmp ty
+
+            -- s = __builtin_clzll(v);
+            -- clz(0) in GHC returns N on N bit systems, whereas
+            -- __builtin_clzll returns 0 (or is undefined)
+            emitClzCall s v' (wordWidth platform)
+
+            if_else <- newBlockId
+            if_done <- newBlockId
+            -- if (s > 0) {
+            -- actually if (s > 0 && s /= wordSizeInBits) {
+            gotoIf (s' `eq` zero) if_else
+            gotoIf (s' `eq` lit n_udword_bits) if_else
+            do
+              --   // Normalize the divisor.
+              --   v = v << s;
+              v .= shl v' s'
+              --   un64 = (u1 << s) | (u0 >> (n_udword_bits - s));
+              un64 .= (u1 `shl` s') `or` (u0 `shr` (lit n_udword_bits `minus` s'))
+              --   un10 = u0 << s; // Shift dividend left
+              un10 .= shl u0 s'
+              emit $ mkBranch if_done
+            -- } else {
+            do
+              --   // Avoid undefined behavior of (u0 >> 64).
+              emitLabel if_else
+              --   un64 = u1;
+              un64 .= u1
+              --   un10 = u0;
+              un10 .= u0
+              s .= lit 0 -- Otherwise leads to >>/<< 64
+              -- }
+            emitLabel if_done
+
+            -- // Break divisor up into two 32-bit digits.
+            -- vn1 = v >> (n_udword_bits / 2);
+            vn1 .= v' `shr` lit (n_udword_bits `div` 2)
+            -- vn0 = v & 0xFFFFFFFF;
+            vn0 .= v' `and` masklow
+
+            -- // Break right half of dividend into two digits.
+            -- un1 = un10 >> (n_udword_bits / 2);
+            un1 .= un10' `shr`  lit (n_udword_bits `div` 2)
+            -- un0 = un10 & 0xFFFFFFFF;
+            un0 .= un10' `and` masklow
+
+            -- // Compute the first quotient digit, q1.
+            -- q1 = un64 / vn1;
+            q1 .= un64' `udiv` vn1'
+            -- rhat = un64 - q1 * vn1;
+            rhat .= un64' `minus` times q1' vn1'
+
+            while_1_entry <- newBlockId
+            while_1_body  <- newBlockId
+            while_1_done  <- newBlockId
+            -- // q1 has at most error 2. No more than 2 iterations.
+            -- while (q1 >= b || q1 * vn0 > b * rhat + un1) {
+            emitLabel while_1_entry
+            gotoIf (q1' `ge` lit b) while_1_body
+            gotoIf (le (times q1' vn0')
+                       (times (lit b) rhat' `plus` un1'))
+                  while_1_done
+            do
+              emitLabel while_1_body
+              --   q1 = q1 - 1;
+              q1 .= q1' `minus` one
+              --   rhat = rhat + vn1;
+              rhat .= rhat' `plus` vn1'
+              --   if (rhat >= b)
+              --     break;
+              gotoIf (rhat' `ge` lit b)
+                      while_1_done
+              emit $ mkBranch while_1_entry
+            -- }
+            emitLabel while_1_done
+
+            -- un21 = un64 * b + un1 - q1 * v;
+            un21 .= (times un64' (lit b) `plus` un1') `minus` times q1' v'
+
+            -- // Compute the second quotient digit.
+            -- q0 = un21 / vn1;
+            q0 .= un21' `udiv` vn1'
+            -- rhat = un21 - q0 * vn1;
+            rhat .= un21' `minus` times q0' vn1'
+
+            -- // q0 has at most error 2. No more than 2 iterations.
+            while_2_entry <- newBlockId
+            while_2_body  <- newBlockId
+            while_2_done  <- newBlockId
+            emitLabel while_2_entry
+            -- while (q0 >= b || q0 * vn0 > b * rhat + un0) {
+            gotoIf (q0' `ge` lit b)
+                   while_2_body
+            gotoIf (le (times q0' vn0')
+                       (times (lit b) rhat' `plus` un0'))
+                   while_2_done
+            do
+              emitLabel while_2_body
+              --   q0 = q0 - 1;
+              q0 .= q0' `minus` one
+              --   rhat = rhat + vn1;
+              rhat .= rhat' `plus` vn1'
+              --   if (rhat >= b)
+              --     break;
+              gotoIf (rhat' `ge` lit b) while_2_done
+              emit $ mkBranch while_2_entry
+            -- }
+            emitLabel while_2_done
+
+            --   r = (un21 * b + un0 - q0 * v) >> s;
+            res_r .= ((times un21' (lit b) `plus` un0') `minus` times q0' v') `shr` s'
+            -- return q1 * b + q0;
+            res_q .= times q1' (lit b) `plus` q0'
 genericWordQuotRem2Op _ _ _ = panic "genericWordQuotRem2Op"
 
 genericWordAdd2Op :: GenericOp


=====================================
testsuite/tests/numeric/should_run/all.T
=====================================
@@ -52,6 +52,7 @@ test('add2', normal, compile_and_run, ['-fobject-code'])
 test('mul2', normal, compile_and_run, ['-fobject-code'])
 test('mul2int', normal, compile_and_run, ['-fobject-code'])
 test('quotRem2', normal, compile_and_run, ['-fobject-code'])
+test('quotRem2Large', normal, compile_and_run, ['-fobject-code'])
 test('T5863', normal, compile_and_run, [''])
 
 test('T7014', js_skip, makefile_test, [])


=====================================
testsuite/tests/numeric/should_run/quotRem2Large.hs
=====================================
The diff for this file was not included because it is too large.


View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/4998a6edb61d3c3f5542106322cee56105b88f91

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/4998a6edb61d3c3f5542106322cee56105b88f91
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20240603/14789bdf/attachment-0001.html>


More information about the ghc-commits mailing list