[Git][ghc/ghc][wip/T22491] Add a missing varToCoreExpr in etaBodyForJoinPoint
Simon Peyton Jones (@simonpj)
gitlab at gitlab.haskell.org
Tue Nov 22 14:59:23 UTC 2022
Simon Peyton Jones pushed to branch wip/T22491 at Glasgow Haskell Compiler / GHC
Commits:
c0e02e0e by Simon Peyton Jones at 2022-11-22T15:00:35+00:00
Add a missing varToCoreExpr in etaBodyForJoinPoint
This subtle bug showed up when compiling a library with 9.4.
See #22491. The bug is present in master, but it is hard to
trigger; the new regression test T22491 fails in 9.4
The fix is definitely right though!
I also moved the preInlineUnconditionally test in simplExprF1 to
before the call to joinPointBinding_maybe, to avoid fruitless
eta-expansion. This is just a minor refactoring.
- - - - -
4 changed files:
- compiler/GHC/Core/Opt/Arity.hs
- compiler/GHC/Core/Opt/Simplify/Iteration.hs
- + testsuite/tests/simplCore/should_compile/T22491.hs
- testsuite/tests/simplCore/should_compile/all.T
Changes:
=====================================
compiler/GHC/Core/Opt/Arity.hs
=====================================
@@ -3104,9 +3104,13 @@ etaBodyForJoinPoint need_args body
| Just (tv, res_ty) <- splitForAllTyCoVar_maybe ty
, let (subst', tv') = substVarBndr subst tv
= go (n-1) res_ty subst' (tv' : rev_bs) (e `App` varToCoreExpr tv')
+ -- The varToCoreExpr is important: `tv` might be a coercion variable
+
| Just (_, mult, arg_ty, res_ty) <- splitFunTy_maybe ty
, let (subst', b) = freshEtaId n subst (Scaled mult arg_ty)
- = go (n-1) res_ty subst' (b : rev_bs) (e `App` Var b)
+ = go (n-1) res_ty subst' (b : rev_bs) (e `App` varToCoreExpr b)
+ -- The varToCoreExpr is important: `b` might be a coercion variable
+
| otherwise
= pprPanic "etaBodyForJoinPoint" $ int need_args $$
ppr body $$ ppr (exprType body)
=====================================
compiler/GHC/Core/Opt/Simplify/Iteration.hs
=====================================
@@ -1227,6 +1227,14 @@ simplExprF1 env (Let (NonRec bndr rhs) body) cont
do { ty' <- simplType env ty
; simplExprF (extendTvSubst env bndr ty') body cont }
+ | Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs env
+ = do { tick (PreInlineUnconditionally bndr)
+ ; simplExprF env' body cont }
+
+ -- Now check for a join point. It's better to do the preInlineUnconditionally
+ -- test first, because joinPointBinding_maybe has to eta-expand, so a trivial
+ -- binding like { j = j2 |> co } would first be eta-expanded and then inlined
+ -- Better to test preInlineUnconditionally first.
| Just (bndr', rhs') <- joinPointBinding_maybe bndr rhs
= {-#SCC "simplNonRecJoinPoint" #-} simplNonRecJoinPoint env bndr' rhs' body cont
@@ -1680,12 +1688,17 @@ simpl_lam env bndr body (ApplyToVal { sc_arg = arg, sc_env = arg_se
, sc_cont = cont, sc_dup = dup })
| isSimplified dup -- Don't re-simplify if we've simplified it once
-- See Note [Avoiding exponential behaviour]
- = do { tick (BetaReduction bndr)
- ; completeBindX env bndr arg body cont }
+ = do { tick (BetaReduction bndr)
+ ; completeBindX env bndr arg body cont }
+
+ | Just env' <- preInlineUnconditionally env NotTopLevel bndr arg arg_se
+ = do { tick (PreInlineUnconditionally bndr)
+ ; -- pprTrace "preInlineUncond" (ppr bndr <+> ppr rhs) $
+ simplLam env' body cont }
| otherwise -- See Note [Avoiding exponential behaviour]
- = do { tick (BetaReduction bndr)
- ; simplNonRecE env bndr (arg, arg_se) body cont }
+ = do { tick (BetaReduction bndr)
+ ; simplNonRecE env bndr (arg, arg_se) body cont }
-- Discard a non-counting tick on a lambda. This may change the
-- cost attribution slightly (moving the allocation of the
@@ -1735,6 +1748,8 @@ simplNonRecE :: SimplEnv
-- It deals with strict bindings, via the StrictBind continuation,
-- which may abort the whole process.
--
+-- Caller is expected to do the preInlineUnconditionally test
+--
-- The RHS may not satisfy the let-can-float invariant yet
simplNonRecE env bndr (rhs, rhs_se) body cont
@@ -1742,23 +1757,14 @@ simplNonRecE env bndr (rhs, rhs_se) body cont
do { (env1, bndr1) <- simplNonRecBndr env bndr
; let needs_case_binding = needsCaseBinding (idType bndr1) rhs
-- See Note [Dark corner with representation polymorphism]
- ; if | not needs_case_binding
- , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs rhs_se ->
- do { tick (PreInlineUnconditionally bndr)
- ; -- pprTrace "preInlineUncond" (ppr bndr <+> ppr rhs) $
- simplLam env' body cont }
-
- -- Deal with strict bindings
- -- See Note [Dark corner with representation polymorphism]
- | isStrictId bndr1 && seCaseCase env
- || needs_case_binding ->
+ ; if ((isStrictId bndr1 && seCaseCase env) || needs_case_binding)
+ then -- Deal with strict bindings
simplExprF (rhs_se `setInScopeFromE` env) rhs
(StrictBind { sc_bndr = bndr, sc_body = body
, sc_env = env, sc_cont = cont, sc_dup = NoDup })
- -- Deal with lazy bindings
- | otherwise ->
+ else -- Deal with lazy bindings
do { (env2, bndr2) <- addBndrRules env1 bndr bndr1 (BC_Let NotTopLevel NonRecursive)
; (floats1, env3) <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se
; (floats2, expr') <- simplLam env3 body cont
@@ -1806,7 +1812,7 @@ care here.
Note [Avoiding exponential behaviour]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
One way in which we can get exponential behaviour is if we simplify a
-big expression, and the re-simplify it -- and then this happens in a
+big expression, and then re-simplify it -- and then this happens in a
deeply-nested way. So we must be jolly careful about re-simplifying
an expression. That is why simplNonRecX does not try
preInlineUnconditionally (unlike simplNonRecE).
@@ -1864,13 +1870,8 @@ simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr
-> InExpr -> SimplCont
-> SimplM (SimplFloats, OutExpr)
simplNonRecJoinPoint env bndr rhs body cont
- | assert (isJoinId bndr ) True
- , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs env
- = do { tick (PreInlineUnconditionally bndr)
- ; simplExprF env' body cont }
-
- | otherwise
- = wrapJoinCont env cont $ \ env cont ->
+ = assert (isJoinId bndr ) $
+ wrapJoinCont env cont $ \ env cont ->
do { -- We push join_cont into the join RHS and the body;
-- and wrap wrap_cont around the whole thing
; let mult = contHoleScaling cont
=====================================
testsuite/tests/simplCore/should_compile/T22491.hs
=====================================
@@ -0,0 +1,319 @@
+{-# LANGUAGE Haskell2010 #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module T22491 (heapster_add_block_hints) where
+
+import qualified Control.Exception as X
+import Control.Applicative
+import Control.Monad
+import Control.Monad.Catch (MonadThrow(..), MonadCatch(..), catches, Handler(..))
+import Control.Monad.IO.Class
+import qualified Control.Monad.Fail as Fail
+import Control.Monad.Trans.Class (MonadTrans(..))
+import Control.Monad.Trans.Reader (ReaderT)
+import Data.Coerce (Coercible, coerce)
+import Data.IORef
+import Data.Kind (Type)
+import Data.Monoid
+import GHC.Exts (build)
+
+failOnNothing :: Fail.MonadFail m => String -> Maybe a -> m a
+failOnNothing err_str Nothing = Fail.fail err_str
+failOnNothing _ (Just a) = return a
+
+lookupLLVMSymbolModAndCFG :: HeapsterEnv -> String -> IO (Maybe (AnyCFG LLVM))
+lookupLLVMSymbolModAndCFG _ _ = pure Nothing
+
+heapster_add_block_hints :: HeapsterEnv -> String -> [Int] ->
+ (forall ext blocks ret.
+ CFG ext blocks ret ->
+ TopLevel Hint) ->
+ TopLevel ()
+heapster_add_block_hints henv nm blks hintF =
+ do env <- liftIO $ readIORef $ heapsterEnvPermEnvRef henv
+ AnyCFG cfg <-
+ failOnNothing ("Could not find symbol definition: " ++ nm) =<<
+ io (lookupLLVMSymbolModAndCFG henv nm)
+ let blocks = fmapFC blockInputs $ cfgBlockMap cfg
+ block_idxs = fmapFC (blockIDIndex . blockID) $ cfgBlockMap cfg
+ blkIDs <- case blks of
+ [] -> pure $ toListFC (Some . BlockID) block_idxs
+ _ -> forM blks $ \blk ->
+ failOnNothing ("Block ID " ++ show blk ++
+ " not found in function " ++ nm)
+ (fmapF BlockID <$> intIndex blk (size blocks))
+ env' <- foldM (\env' _ ->
+ permEnvAddHint env' <$>
+ hintF cfg)
+ env blkIDs
+ liftIO $ writeIORef (heapsterEnvPermEnvRef henv) env'
+
+-----
+
+data Some (f:: k -> Type) = forall x . Some (f x)
+
+class FunctorF m where
+ fmapF :: (forall x . f x -> g x) -> m f -> m g
+
+mapSome :: (forall tp . f tp -> g tp) -> Some f -> Some g
+mapSome f (Some x) = Some $! f x
+
+instance FunctorF Some where fmapF = mapSome
+
+type SingleCtx x = EmptyCtx ::> x
+
+data Ctx k
+ = EmptyCtx
+ | Ctx k ::> k
+
+type family (<+>) (x :: Ctx k) (y :: Ctx k) :: Ctx k where
+ x <+> EmptyCtx = x
+ x <+> (y ::> e) = (x <+> y) ::> e
+
+data Height = Zero | Succ Height
+
+data BalancedTree h (f :: k -> Type) (p :: Ctx k) where
+ BalLeaf :: !(f x) -> BalancedTree 'Zero f (SingleCtx x)
+ BalPair :: !(BalancedTree h f x)
+ -> !(BalancedTree h f y)
+ -> BalancedTree ('Succ h) f (x <+> y)
+
+data BinomialTree (h::Height) (f :: k -> Type) :: Ctx k -> Type where
+ Empty :: BinomialTree h f EmptyCtx
+
+ PlusOne :: !Int
+ -> !(BinomialTree ('Succ h) f x)
+ -> !(BalancedTree h f y)
+ -> BinomialTree h f (x <+> y)
+
+ PlusZero :: !Int
+ -> !(BinomialTree ('Succ h) f x)
+ -> BinomialTree h f x
+
+tsize :: BinomialTree h f a -> Int
+tsize Empty = 0
+tsize (PlusOne s _ _) = 2*s+1
+tsize (PlusZero s _) = 2*s
+
+fmap_bin :: (forall tp . f tp -> g tp)
+ -> BinomialTree h f c
+ -> BinomialTree h g c
+fmap_bin _ Empty = Empty
+fmap_bin f (PlusOne s t x) = PlusOne s (fmap_bin f t) (fmap_bal f x)
+fmap_bin f (PlusZero s t) = PlusZero s (fmap_bin f t)
+{-# INLINABLE fmap_bin #-}
+
+fmap_bal :: (forall tp . f tp -> g tp)
+ -> BalancedTree h f c
+ -> BalancedTree h g c
+fmap_bal = go
+ where go :: (forall tp . f tp -> g tp)
+ -> BalancedTree h f c
+ -> BalancedTree h g c
+ go f (BalLeaf x) = BalLeaf (f x)
+ go f (BalPair x y) = BalPair (go f x) (go f y)
+{-# INLINABLE fmap_bal #-}
+
+traverse_bin :: Applicative m
+ => (forall tp . f tp -> m (g tp))
+ -> BinomialTree h f c
+ -> m (BinomialTree h g c)
+traverse_bin _ Empty = pure Empty
+traverse_bin f (PlusOne s t x) = PlusOne s <$> traverse_bin f t <*> traverse_bal f x
+traverse_bin f (PlusZero s t) = PlusZero s <$> traverse_bin f t
+{-# INLINABLE traverse_bin #-}
+
+traverse_bal :: Applicative m
+ => (forall tp . f tp -> m (g tp))
+ -> BalancedTree h f c
+ -> m (BalancedTree h g c)
+traverse_bal = go
+ where go :: Applicative m
+ => (forall tp . f tp -> m (g tp))
+ -> BalancedTree h f c
+ -> m (BalancedTree h g c)
+ go f (BalLeaf x) = BalLeaf <$> f x
+ go f (BalPair x y) = BalPair <$> go f x <*> go f y
+{-# INLINABLE traverse_bal #-}
+
+data Assignment (f :: k -> Type) (ctx :: Ctx k)
+ = Assignment (BinomialTree 'Zero f ctx)
+
+newtype Index (ctx :: Ctx k) (tp :: k) = Index { indexVal :: Int }
+
+newtype Size (ctx :: Ctx k) = Size Int
+
+intIndex :: Int -> Size ctx -> Maybe (Some (Index ctx))
+intIndex i n | 0 <= i && i < sizeInt n = Just (Some (Index i))
+ | otherwise = Nothing
+
+size :: Assignment f ctx -> Size ctx
+size (Assignment t) = Size (tsize t)
+
+sizeInt :: Size ctx -> Int
+sizeInt (Size n) = n
+
+class FunctorFC (t :: (k -> Type) -> l -> Type) where
+ fmapFC :: forall f g. (forall x. f x -> g x) ->
+ (forall x. t f x -> t g x)
+
+(#.) :: Coercible b c => (b -> c) -> (a -> b) -> (a -> c)
+(#.) _f = coerce
+
+class FoldableFC (t :: (k -> Type) -> l -> Type) where
+ foldMapFC :: forall f m. Monoid m => (forall x. f x -> m) -> (forall x. t f x -> m)
+ foldMapFC f = foldrFC (mappend . f) mempty
+
+ foldrFC :: forall f b. (forall x. f x -> b -> b) -> (forall x. b -> t f x -> b)
+ foldrFC f z t = appEndo (foldMapFC (Endo #. f) t) z
+
+ toListFC :: forall f a. (forall x. f x -> a) -> (forall x. t f x -> [a])
+ toListFC f t = build (\c n -> foldrFC (\e v -> c (f e) v) n t)
+
+foldMapFCDefault :: (TraversableFC t, Monoid m) => (forall x. f x -> m) -> (forall x. t f x -> m)
+foldMapFCDefault = \f -> getConst . traverseFC (Const . f)
+{-# INLINE foldMapFCDefault #-}
+
+class (FunctorFC t, FoldableFC t) => TraversableFC (t :: (k -> Type) -> l -> Type) where
+ traverseFC :: forall f g m. Applicative m
+ => (forall x. f x -> m (g x))
+ -> (forall x. t f x -> m (t g x))
+
+instance FunctorFC Assignment where
+ fmapFC = \f (Assignment x) -> Assignment (fmap_bin f x)
+ {-# INLINE fmapFC #-}
+
+instance FoldableFC Assignment where
+ foldMapFC = foldMapFCDefault
+ {-# INLINE foldMapFC #-}
+
+instance TraversableFC Assignment where
+ traverseFC = \f (Assignment x) -> Assignment <$> traverse_bin f x
+ {-# INLINE traverseFC #-}
+
+data CrucibleType
+
+data TypeRepr (tp::CrucibleType) where
+
+type CtxRepr = Assignment TypeRepr
+
+data CFG (ext :: Type)
+ (blocks :: Ctx (Ctx CrucibleType))
+ (ret :: CrucibleType)
+ = CFG { cfgBlockMap :: !(BlockMap ext blocks ret)
+ }
+
+type BlockMap ext blocks ret = Assignment (Block ext blocks ret) blocks
+
+data Block ext (blocks :: Ctx (Ctx CrucibleType)) (ret :: CrucibleType) ctx
+ = Block { blockID :: !(BlockID blocks ctx)
+ , blockInputs :: !(CtxRepr ctx)
+ }
+
+newtype BlockID (blocks :: Ctx (Ctx CrucibleType)) (tp :: Ctx CrucibleType)
+ = BlockID { blockIDIndex :: Index blocks tp }
+
+data LLVM
+
+data AnyCFG ext where
+ AnyCFG :: CFG ext blocks ret
+ -> AnyCFG ext
+
+newtype StateContT s r m a
+ = StateContT { runStateContT :: (a -> s -> m r)
+ -> s
+ -> m r
+ }
+
+fmapStateContT :: (a -> b) -> StateContT s r m a -> StateContT s r m b
+fmapStateContT = \f m -> StateContT $ \c -> runStateContT m (\v s -> (c $! f v) s)
+{-# INLINE fmapStateContT #-}
+
+applyStateContT :: StateContT s r m (a -> b) -> StateContT s r m a -> StateContT s r m b
+applyStateContT = \mf mv ->
+ StateContT $ \c ->
+ runStateContT mf (\f -> runStateContT mv (\v s -> (c $! f v) s))
+{-# INLINE applyStateContT #-}
+
+returnStateContT :: a -> StateContT s r m a
+returnStateContT = \v -> seq v $ StateContT $ \c -> c v
+{-# INLINE returnStateContT #-}
+
+bindStateContT :: StateContT s r m a -> (a -> StateContT s r m b) -> StateContT s r m b
+bindStateContT = \m n -> StateContT $ \c -> runStateContT m (\a -> runStateContT (n a) c)
+{-# INLINE bindStateContT #-}
+
+instance Functor (StateContT s r m) where
+ fmap = fmapStateContT
+
+instance Applicative (StateContT s r m) where
+ pure = returnStateContT
+ (<*>) = applyStateContT
+
+instance Monad (StateContT s r m) where
+ (>>=) = bindStateContT
+
+instance MonadFail m => MonadFail (StateContT s r m) where
+ fail = \msg -> StateContT $ \_ _ -> fail msg
+
+instance MonadTrans (StateContT s r) where
+ lift = \m -> StateContT $ \c s -> m >>= \v -> seq v (c v s)
+
+instance MonadIO m => MonadIO (StateContT s r m) where
+ liftIO = lift . liftIO
+
+instance MonadThrow m => MonadThrow (StateContT s r m) where
+ throwM e = StateContT (\_k _s -> throwM e)
+
+instance MonadCatch m => MonadCatch (StateContT s r m) where
+ catch m hdl =
+ StateContT $ \k s ->
+ catch
+ (runStateContT m k s)
+ (\e -> runStateContT (hdl e) k s)
+
+data TopLevelRO
+data TopLevelRW
+data Value
+
+newtype TopLevel a =
+ TopLevel_ (ReaderT TopLevelRO (StateContT TopLevelRW (Value, TopLevelRW) IO) a)
+ deriving (Applicative, Functor, Monad, MonadFail, MonadThrow, MonadCatch)
+
+instance MonadIO TopLevel where
+ liftIO = io
+
+io :: IO a -> TopLevel a
+io f = TopLevel_ (liftIO f) `catches` [Handler handleIO]
+ where
+ rethrow :: X.Exception ex => ex -> TopLevel a
+ rethrow ex = throwM (X.SomeException ex)
+
+ handleIO :: X.IOException -> TopLevel a
+ handleIO = rethrow
+
+data HeapsterEnv = HeapsterEnv {
+ heapsterEnvPermEnvRef :: IORef PermEnv
+ }
+
+data Hint where
+
+data PermEnv = PermEnv {
+ permEnvHints :: [Hint]
+ }
+
+permEnvAddHint :: PermEnv -> Hint -> PermEnv
+permEnvAddHint env hint = env { permEnvHints = hint : permEnvHints env }
+
+type family CtxToRList (ctx :: Ctx k) :: RList k where
+ CtxToRList EmptyCtx = RNil
+ CtxToRList (ctx' ::> x) = CtxToRList ctx' :> x
+
+data RList a
+ = RNil
+ | (RList a) :> a
=====================================
testsuite/tests/simplCore/should_compile/all.T
=====================================
@@ -452,3 +452,5 @@ test('T22375', normal, compile, ['-O -ddump-simpl -dsuppress-uniques -dno-typeab
test('T21851_2', [grep_errmsg(r'wwombat') ], multimod_compile, ['T21851_2', '-O -dno-typeable-binds -dsuppress-uniques'])
# Should not inline m, so there shouldn't be a single YES
test('T22317', [grep_errmsg(r'ANSWER = YES') ], compile, ['-O -dinline-check m -ddebug-output'])
+
+test('T22491', normal, compile, ['-O2'])
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/c0e02e0e0e4f08e6a9458e25b90b18149a411485
--
View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/c0e02e0e0e4f08e6a9458e25b90b18149a411485
You're receiving this email because of your account on gitlab.haskell.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.haskell.org/pipermail/ghc-commits/attachments/20221122/b2709b96/attachment-0001.html>
More information about the ghc-commits
mailing list