# [Haskell-cafe] My Unification Sorrows

Paul Berg procyon112 at gmail.com
Wed Apr 4 16:37:44 EDT 2007

```Ok, so I decided to implement an algorithm to build Strongly Typed Genetic
Programming trees in Haskell, in an effort to learn Haskell,
and I'm way over my head on this unification problem.

The unification seems to work well, as long as I include the occurs check.
growFullTree returns a lazy list of all possible syntax trees in random
order that are well typed. So, with the occurs check:

head \$ growFullTree 4 TInt 0 [] (mkStdGen 5)

Gives me a nice, random, well typed tree.  So far so good, although
I am wary of the efficiency... perhaps I need some applySubst calls
in there somewhere to keep the constraint list down?

Anyway now, we comment out the occurs check in the unify function and
do the same thing again.  Now we get what appears to be an infinite
constraint!  The unify is supposed to be terminating without the occurs
check, so something is very wrong here!

Worse, now we run:

head \$ growFullTree 5 TInt 0 [] (mkStdGen 7)

This one never even begins printing!  Although it appears the bug is
tail recursive in this case.  Other random number seeds will cause it
to blow stack.

I have gone over and over this code and cannot find the issue due to my
lack of experience with both Haskell and Unification algorithms in general.

I was hoping someone here could give pointers on where the bug might lie, eg.
what is my algorithm doing wrong that makes it a mostly, but not completely
correct unification algorithm, as well as give me pointers on how this
code could be
made cleaner and/or more concise, as I'm trying very hard to get my brain around
this language, and this problem.

Here's the basic code, which should be fully runnable:

module Evaluator
where

import Maybe
import Random

-- | Expressions.
data Exp = Prim String        -- A Primitive Function
| App Exp Exp        -- An un-evaluated application
| Func (Exp -> Exp)  -- A partially applied function
| LitInt Integer
| LitReal Double
deriving Show

-- | We name type variables by uniquely picking values of this type.
type TVName = Int

-- | Describes possible types in our language
data Type = TInt
| TReal
| TVar TVName
| Type :-> Type
deriving (Eq, Show)

-- | Type arrows are right associative to match Haskell's native type arrows.
infixr 9 :->

-- | Constraints are simply a pair of Types which are equal.
type Constraint = (Type, Type)

-- | Returns a new type variable name, given an old one.
getNextName :: TVName -> TVName
getNextName n = n + 1

-- | Evaluate an expression.
eval :: Exp -> Exp
eval (Prim s)  = fst . fromJust \$ lookup s primitives
eval (App f x) = apply f x
eval x         = x

-- | Apply a function.
apply :: Exp -> Exp -> Exp
apply a@(App _ _) x = eval \$ App (eval a) x
apply p@(Prim _)  x = eval \$ App (eval p) x
apply (Func f)    x = eval \$ f x
apply f           x = error \$ (show f) ++ " was applied to " ++ (show
x) ++ " and is not a valid function."

-- | Given a set of constraints (possibly empty), the next available
variable name and an expression,
-- return a triple which consists of the inferred type of the
expression, the next available variable name,
-- and a list of constraints for the new type.
infer :: [Constraint] -> TVName -> Exp -> (Type, TVName, [Constraint])
infer _ nvar (LitInt _)  = (TInt, nvar, [])
infer _ nvar (LitReal _) = (TReal, nvar, [])
infer _ nvar (Prim s)    = (t, nvar', [])
where (nvar', t, _) = substvars nvar (snd . fromJust \$ lookup s primitives) []

infer ctx nvar (App t1 t2) = (TVar nvar'', getNextName nvar'',
newconstr : (ctx' ++ ctx''))
where (tyT1, nvar', ctx')   = infer ctx nvar t1
(tyT2, nvar'', ctx'') = infer ctx nvar' t2
newconstr             = (tyT1, tyT2 :-> TVar nvar'')

infer _ _ (Func _) = error "Type Inferrence of partially applied
functions not supported"

-- | Given a unified constraint list and a type, return a simplified type
applySubst :: Type -> [Constraint] -> Type
applySubst tyT ctx = foldl (\tyS (tyX,tyC) -> substinty tyX tyC tyS)
tyT (reverse ctx)

-- | We need to substitute all the type variables in a primitive with new
-- ones that do not clash with the variables already defined in our
-- constraint set.  This function does that, by taking the next available
-- variable name and creating a set of substitutions over a type that
-- replace the old variable names with newly generated ones
substvars :: TVName -> Type -> [(TVName, TVName)] -> (TVName, Type,
[(TVName, TVName)])
substvars nvar (TVar n) ctx =
case lookup n ctx of
(Just n') -> (nvar, TVar n', ctx)
Nothing   -> (getNextName nvar, TVar nvar, (n, nvar) : ctx)

substvars nvar (t1 :-> t2) ctx =
(nvar'', t1' :-> t2', ctx'')
where (nvar', t1', ctx')   = substvars nvar t1 ctx
(nvar'', t2', ctx'') = substvars nvar' t2 ctx'

substvars nvar t ctx = (nvar, t, ctx)

-- | returns the type of an expression if it has one
typeof :: Exp -> Maybe Type
typeof e = fmap (applySubst t) (unify ctx)
where (t, _, ctx) = infer [] 0 e

-- | Substitutes a type for a type variable
substinty :: Type -> Type -> Type -> Type
substinty tyX tyT (s1 :-> s2) = (substinty tyX tyT s1) :-> (substinty
tyX tyT s2)
substinty tyX tyT tyS@(TVar _)
| tyS == tyX = tyT
| otherwise  = tyS

substinty _ _ tyS = tyS

-- | Substitutes a type for another type in a constraint set.
substinconstr :: Type -> Type -> [Constraint] -> [Constraint]
substinconstr tyX tyT = map (\(x,y) -> (substinty tyX tyT x, substinty
tyX tyT y))

-- | Checks to see if a type occurs in another type.  This is used to identify
-- infinite types.
occursIn :: Type -> Type -> Bool
occursIn tx (ty1 :-> ty2) = (occursIn tx ty1) || (occursIn tx ty2)
occursIn tx ty            = tx == ty

-- | Given a list of constraints, unify those constraints,
-- finding values for the type variables
unify :: (Monad m, Functor m) => [Constraint] -> m [Constraint]
unify []                                   = return []
unify ((t1 ,          t2)           :rest)
| t1 == t2                               = unify rest

unify ((tyS,          tyX@(TVar _)) :rest)
-- Problems start by commenting out this line
| tyX `occursIn` tyS                     = fail "Infinite Type"
| otherwise                              = fmap (++ [(tyX,tyS)])
(unify \$ substinconstr tyX tyS (reverse rest))

unify ((tyX@(TVar _), tyT)          :rest) = unify \$ (tyT,tyX) : rest
unify ((tyS1 :-> tyS2,tyT1 :-> tyT2):rest) = unify \$ (tyS1,tyT1) :
(tyS2,tyT2) : rest
unify _                                    = fail "Unsolvable"

-- | The basic primitives used to construct expressions.
primitives :: [(String, (Exp, Type))]
primitives = [("B",   (b_, ((TVar 0) :-> (TVar 1)) :-> ((TVar 2) :->
(TVar 0)) :-> (TVar 2) :-> (TVar 1))),
("I",   (i_, (TVar 0) :-> (TVar 0))),
("K",   (k_, (TVar 0) :-> (TVar 1) :-> (TVar 0))),
("S",   (s_, ((TVar 0) :-> (TVar 1) :-> (TVar 2)) :->
((TVar 0) :-> (TVar 1)) :-> (TVar 0) :-> (TVar 2))),
("Add", (add, TInt :-> TInt :-> TInt)),
-- ("Y",   (y_, (TVar 0 :-> TVar 0) :-> TVar 0)),
("Uno", (LitInt 1, TInt))]

-- Some primitives to play with
s_ = Func (\x -> Func (\y -> Func (\z -> App (App x z) (App y z))))
k_ = Func (\x -> Func (\_ -> x))
i_ = Func (\x -> x)
y_ = Func (\f -> let x = (App f x) in x)
b_ = Func (\f -> Func (\g -> Func (\x -> App f (App g x))))
add = Func (\x -> Func (\y -> case (eval x,eval y) of
((LitInt x', LitInt y')) -> LitInt \$ x' + y'
_                        -> error
"Addition applied to non integers not implemented"))

-- | given a Type and a list of constraints for that type, and another
primitive type and the next available type variable,
-- return a new next available type variable, the type of the
primitive and a new list of constraints, if the type
-- is unifiable.
primitiveUnifies :: (Monad m, Functor m) => Type -> [Constraint] ->
Type -> TVName -> m (TVName, Type, [Constraint])
primitiveUnifies t1 cs t2 nvar = fmap construct (unify \$ (t2',t1):cs)
where (nvar',t2',_) = substvars nvar t2 []
construct x   = (nvar', t2', x)

-- | Given a type, the next available type variable, and a list of constraints,
-- return a list of primitives which unify to that type, the type of
the primitive, a new next available
-- type variable, and a new list of constraints
findTypeMatch :: Type -> TVName -> [Constraint] -> [(String, Type,
TVName, [Constraint])]
findTypeMatch t nvar c = foldr foldFunc [] primitives
where construct name (nvar', ty, constr) = (name, ty, nvar', constr)
foldFunc (name,(_,ty)) xs = let result = primitiveUnifies t c ty nvar in
maybe xs (:xs) \$ fmap
(construct name) result

-- | Give the depth of the tree, they type to return, the next
available type variable, a list of constraints on type variables,
-- and a random generator, and this will return a list of full-depth
trees meeting those constraints from the primitives.
growFullTree :: Int -> Type -> TVName -> [Constraint] -> StdGen ->
[(Exp, TVName, [Constraint], StdGen)]
growFullTree 0 ty nvar c g
| optionlist == [] = mzero
| otherwise        = do (func, _, nvar', c') <- optionlist
return (Prim func, nvar', c', g')
where (optionlist, g') = shuffle (findTypeMatch ty nvar c) g
growFullTree depth ty nvar c g = do (func,  nvar',  c', g')   <-
growFullTree (depth - 1) (TVar nvar :-> ty) (getNextName nvar) c  g
(param, nvar'', c'', g'') <-
growFullTree (depth - 1) (TVar nvar)        nvar'              c' g'
return ((App func param), nvar'', c'', g'')

-- | Insert a value into a list at the given index
insert_ix :: Int -> a -> [a] -> [a]
insert_ix 0 a lst = a : lst
insert_ix (k + 1) a (l:ls) = l : insert_ix k a ls

-- | insert a value into a list at a random index
insert_rand :: (RandomGen g) => a -> [a] -> g -> ([a], g)
insert_rand a lst g1 = let (ix,g2) = randomR (0,length lst) g1 in
(insert_ix ix a lst, g2)

-- | Shuffle a list
shuffle :: (RandomGen g) => [a] -> g -> ([a], g)
shuffle [] g1 = ([], g1)
shuffle (x:xs) g1 = let (xs', g2) = shuffle xs g1 in
insert_rand x xs' g2
```