[Git][ghc/ghc][wip/T22491] Add a missing varToCoreExpr in etaBodyForJoinPoint

Simon Peyton Jones (@simonpj) gitlab at gitlab.haskell.org
Tue Nov 22 14:59:23 UTC 2022



Simon Peyton Jones pushed to branch wip/T22491 at Glasgow Haskell Compiler / GHC


Commits:
c0e02e0e by Simon Peyton Jones at 2022-11-22T15:00:35+00:00
Add a missing varToCoreExpr in etaBodyForJoinPoint

This subtle bug showed up when compiling a library with 9.4.
See #22491.  The bug is present in master, but it is hard to
trigger; the new regression test T22491 fails in 9.4

The fix is definitely right though!

I also moved the preInlineUnconditionally test in simplExprF1 to
before the call to joinPointBinding_maybe, to avoid fruitless
eta-expansion.  This is just a minor refactoring.

- - - - -


4 changed files:

- compiler/GHC/Core/Opt/Arity.hs
- compiler/GHC/Core/Opt/Simplify/Iteration.hs
- + testsuite/tests/simplCore/should_compile/T22491.hs
- testsuite/tests/simplCore/should_compile/all.T


Changes:

=====================================
compiler/GHC/Core/Opt/Arity.hs
=====================================
@@ -3104,9 +3104,13 @@ etaBodyForJoinPoint need_args body
       | Just (tv, res_ty) <- splitForAllTyCoVar_maybe ty
       , let (subst', tv') = substVarBndr subst tv
       = go (n-1) res_ty subst' (tv' : rev_bs) (e `App` varToCoreExpr tv')
+        -- The varToCoreExpr is important: `tv` might be a coercion variable
+
       | Just (_, mult, arg_ty, res_ty) <- splitFunTy_maybe ty
       , let (subst', b) = freshEtaId n subst (Scaled mult arg_ty)
-      = go (n-1) res_ty subst' (b : rev_bs) (e `App` Var b)
+      = go (n-1) res_ty subst' (b : rev_bs) (e `App` varToCoreExpr b)
+        -- The varToCoreExpr is important: `b` might be a coercion variable
+
       | otherwise
       = pprPanic "etaBodyForJoinPoint" $ int need_args $$
                                          ppr body $$ ppr (exprType body)


=====================================
compiler/GHC/Core/Opt/Simplify/Iteration.hs
=====================================
@@ -1227,6 +1227,14 @@ simplExprF1 env (Let (NonRec bndr rhs) body) cont
     do { ty' <- simplType env ty
        ; simplExprF (extendTvSubst env bndr ty') body cont }
 
+  | Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs env
+  = do { tick (PreInlineUnconditionally bndr)
+       ; simplExprF env' body cont }
+
+  -- Now check for a join point.  It's better to do the preInlineUnconditionally
+  -- test first, because joinPointBinding_maybe has to eta-expand, so a trivial
+  -- binding like { j = j2 |> co } would first be eta-expanded and then inlined
+  -- Better to test preInlineUnconditionally first.
   | Just (bndr', rhs') <- joinPointBinding_maybe bndr rhs
   = {-#SCC "simplNonRecJoinPoint" #-} simplNonRecJoinPoint env bndr' rhs' body cont
 
@@ -1680,12 +1688,17 @@ simpl_lam env bndr body (ApplyToVal { sc_arg = arg, sc_env = arg_se
                                     , sc_cont = cont, sc_dup = dup })
   | isSimplified dup  -- Don't re-simplify if we've simplified it once
                       -- See Note [Avoiding exponential behaviour]
-  =  do { tick (BetaReduction bndr)
-        ; completeBindX env bndr arg body cont }
+  = do { tick (BetaReduction bndr)
+       ; completeBindX env bndr arg body cont }
+
+  | Just env' <- preInlineUnconditionally env NotTopLevel bndr arg arg_se
+  = do { tick (PreInlineUnconditionally bndr)
+       ; -- pprTrace "preInlineUncond" (ppr bndr <+> ppr rhs) $
+         simplLam env' body cont }
 
   | otherwise         -- See Note [Avoiding exponential behaviour]
-  = do  { tick (BetaReduction bndr)
-        ; simplNonRecE env bndr (arg, arg_se) body cont }
+  = do { tick (BetaReduction bndr)
+       ; simplNonRecE env bndr (arg, arg_se) body cont }
 
 -- Discard a non-counting tick on a lambda.  This may change the
 -- cost attribution slightly (moving the allocation of the
@@ -1735,6 +1748,8 @@ simplNonRecE :: SimplEnv
 -- It deals with strict bindings, via the StrictBind continuation,
 -- which may abort the whole process.
 --
+-- Caller is expected to do the preInlineUnconditionally test
+--
 -- The RHS may not satisfy the let-can-float invariant yet
 
 simplNonRecE env bndr (rhs, rhs_se) body cont
@@ -1742,23 +1757,14 @@ simplNonRecE env bndr (rhs, rhs_se) body cont
     do { (env1, bndr1) <- simplNonRecBndr env bndr
        ; let needs_case_binding = needsCaseBinding (idType bndr1) rhs
          -- See Note [Dark corner with representation polymorphism]
-       ; if | not needs_case_binding
-            , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs rhs_se ->
-            do { tick (PreInlineUnconditionally bndr)
-               ; -- pprTrace "preInlineUncond" (ppr bndr <+> ppr rhs) $
-                 simplLam env' body cont }
-
 
-             -- Deal with strict bindings
-             -- See Note [Dark corner with representation polymorphism]
-            | isStrictId bndr1 && seCaseCase env
-            || needs_case_binding ->
+       ; if ((isStrictId bndr1 && seCaseCase env) || needs_case_binding)
+         then    -- Deal with strict bindings
             simplExprF (rhs_se `setInScopeFromE` env) rhs
                        (StrictBind { sc_bndr = bndr, sc_body = body
                                    , sc_env = env, sc_cont = cont, sc_dup = NoDup })
 
-            -- Deal with lazy bindings
-            | otherwise ->
+         else   -- Deal with lazy bindings
             do { (env2, bndr2)    <- addBndrRules env1 bndr bndr1 (BC_Let NotTopLevel NonRecursive)
                ; (floats1, env3)  <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se
                ; (floats2, expr') <- simplLam env3 body cont
@@ -1806,7 +1812,7 @@ care here.
 Note [Avoiding exponential behaviour]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 One way in which we can get exponential behaviour is if we simplify a
-big expression, and the re-simplify it -- and then this happens in a
+big expression, and then re-simplify it -- and then this happens in a
 deeply-nested way.  So we must be jolly careful about re-simplifying
 an expression.  That is why simplNonRecX does not try
 preInlineUnconditionally (unlike simplNonRecE).
@@ -1864,13 +1870,8 @@ simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr
                      -> InExpr -> SimplCont
                      -> SimplM (SimplFloats, OutExpr)
 simplNonRecJoinPoint env bndr rhs body cont
-  | assert (isJoinId bndr ) True
-  , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs env
-  = do { tick (PreInlineUnconditionally bndr)
-       ; simplExprF env' body cont }
-
-   | otherwise
-   = wrapJoinCont env cont $ \ env cont ->
+   = assert (isJoinId bndr ) $
+     wrapJoinCont env cont $ \ env cont ->
      do { -- We push join_cont into the join RHS and the body;
           -- and wrap wrap_cont around the whole thing
         ; let mult   = contHoleScaling cont


=====================================
testsuite/tests/simplCore/should_compile/T22491.hs
=====================================
@@ -0,0 +1,319 @@
+{-# LANGUAGE Haskell2010 #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module T22491 (heapster_add_block_hints) where
+
+import qualified Control.Exception as X
+import Control.Applicative
+import Control.Monad
+import Control.Monad.Catch (MonadThrow(..), MonadCatch(..), catches, Handler(..))
+import Control.Monad.IO.Class
+import qualified Control.Monad.Fail as Fail
+import Control.Monad.Trans.Class (MonadTrans(..))
+import Control.Monad.Trans.Reader (ReaderT)
+import Data.Coerce (Coercible, coerce)
+import Data.IORef
+import Data.Kind (Type)
+import Data.Monoid
+import GHC.Exts (build)
+
+failOnNothing :: Fail.MonadFail m => String -> Maybe a -> m a
+failOnNothing err_str Nothing = Fail.fail err_str
+failOnNothing _ (Just a) = return a
+
+lookupLLVMSymbolModAndCFG :: HeapsterEnv -> String -> IO (Maybe (AnyCFG LLVM))
+lookupLLVMSymbolModAndCFG _ _ = pure Nothing
+
+heapster_add_block_hints :: HeapsterEnv -> String -> [Int] ->
+                            (forall ext blocks ret.
+                             CFG ext blocks ret ->
+                             TopLevel Hint) ->
+                            TopLevel ()
+heapster_add_block_hints henv nm blks hintF =
+  do env <- liftIO $ readIORef $ heapsterEnvPermEnvRef henv
+     AnyCFG cfg <-
+       failOnNothing ("Could not find symbol definition: " ++ nm) =<<
+         io (lookupLLVMSymbolModAndCFG henv nm)
+     let blocks = fmapFC blockInputs $ cfgBlockMap cfg
+         block_idxs = fmapFC (blockIDIndex . blockID) $ cfgBlockMap cfg
+     blkIDs <- case blks of
+       [] -> pure $ toListFC (Some . BlockID) block_idxs
+       _ -> forM blks $ \blk ->
+         failOnNothing ("Block ID " ++ show blk ++
+                        " not found in function " ++ nm)
+                       (fmapF BlockID <$> intIndex blk (size blocks))
+     env' <- foldM (\env' _ ->
+                     permEnvAddHint env' <$>
+                     hintF cfg)
+       env blkIDs
+     liftIO $ writeIORef (heapsterEnvPermEnvRef henv) env'
+
+-----
+
+data Some (f:: k -> Type) = forall x . Some (f x)
+
+class FunctorF m where
+  fmapF :: (forall x . f x -> g x) -> m f -> m g
+
+mapSome :: (forall tp . f tp -> g tp) -> Some f -> Some g
+mapSome f (Some x) = Some $! f x
+
+instance FunctorF Some where fmapF = mapSome
+
+type SingleCtx x = EmptyCtx ::> x
+
+data Ctx k
+  = EmptyCtx
+  | Ctx k ::> k
+
+type family (<+>) (x :: Ctx k) (y :: Ctx k) :: Ctx k where
+  x <+> EmptyCtx = x
+  x <+> (y ::> e) = (x <+> y) ::> e
+
+data Height = Zero | Succ Height
+
+data BalancedTree h (f :: k -> Type) (p :: Ctx k) where
+  BalLeaf :: !(f x) -> BalancedTree 'Zero f (SingleCtx x)
+  BalPair :: !(BalancedTree h f x)
+          -> !(BalancedTree h f y)
+          -> BalancedTree ('Succ h) f (x <+> y)
+
+data BinomialTree (h::Height) (f :: k -> Type) :: Ctx k -> Type where
+  Empty :: BinomialTree h f EmptyCtx
+
+  PlusOne  :: !Int
+           -> !(BinomialTree ('Succ h) f x)
+           -> !(BalancedTree h f y)
+           -> BinomialTree h f (x <+> y)
+
+  PlusZero  :: !Int
+            -> !(BinomialTree ('Succ h) f x)
+            -> BinomialTree h f x
+
+tsize :: BinomialTree h f a -> Int
+tsize Empty = 0
+tsize (PlusOne s _ _) = 2*s+1
+tsize (PlusZero  s _) = 2*s
+
+fmap_bin :: (forall tp . f tp -> g tp)
+         -> BinomialTree h f c
+         -> BinomialTree h g c
+fmap_bin _ Empty = Empty
+fmap_bin f (PlusOne s t x) = PlusOne s (fmap_bin f t) (fmap_bal f x)
+fmap_bin f (PlusZero s t)  = PlusZero s (fmap_bin f t)
+{-# INLINABLE fmap_bin #-}
+
+fmap_bal :: (forall tp . f tp -> g tp)
+         -> BalancedTree h f c
+         -> BalancedTree h g c
+fmap_bal = go
+  where go :: (forall tp . f tp -> g tp)
+              -> BalancedTree h f c
+              -> BalancedTree h g c
+        go f (BalLeaf x) = BalLeaf (f x)
+        go f (BalPair x y) = BalPair (go f x) (go f y)
+{-# INLINABLE fmap_bal #-}
+
+traverse_bin :: Applicative m
+             => (forall tp . f tp -> m (g tp))
+             -> BinomialTree h f c
+             -> m (BinomialTree h g c)
+traverse_bin _ Empty = pure Empty
+traverse_bin f (PlusOne s t x) = PlusOne s  <$> traverse_bin f t <*> traverse_bal f x
+traverse_bin f (PlusZero s t)  = PlusZero s <$> traverse_bin f t
+{-# INLINABLE traverse_bin #-}
+
+traverse_bal :: Applicative m
+             => (forall tp . f tp -> m (g tp))
+             -> BalancedTree h f c
+             -> m (BalancedTree h g c)
+traverse_bal = go
+  where go :: Applicative m
+              => (forall tp . f tp -> m (g tp))
+              -> BalancedTree h f c
+              -> m (BalancedTree h g c)
+        go f (BalLeaf x) = BalLeaf <$> f x
+        go f (BalPair x y) = BalPair <$> go f x <*> go f y
+{-# INLINABLE traverse_bal #-}
+
+data Assignment (f :: k -> Type) (ctx :: Ctx k)
+      = Assignment (BinomialTree 'Zero f ctx)
+
+newtype Index (ctx :: Ctx k) (tp :: k) = Index { indexVal :: Int }
+
+newtype Size (ctx :: Ctx k) = Size Int
+
+intIndex :: Int -> Size ctx -> Maybe (Some (Index ctx))
+intIndex i n | 0 <= i && i < sizeInt n = Just (Some (Index i))
+             | otherwise = Nothing
+
+size :: Assignment f ctx -> Size ctx
+size (Assignment t) = Size (tsize t)
+
+sizeInt :: Size ctx -> Int
+sizeInt (Size n) = n
+
+class FunctorFC (t :: (k -> Type) -> l -> Type) where
+  fmapFC :: forall f g. (forall x. f x -> g x) ->
+                        (forall x. t f x -> t g x)
+
+(#.) :: Coercible b c => (b -> c) -> (a -> b) -> (a -> c)
+(#.) _f = coerce
+
+class FoldableFC (t :: (k -> Type) -> l -> Type) where
+  foldMapFC :: forall f m. Monoid m => (forall x. f x -> m) -> (forall x. t f x -> m)
+  foldMapFC f = foldrFC (mappend . f) mempty
+
+  foldrFC :: forall f b. (forall x. f x -> b -> b) -> (forall x. b -> t f x -> b)
+  foldrFC f z t = appEndo (foldMapFC (Endo #. f) t) z
+
+  toListFC :: forall f a. (forall x. f x -> a) -> (forall x. t f x -> [a])
+  toListFC f t = build (\c n -> foldrFC (\e v -> c (f e) v) n t)
+
+foldMapFCDefault :: (TraversableFC t, Monoid m) => (forall x. f x -> m) -> (forall x. t f x -> m)
+foldMapFCDefault = \f -> getConst . traverseFC (Const . f)
+{-# INLINE foldMapFCDefault #-}
+
+class (FunctorFC t, FoldableFC t) => TraversableFC (t :: (k -> Type) -> l -> Type) where
+  traverseFC :: forall f g m. Applicative m
+             => (forall x. f x -> m (g x))
+             -> (forall x. t f x -> m (t g x))
+
+instance FunctorFC Assignment where
+  fmapFC = \f (Assignment x) -> Assignment (fmap_bin f x)
+  {-# INLINE fmapFC #-}
+
+instance FoldableFC Assignment where
+  foldMapFC = foldMapFCDefault
+  {-# INLINE foldMapFC #-}
+
+instance TraversableFC Assignment where
+  traverseFC = \f (Assignment x) -> Assignment <$> traverse_bin f x
+  {-# INLINE traverseFC #-}
+
+data CrucibleType
+
+data TypeRepr (tp::CrucibleType) where
+
+type CtxRepr = Assignment TypeRepr
+
+data CFG (ext :: Type)
+         (blocks :: Ctx (Ctx CrucibleType))
+         (ret :: CrucibleType)
+   = CFG { cfgBlockMap :: !(BlockMap ext blocks ret)
+         }
+
+type BlockMap ext blocks ret = Assignment (Block ext blocks ret) blocks
+
+data Block ext (blocks :: Ctx (Ctx CrucibleType)) (ret :: CrucibleType) ctx
+   = Block { blockID        :: !(BlockID blocks ctx)
+           , blockInputs    :: !(CtxRepr ctx)
+           }
+
+newtype BlockID (blocks :: Ctx (Ctx CrucibleType)) (tp :: Ctx CrucibleType)
+      = BlockID { blockIDIndex :: Index blocks tp }
+
+data LLVM
+
+data AnyCFG ext where
+  AnyCFG :: CFG ext blocks ret
+         -> AnyCFG ext
+
+newtype StateContT s r m a
+      = StateContT { runStateContT :: (a -> s -> m r)
+                                   -> s
+                                   -> m r
+                   }
+
+fmapStateContT :: (a -> b) -> StateContT s r m a -> StateContT s r m b
+fmapStateContT = \f m -> StateContT $ \c -> runStateContT m (\v s -> (c $! f v) s)
+{-# INLINE fmapStateContT #-}
+
+applyStateContT :: StateContT s r m (a -> b) -> StateContT s r m a -> StateContT s r m b
+applyStateContT = \mf mv ->
+  StateContT $ \c ->
+    runStateContT mf (\f -> runStateContT mv (\v s -> (c $! f v) s))
+{-# INLINE applyStateContT #-}
+
+returnStateContT :: a -> StateContT s r m a
+returnStateContT = \v -> seq v $ StateContT $ \c -> c v
+{-# INLINE returnStateContT #-}
+
+bindStateContT :: StateContT s r m a -> (a -> StateContT s r m b) -> StateContT s r m b
+bindStateContT = \m n -> StateContT $ \c -> runStateContT m (\a -> runStateContT (n a) c)
+{-# INLINE bindStateContT #-}
+
+instance Functor (StateContT s r m) where
+  fmap = fmapStateContT
+
+instance Applicative (StateContT s r m) where
+  pure  = returnStateContT
+  (<*>) = applyStateContT
+
+instance Monad (StateContT s r m) where
+  (>>=) = bindStateContT
+
+instance MonadFail m => MonadFail (StateContT s r m) where
+  fail = \msg -> StateContT $ \_ _ -> fail msg
+
+instance MonadTrans (StateContT s r) where
+  lift = \m -> StateContT $ \c s -> m >>= \v -> seq v (c v s)
+
+instance MonadIO m => MonadIO (StateContT s r m) where
+  liftIO = lift . liftIO
+
+instance MonadThrow m => MonadThrow (StateContT s r m) where
+  throwM e = StateContT (\_k _s -> throwM e)
+
+instance MonadCatch m => MonadCatch (StateContT s r m) where
+  catch m hdl =
+    StateContT $ \k s ->
+      catch
+        (runStateContT m k s)
+        (\e -> runStateContT (hdl e) k s)
+
+data TopLevelRO
+data TopLevelRW
+data Value
+
+newtype TopLevel a =
+  TopLevel_ (ReaderT TopLevelRO (StateContT TopLevelRW (Value, TopLevelRW) IO) a)
+ deriving (Applicative, Functor, Monad, MonadFail, MonadThrow, MonadCatch)
+
+instance MonadIO TopLevel where
+  liftIO = io
+
+io :: IO a -> TopLevel a
+io f = TopLevel_ (liftIO f) `catches` [Handler handleIO]
+  where
+    rethrow :: X.Exception ex => ex -> TopLevel a
+    rethrow ex = throwM (X.SomeException ex)
+
+    handleIO :: X.IOException -> TopLevel a
+    handleIO = rethrow
+
+data HeapsterEnv = HeapsterEnv {
+  heapsterEnvPermEnvRef :: IORef PermEnv
+  }
+
+data Hint where
+
+data PermEnv = PermEnv {
+  permEnvHints :: [Hint]
+  }
+
+permEnvAddHint :: PermEnv -> Hint -> PermEnv
+permEnvAddHint env hint = env { permEnvHints = hint : permEnvHints env }
+
+type family CtxToRList (ctx :: Ctx k) :: RList k where
+  CtxToRList EmptyCtx = RNil
+  CtxToRList (ctx' ::> x) = CtxToRList ctx' :> x
+
+data RList a
+  = RNil
+  | (RList a) :> a


=====================================
testsuite/tests/simplCore/should_compile/all.T
=====================================
@@ -452,3 +452,5 @@ test('T22375', normal, compile, ['-O -ddump-simpl -dsuppress-uniques -dno-typeab
 test('T21851_2', [grep_errmsg(r'wwombat') ], multimod_compile, ['T21851_2', '-O -dno-typeable-binds -dsuppress-uniques'])
 # Should not inline m, so there shouldn't be a single YES
 test('T22317', [grep_errmsg(r'ANSWER = YES') ], compile, ['-O -dinline-check m -ddebug-output'])
+
+test('T22491', normal, compile, ['-O2'])



View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/c0e02e0e0e4f08e6a9458e25b90b18149a411485

-- 
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/c0e02e0e0e4f08e6a9458e25b90b18149a411485
You're receiving this email because of your account on gitlab.haskell.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20221122/b2709b96/attachment-0001.html>


More information about the ghc-commits mailing list