1e3986b7 by Ryan Scott at 2023-06-07T13:42:20+02:00
Consistently use validity checks for TH conversion of data constructors

We were checking that TH-spliced data declarations do not look like this:

data D :: Type = MkD Int

But we were only doing so for `data` declarations' data constructors, not for
`newtype`s, `data instance`s, or `newtype instance`s. This patch factors out
the necessary validity checks into its own `cvtDataDefnCons` function and uses
it in all of the places where it needs to be.

Fixes #22559.

- - - - -

9 changed files:

- compiler/GHC/ThToHs.hs
- compiler/Language/Haskell/Syntax/Decls.hs
- + testsuite/tests/th/T22559a.hs
- + testsuite/tests/th/T22559a.stderr
- + testsuite/tests/th/T22559b.hs
- + testsuite/tests/th/T22559b.stderr
- + testsuite/tests/th/T22559c.hs
- + testsuite/tests/th/T22559c.stderr
- testsuite/tests/th/all.T


@@ -276,17 +276,13 @@ cvtDec (DataD ctxt tc tvs ksig constrs derivs)
 cvtDec (NewtypeD ctxt tc tvs ksig constr derivs)
   = do  { (ctxt', tc', tvs') <- cvt_tycl_hdr ctxt tc tvs
         ; ksig' <- cvtKind `traverse` ksig
-        ; let first_datacon =
-                case get_cons_names constr of
-                  []  -> panic "cvtDec: empty list of constructors"
-                  c:_ -> c
-        ; con' <- cvtConstr first_datacon cNameN constr
+        ; con' <- cvtDataDefnCons False ksig $ NewTypeCon constr
         ; derivs' <- cvtDerivs derivs
         ; let defn = HsDataDefn { dd_ext = noExtField
                                 , dd_cType = Nothing
                                 , dd_ctxt = mkHsContextMaybe ctxt'
                                 , dd_kindSig = ksig'
-                                , dd_cons = NewTypeCon con'
+                                , dd_cons = con'
                                 , dd_derivs = derivs' }
         ; returnJustLA $ TyClD noExtField $
           DataDecl { tcdDExt = noAnn
@@ -352,17 +348,13 @@ cvtDec (DataFamilyD tc tvs kind)
 cvtDec (DataInstD ctxt bndrs tys ksig constrs derivs)
   = do { (ctxt', tc', bndrs', typats') <- cvt_datainst_hdr ctxt bndrs tys
        ; ksig' <- cvtKind `traverse` ksig
-       ; let first_datacon =
-                case get_cons_names $ head constrs of
-                  []  -> panic "cvtDec: empty list of constructors"
-                  c:_ -> c
-       ; cons' <- mapM (cvtConstr first_datacon cNameN) constrs
+       ; cons' <- cvtDataDefnCons False ksig $ DataTypeCons False constrs
        ; derivs' <- cvtDerivs derivs
        ; let defn = HsDataDefn { dd_ext = noExtField
                                , dd_cType = Nothing
                                , dd_ctxt = mkHsContextMaybe ctxt'
                                , dd_kindSig = ksig'
-                               , dd_cons = DataTypeCons False cons'
+                               , dd_cons = cons'
                                , dd_derivs = derivs' }
        ; returnJustLA $ InstD noExtField $ DataFamInstD
@@ -378,17 +370,14 @@ cvtDec (DataInstD ctxt bndrs tys ksig constrs derivs)
 cvtDec (NewtypeInstD ctxt bndrs tys ksig constr derivs)
   = do { (ctxt', tc', bndrs', typats') <- cvt_datainst_hdr ctxt bndrs tys
        ; ksig' <- cvtKind `traverse` ksig
-       ; let first_datacon =
-                case get_cons_names constr of
-                  []  -> panic "cvtDec: empty list of constructors"
-                  c:_ -> c
-       ; con' <- cvtConstr first_datacon cNameN constr
+       ; con' <- cvtDataDefnCons False ksig $ NewTypeCon constr
        ; derivs' <- cvtDerivs derivs
        ; let defn = HsDataDefn { dd_ext = noExtField
                                , dd_cType = Nothing
                                , dd_ctxt = mkHsContextMaybe ctxt'
                                , dd_kindSig = ksig'
-                               , dd_cons = NewTypeCon con', dd_derivs = derivs' }
+                               , dd_cons = con'
+                               , dd_derivs = derivs' }
        ; returnJustLA $ InstD noExtField $ DataFamInstD
            { dfid_ext = noExtField
            , dfid_inst = DataFamInstDecl { dfid_eqn =
@@ -497,6 +486,28 @@ cvtGenDataDec :: Bool -> TH.Cxt -> TH.Name -> [TH.TyVarBndr ()]
     -> Maybe TH.Kind -> [TH.Con] -> [TH.DerivClause]
     -> CvtM (Maybe (LHsDecl GhcPs))
 cvtGenDataDec type_data ctxt tc tvs ksig constrs derivs
+  = do  { (ctxt', tc', tvs') <- cvt_tycl_hdr ctxt tc tvs
+        ; ksig' <- cvtKind `traverse` ksig
+        ; cons' <- cvtDataDefnCons type_data ksig $
+                   DataTypeCons type_data constrs
+        ; derivs' <- cvtDerivs derivs
+        ; let defn = HsDataDefn { dd_ext = noExtField
+                                , dd_cType = Nothing
+                                , dd_ctxt = mkHsContextMaybe ctxt'
+                                , dd_kindSig = ksig'
+                                , dd_cons = cons'
+                                , dd_derivs = derivs' }
+        ; returnJustLA $ TyClD noExtField $
+          DataDecl { tcdDExt = noAnn
+                   , tcdLName = tc', tcdTyVars = tvs'
+                   , tcdFixity = Prefix
+                   , tcdDataDefn = defn } }
+-- Convert a set of data constructors.
+cvtDataDefnCons ::
+  Bool -> Maybe TH.Kind ->
+  DataDefnCons TH.Con -> CvtM (DataDefnCons (LConDecl GhcPs))
+cvtDataDefnCons type_data ksig constrs
   = do  { let isGadtCon (GadtC    _ _ _) = True
               isGadtCon (RecGadtC _ _ _) = True
               isGadtCon (ForallC  _ _ c) = isGadtCon c
@@ -514,27 +525,16 @@ cvtGenDataDec type_data ctxt tc tvs ksig constrs derivs
                  (failWith CannotMixGADTConsWith98Cons)
         ; unless (isNothing ksig || isGadtDecl)
                  (failWith KindSigsOnlyAllowedOnGADTs)
-        ; (ctxt', tc', tvs') <- cvt_tycl_hdr ctxt tc tvs
-        ; ksig' <- cvtKind `traverse` ksig
         ; let first_datacon =
-                case get_cons_names $ head constrs of
-                  []  -> panic "cvtGenDataDec: empty list of constructors"
+                case firstDataDefnCon constrs of
+                  Nothing -> panic "cvtDataDefnCons: empty list of constructors"
+                  Just con -> con
+              first_datacon_name =
+                case get_cons_names first_datacon of
+                  []  -> panic "cvtDataDefnCons: data constructor with no names"
                   c:_ -> c
-        ; cons' <- mapM (cvtConstr first_datacon con_name) constrs
-        ; derivs' <- cvtDerivs derivs
-        ; let defn = HsDataDefn { dd_ext = noExtField
-                                , dd_cType = Nothing
-                                , dd_ctxt = mkHsContextMaybe ctxt'
-                                , dd_kindSig = ksig'
-                                , dd_cons = DataTypeCons type_data cons'
-                                , dd_derivs = derivs' }
-        ; returnJustLA $ TyClD noExtField $
-          DataDecl { tcdDExt = noAnn
-                   , tcdLName = tc', tcdTyVars = tvs'
-                   , tcdFixity = Prefix
-                   , tcdDataDefn = defn } }
+        ; mapM (cvtConstr first_datacon_name con_name) constrs }
 cvtTySynEqn :: TySynEqn -> CvtM (LTyFamInstEqn GhcPs)

@@ -30,7 +30,7 @@ module Language.Haskell.Syntax.Decls (
   HsDecl(..), LHsDecl, HsDataDefn(..), HsDeriving, LHsFunDep, FunDep(..),
   HsDerivingClause(..), LHsDerivingClause, DerivClauseTys(..), LDerivClauseTys,
   NewOrData(..), DataDefnCons(..), dataDefnConsNewOrData,
-  isTypeDataDefnCons,
+  isTypeDataDefnCons, firstDataDefnCon,
   StandaloneKindSig(..), LStandaloneKindSig,
   -- ** Class or type declarations
@@ -1040,6 +1040,11 @@ isTypeDataDefnCons :: DataDefnCons a -> Bool
 isTypeDataDefnCons (NewTypeCon _) = False
 isTypeDataDefnCons (DataTypeCons is_type_data _) = is_type_data
+-- | Retrieve the first data constructor in a 'DataDefnCons' (if one exists).
+firstDataDefnCon :: DataDefnCons a -> Maybe a
+firstDataDefnCon (NewTypeCon con) = Just con
+firstDataDefnCon (DataTypeCons _ cons) = listToMaybe cons
 -- | Located data Constructor Declaration
 type LConDecl pass = XRec pass (ConDecl pass)
       -- ^ May have 'GHC.Parser.Annotation.AnnKeywordId' : 'GHC.Parser.Annotation.AnnSemi' when

@@ -0,0 +1,13 @@
+{-# LANGUAGE TemplateHaskell #-}
+module T22559a where
+import Language.Haskell.TH
+$(pure [NewtypeD
+         [] (mkName "D") [] (Just StarT)
+         (NormalC (mkName "MkD")
+                  [( Bang NoSourceUnpackedness NoSourceStrictness
+                   , ConT ''Int
+                   )])
+         []])

@@ -0,0 +1,4 @@
+T22559a.hs:7:2: error: [GHC-40746]
+    Kind signatures are only allowed on GADTs
+    When splicing a TH declaration: newtype D :: * = MkD GHC.Types.Int

@@ -0,0 +1,17 @@
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+module T22559b where
+import Language.Haskell.TH
+data family D
+$(pure [DataInstD
+         [] Nothing
+         (ConT (mkName "D")) (Just StarT)
+         [NormalC (mkName "MkD")
+                  [( Bang NoSourceUnpackedness NoSourceStrictness
+                   , ConT ''Int
+                   )]]
+         []])

@@ -0,0 +1,5 @@
+T22559b.hs:10:2: error: [GHC-40746]
+    Kind signatures are only allowed on GADTs
+    When splicing a TH declaration:
+      data instance D :: * = MkD GHC.Types.Int

@@ -0,0 +1,17 @@
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+module T22559c where
+import Language.Haskell.TH
+data family D
+$(pure [NewtypeInstD
+         [] Nothing
+         (ConT (mkName "D")) (Just StarT)
+         (NormalC (mkName "MkD")
+                  [( Bang NoSourceUnpackedness NoSourceStrictness
+                   , ConT ''Int
+                   )])
+         []])

@@ -0,0 +1,5 @@
+T22559c.hs:10:2: error: [GHC-40746]
+    Kind signatures are only allowed on GADTs
+    When splicing a TH declaration:
+      newtype instance D :: * = MkD GHC.Types.Int

@@ -573,3 +573,6 @@ test('TH_typed3', normal, compile, ['-v0 -ddump-splices -dsuppress-uniques'])
 test('TH_typed4', normal, compile, ['-v0 -ddump-splices -dsuppress-uniques'])
 test('TH_typed5', normal, compile_and_run, [''])
 test('T21050', normal, compile_fail, [''])
+test('T22559a', normal, compile_fail, [''])
+test('T22559b', normal, compile_fail, [''])
+test('T22559c', normal, compile_fail, [''])

View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/1e3986b7d601a16b33b4d99d7618fa9d8c3d224e

