[commit: ghc] wip/T14068: Detect functions where all recursive calls are tail-recursive (456cbbd)

git at git.haskell.org git at git.haskell.org
Tue Aug 1 13:49:10 UTC 2017


Repository : ssh://git@git.haskell.org/ghc

On branch  : wip/T14068
Link       : http://ghc.haskell.org/trac/ghc/changeset/456cbbde2bead57319c9fa0d27bc88f1d5523625/ghc

>---------------------------------------------------------------

commit 456cbbde2bead57319c9fa0d27bc88f1d5523625
Author: Joachim Breitner <mail at joachim-breitner.de>
Date:   Tue Aug 1 09:47:49 2017 -0400

    Detect functions where all recursive calls are tail-recursive
    
    This is the first half of #14068.


>---------------------------------------------------------------

456cbbde2bead57319c9fa0d27bc88f1d5523625
 compiler/basicTypes/BasicTypes.hs | 20 ++++++++++++++------
 compiler/simplCore/OccurAnal.hs   | 25 +++++++++++++++++++++----
 2 files changed, 35 insertions(+), 10 deletions(-)

diff --git a/compiler/basicTypes/BasicTypes.hs b/compiler/basicTypes/BasicTypes.hs
index 90a043d..284ddfe 100644
--- a/compiler/basicTypes/BasicTypes.hs
+++ b/compiler/basicTypes/BasicTypes.hs
@@ -935,6 +935,7 @@ notOneBranch = False
 
 -----------------
 data TailCallInfo = AlwaysTailCalled JoinArity -- See Note [TailCallInfo]
+                  | RecursiveTailCalled JoinArity
                   | NoTailCallInfo
   deriving (Eq)
 
@@ -948,12 +949,14 @@ zapOccTailCallInfo occ       = occ { occ_tail = NoTailCallInfo }
 
 isAlwaysTailCalled :: OccInfo -> Bool
 isAlwaysTailCalled occ
-  = case tailCallInfo occ of AlwaysTailCalled{} -> True
-                             NoTailCallInfo     -> False
+  = case tailCallInfo occ of AlwaysTailCalled{}     -> True
+                             RecursiveTailCalled {} -> False
+                             NoTailCallInfo         -> False
 
 instance Outputable TailCallInfo where
-  ppr (AlwaysTailCalled ar) = sep [ text "Tail", int ar ]
-  ppr _                     = empty
+  ppr (AlwaysTailCalled ar)    = sep [ text "Tail", int ar ]
+  ppr (RecursiveTailCalled ar) = sep [ text "Tail(rec)", int ar ]
+  ppr _                        = empty
 
 -----------------
 strongLoopBreaker, weakLoopBreaker :: OccInfo
@@ -1003,8 +1006,9 @@ instance Outputable OccInfo where
           pp_tail             = pprShortTailCallInfo tail_info
 
 pprShortTailCallInfo :: TailCallInfo -> SDoc
-pprShortTailCallInfo (AlwaysTailCalled ar) = char 'T' <> brackets (int ar)
-pprShortTailCallInfo NoTailCallInfo        = empty
+pprShortTailCallInfo (AlwaysTailCalled ar)    = char 'T'  <> brackets (int ar)
+pprShortTailCallInfo (RecursiveTailCalled ar) = text "TR" <> brackets (int ar)
+pprShortTailCallInfo NoTailCallInfo           = empty
 
 {-
 Note [TailCallInfo]
@@ -1037,6 +1041,10 @@ point can also be invoked from other join points, not just from case branches:
 Here both 'j1' and 'j2' will get marked AlwaysTailCalled, but j1 will get
 ManyOccs and j2 will get `OneOcc { occ_one_br = True }`.
 
+The RecursiveTailCalled marker, which is only valid for a recursive binder,
+says: All recursive calls are tail calls in the sense of AlwaysTailCalled,
+even if some calls in the body might not be.
+
 ************************************************************************
 *                                                                      *
                 Default method specification
diff --git a/compiler/simplCore/OccurAnal.hs b/compiler/simplCore/OccurAnal.hs
index dbe1c48..a652e1c 100644
--- a/compiler/simplCore/OccurAnal.hs
+++ b/compiler/simplCore/OccurAnal.hs
@@ -2633,8 +2633,9 @@ tagRecBinders lvl body_uds triples
 
      -- 1. Determine join-point-hood of whole group, as determined by
      --    the *unadjusted* usage details
-     unadj_uds     = body_uds +++ combineUsageDetailsList rhs_udss
-     will_be_joins = decideJoinPointHood lvl unadj_uds bndrs
+     unadj_uds_rhss = combineUsageDetailsList rhs_udss
+     unadj_uds      = body_uds +++ unadj_uds_rhss
+     will_be_joins  = decideJoinPointHood lvl unadj_uds bndrs
 
      -- 2. Adjust usage details of each RHS, taking into account the
      --    join-point-hood decision
@@ -2658,8 +2659,20 @@ tagRecBinders lvl body_uds triples
      adj_uds   = body_uds +++ combineUsageDetailsList rhs_udss'
 
      -- 4. Tag each binder with its adjusted details
-     bndrs'    = [ setBinderOcc (lookupDetails adj_uds bndr) bndr
-                 | bndr <- bndrs ]
+     bndrs'
+        -- 4a. If this is only one function, and the recursive calls are
+        --     tail calls, then the simplifier turn it into a non-recursive function
+        --     with a local joinrec.
+        | [bndr] <- bndrs
+        , let occ_rhs = lookupDetails unadj_uds_rhss bndr
+        , AlwaysTailCalled arity <- tailCallInfo occ_rhs
+        = let occ = lookupDetails adj_uds bndr
+              occ' = markRecursiveTailCalled arity occ
+          in [ setBinderOcc occ' bndr ]
+        -- 4b. Otherwise, just use the adjusted details
+        | otherwise
+        = [ setBinderOcc (lookupDetails adj_uds bndr) bndr
+          | bndr <- bndrs ]
 
      -- 5. Drop the binders from the adjusted details and return
      usage'    = adj_uds `delDetailsList` bndrs
@@ -2744,6 +2757,10 @@ markInsideLam occ             = occ
 markNonTailCalled IAmDead = IAmDead
 markNonTailCalled occ     = occ { occ_tail = NoTailCallInfo }
 
+markRecursiveTailCalled :: Arity -> OccInfo -> OccInfo
+markRecursiveTailCalled _     IAmDead = IAmDead
+markRecursiveTailCalled arity occ     = occ { occ_tail = RecursiveTailCalled arity }
+
 addOccInfo, orOccInfo :: OccInfo -> OccInfo -> OccInfo
 
 addOccInfo a1 a2  = ASSERT( not (isDeadOcc a1 || isDeadOcc a2) )



More information about the ghc-commits mailing list