379bdcf8 by Josh Meredith at 2023-05-02T12:19:36+00:00
JS: refactor jsSaturate to return a saturated JStat (#23328)

6 changed files:

- compiler/GHC/JS/Transform.hs
- compiler/GHC/StgToJS/CodeGen.hs
- compiler/GHC/StgToJS/FFI.hs
- compiler/GHC/StgToJS/Linker/Linker.hs
- compiler/GHC/StgToJS/Monad.hs
- compiler/GHC/StgToJS/Rts/Rts.hs


@@ -6,6 +6,7 @@
 {-# LANGUAGE ScopedTypeVariables #-}
 {-# LANGUAGE BlockArguments #-}
+{-# LANGUAGE TupleSections #-}
 module GHC.JS.Transform
   ( identsS
@@ -22,7 +23,6 @@ module GHC.JS.Transform
   , composOpM_
   , composOpFold
   , satJExpr
-  , satJStat
@@ -33,7 +33,6 @@ import GHC.JS.Unsat.Syntax
 import Data.Functor.Identity
 import Control.Monad
-import Control.Arrow ((***))
 import GHC.Data.FastString
 import GHC.Utils.Monad.State.Strict
@@ -200,54 +199,34 @@ jmcompos ret app f' v =
 -- | Given an optional prefix, fills in all free variable names with a supply
 -- of names generated by the prefix.
-jsSaturate :: (JMacro a) => Maybe FastString -> a -> a
-jsSaturate str x = evalState (runIdentSupply $ jsSaturate_ x) (newIdentSupply str)
-jsSaturate_ :: (JMacro a) => a -> IdentSupply a
-jsSaturate_ e = IS $ jfromGADT <$> go (jtoGADT e)
-    where
-      go :: forall a. JMGadt a -> State [Ident] (JMGadt a)
-      go v = case v of
-               JMGStat (UnsatBlock us) -> go =<< (JMGStat <$> runIdentSupply us)
-               JMGExpr (UnsatExpr  us) -> go =<< (JMGExpr <$> runIdentSupply us)
-               JMGVal  (UnsatVal   us) -> go =<< (JMGVal  <$> runIdentSupply us)
-               _ -> composOpM go v
+jsSaturate :: Maybe FastString -> JStat -> Sat.JStat
+jsSaturate str x = evalState (go x) (newIdentSupply str)
+  where
+    go :: JStat -> State [Ident] Sat.JStat
+    go  = \case
+      DeclStat i rhs        -> return $ Sat.DeclStat i (fmap satJExpr rhs)
+      ReturnStat e          -> return $ Sat.ReturnStat (satJExpr e)
+      IfStat c t e          -> Sat.IfStat (satJExpr c) <$> go t <*> go e
+      WhileStat is_do c e   -> Sat.WhileStat is_do (satJExpr c) <$> go e
+      ForInStat is_each i iter body -> Sat.ForInStat is_each i (satJExpr iter) <$> go body
+      SwitchStat struct ps def -> Sat.SwitchStat (satJExpr struct)
+                                                 <$> mapM (\(p1, p2) -> (satJExpr p1,) <$> go p2) ps
+                                                 <*> go def
+      TryStat t i c f       -> Sat.TryStat <$> go t <*> pure i <*> go c <*> go f
+      BlockStat bs          -> fmap Sat.BlockStat $! mapM go bs
+      ApplStat rator rand   -> return $ Sat.ApplStat (satJExpr rator) (satJExpr <$> rand)
+      UOpStat  rator rand   -> return $ Sat.UOpStat  (satJUOp rator) (satJExpr rand)
+      AssignStat lhs rhs    -> return $ Sat.AssignStat (satJExpr lhs) (satJExpr rhs)
+      LabelStat lbl stmt    -> Sat.LabelStat lbl <$> go stmt
+      BreakStat m_l         -> return $ Sat.BreakStat $! m_l
+      ContinueStat m_l      -> return $ Sat.ContinueStat $! m_l
+      UnsatBlock us         -> go =<< runIdentSupply us
 --                            Translation
 -- This will be moved after GHC.JS.Syntax is removed
-satJStat :: JStat -> Sat.JStat
-satJStat = witness . proof
-  where proof = jsSaturate Nothing
-        -- This is an Applicative but we can't use it because no type variables :(
-        witness :: JStat -> Sat.JStat
-        witness (DeclStat i rhs)      = Sat.DeclStat i (fmap satJExpr rhs)
-        witness (ReturnStat e)        = Sat.ReturnStat (satJExpr e)
-        witness (IfStat c t e)        = Sat.IfStat (satJExpr c) (witness t) (witness e)
-        witness (WhileStat is_do c e) = Sat.WhileStat is_do (satJExpr c) (witness e)
-        witness (ForInStat is_each i iter body) = Sat.ForInStat is_each i
-                                                  (satJExpr iter)
-                                                  (witness body)
-        witness (SwitchStat struct ps def) = Sat.SwitchStat
-                                             (satJExpr struct)
-                                             (map (satJExpr *** witness) ps)
-                                             (witness def)
-        witness (TryStat t i c f)     = Sat.TryStat (witness t) i (witness c) (witness f)
-        witness (BlockStat bs)        = Sat.BlockStat $! fmap witness bs
-        witness (ApplStat rator rand) = Sat.ApplStat (satJExpr rator) (satJExpr <$> rand)
-        witness (UOpStat rator rand)  = Sat.UOpStat  (satJUOp rator) (satJExpr rand)
-        witness (AssignStat lhs rhs)  = Sat.AssignStat (satJExpr lhs) (satJExpr rhs)
-        witness (LabelStat lbl stmt)  = Sat.LabelStat lbl (witness stmt)
-        witness (BreakStat Nothing)   = Sat.BreakStat Nothing
-        witness (BreakStat (Just l))  = Sat.BreakStat $! Just l
-        witness (ContinueStat Nothing)  = Sat.ContinueStat Nothing
-        witness (ContinueStat (Just l)) = Sat.ContinueStat $! Just l
-        witness UnsatBlock{}            = error "satJStat: discovered an Unsat...impossibly"
 satJExpr :: JExpr -> Sat.JExpr
 satJExpr = go
@@ -315,5 +294,5 @@ satJVal = go
     go (JStr f)    = Sat.JStr   f
     go (JRegEx f)  = Sat.JRegEx f
     go (JHash m)   = Sat.JHash (satJExpr <$> m)
-    go (JFunc args body) = Sat.JFunc args (satJStat body)
+    go (JFunc args body) = Sat.JFunc args (jsSaturate Nothing body)
     go UnsatVal{} = error "jvalToSatVar: discovered an Sat...impossibly"

@@ -134,7 +134,6 @@ genUnits m ss spt_entries foreign_stubs = do
         staticInit <-
           initStaticPtrs spt_entries
         let stat = ( -- O.optimize .
-                     satJStat .
                      jsSaturate (Just $ modulePrefix m 1)
                    $ mconcat (reverse glbl) <> staticInit)
         let syms = [moduleGlobalSymbol m]
@@ -208,7 +207,7 @@ genUnits m ss spt_entries foreign_stubs = do
               _extraTl   <- State.gets (ggsToplevelStats . gsGroup)
               si        <- State.gets (ggsStatic . gsGroup)
               let body = mempty -- mconcat (reverse extraTl) <> b1 ||= e1 <> b2 ||= e2
-              let stat =  satJStat $ jsSaturate (Just $ modulePrefix m n) body
+              let stat = jsSaturate (Just $ modulePrefix m n) body
               let ids = [bnd]
               syms <- (\(TxtI i) -> [i]) <$> identForId bnd
               let oi = ObjUnit
@@ -246,7 +245,6 @@ genUnits m ss spt_entries foreign_stubs = do
               topDeps  = collectTopIds decl
               required = hasExport decl
               stat     = -- Opt.optimize .
-                         satJStat .
                          jsSaturate (Just $ modulePrefix m n)
                        $ mconcat (reverse extraTl) <> tl
           syms <- mapM (fmap (\(TxtI i) -> i) . identForId) topDeps

@@ -14,6 +14,7 @@ import GHC.Prelude
 import GHC.JS.Unsat.Syntax
 import GHC.JS.Make
 import GHC.JS.Transform
+import qualified GHC.JS.Syntax as Sat
 import GHC.StgToJS.Arg
 import GHC.StgToJS.ExprCtx
@@ -176,7 +177,7 @@ genFFIArg isJavaScriptCc a@(StgVarArg i)
      arg_ty = stgArgType a
      r      = uTypeVt arg_ty
-saturateFFI :: JMacro a => Int -> a -> a
+saturateFFI :: Int -> JStat -> Sat.JStat
 saturateFFI u = jsSaturate (Just . mkFastString $ "ghcjs_ffi_sat_" ++ show u)
 genForeignCall :: HasDebugCallStack

@@ -332,7 +332,7 @@ renderLinker h mods jsFiles = do
     pure (mod_mod, mod_size)
   -- commoned up metadata
-  !meta_length <- fromIntegral <$> putJS (satJStat meta)
+  !meta_length <- fromIntegral <$> putJS (jsSaturate Nothing meta)
   -- module exports
   mapM_ (putBS . cmc_exports) compacted_mods

@@ -25,6 +25,7 @@ where
 import GHC.Prelude
 import GHC.JS.Unsat.Syntax
+import qualified GHC.JS.Syntax as Sat
 import GHC.JS.Transform
 import GHC.StgToJS.Types
@@ -160,7 +161,7 @@ data GlobalOcc = GlobalOcc
 -- | Return number of occurrences of every global id used in the given JStat.
 -- Sort by increasing occurrence count.
-globalOccs :: JStat -> G [GlobalOcc]
+globalOccs :: Sat.JStat -> G [GlobalOcc]
 globalOccs jst = do
   GlobalIdCache gidc <- getGlobalIdCache
   -- build a map form Ident Unique to (Ident, Id, Count)
@@ -180,4 +181,4 @@ globalOccs jst = do
               let g = GlobalOcc i gid 1
               in go (addToUFM_C inc gids i g) is
-  pure $ go emptyUFM (identsS $ satJStat jst)
+  pure $ go emptyUFM (identsS jst)

@@ -30,6 +30,7 @@ import GHC.Prelude
 import GHC.JS.Unsat.Syntax
 import GHC.JS.Make
 import GHC.JS.Transform
+import qualified GHC.JS.Syntax as Sat
 import GHC.StgToJS.Apply
 import GHC.StgToJS.Closure
@@ -298,7 +299,7 @@ closureTypes = mconcat (map mkClosureType (enumFromTo minBound maxBound)) <> clo
     ifCT arg ct = jwhenS (arg .===. toJExpr ct) (returnS (toJExpr (show ct)))
 -- | JS payload declaring the RTS functions.
-rtsDecls :: JStat
+rtsDecls :: Sat.JStat
 rtsDecls = jsSaturate (Just "h$RTSD") $
   mconcat [ TxtI "h$currentThread"   ||= null_                   -- thread state object for current thread
           , TxtI "h$stack"           ||= null_                   -- stack for the current thread
@@ -314,14 +315,14 @@ rtsDecls = jsSaturate (Just "h$RTSD") $
 -- | print the embedded RTS to a String
 rtsText :: StgToJSConfig -> String
-rtsText = show . pretty . satJStat . rts
+rtsText = show . pretty . rts
 -- | print the RTS declarations to a String.
 rtsDeclsText :: String
-rtsDeclsText = show . pretty . satJStat $ rtsDecls
+rtsDeclsText = show . pretty $ rtsDecls
 -- | Wrapper over the RTS to guarentee saturation, see 'GHC.JS.Transform'
-rts :: StgToJSConfig -> JStat
+rts :: StgToJSConfig -> Sat.JStat
 rts = jsSaturate (Just "h$RTS") . rts'
 -- | JS Payload which defines the embedded RTS.

