[Haskell-cafe] Unifcation and matching in Abelian groups
John D. Ramsdell
ramsdell0 at gmail.com
Wed Aug 19 05:43:43 EDT 2009
I've been studying equational unification. I decided to test my
understanding of it by implementing unification and matching in
Abelian groups. I am quite surprised by how little code it takes.
Let me share it with you.
John
Test cases:
2x+y=3z
2x=x+y
64x-41y=a
Code:
> -- Unification and matching in Abelian groups
> -- John D. Ramsdell -- August 2009
>
> module Main (main, test) where
>
> import Data.Char (isSpace, isAlpha, isAlphaNum, isDigit)
> import Data.List (sort)
> import System.IO (isEOF)
>
> -- Chapter 8, Section 5 of the Handbook of Automated Reasoning by
> -- Franz Baader and Wayne Snyder describes unification and matching in
> -- communtative/monoidal theories. This module refines the described
> -- algorithms for the special case of Abelian groups.
>
> -- In this module, an Abelian group is a free algebra over a signature
> -- with three function symbols,
> --
> -- * the binary symbol +, the group operator,
> -- * a constant 0, the identity element, and
> -- * the unary symbol -, the inverse operator.
> --
> -- The algebra is generated by a set of variables. Syntactically, a
> -- variable is an identifer such as x and y.
>
> -- The axioms associated with the algebra are:
> --
> -- * x + y = y + x Commutativity
> -- * (x + y) + z = x + (y + z) Associativity
> -- * x + 0 = x Group identity
> -- * x + -x = 0 Cancellation
>
> -- A substitution maps variables to terms. A substitution s is
> -- extended to a term as follows.
> --
> -- s(0) = 0
> -- s(-t) = -s(t)
> -- s(t + t') = s(t) + s(t')
>
> -- The unification problem is given the problem statement t =? t',
> -- find a substitution s such that s(t) = s(t') modulo the axioms of
> -- the algebra. The matching problem is to find substitution s such
> -- that s(t) = t' modulo the axioms.
>
> -- A term is represented as the sum of factors, and a factor is the
> -- product of an integer coeficient and a variable or the group
> -- identity, zero. In this representation, every coeficient is
> -- non-zero, and no variable occurs twice.
>
> -- A term can be represented by a finite map from variables to
> -- non-negative integers. To make the code easier to understand,
> -- association lists are used instead of Data.Map.
>
> newtype Lin = Lin [(String, Int)]
>
> -- Constructors
>
> -- Identity element (zero)
> ide :: Lin
> ide = Lin []
>
> -- Factors
> var :: Int -> String -> Lin
> var 0 _ = Lin []
> var c x = Lin [(x, c)]
>
> -- Invert by negating coefficients.
> neg :: Lin -> Lin
> neg (Lin t) =
> Lin $ map (\(x, c) -> (x, negate c)) t
>
> -- Join terms ensuring that coefficients are non-zero, and no variable
> -- occurs twice.
> add :: Lin -> Lin -> Lin
> add (Lin t) (Lin t') =
> Lin $ foldr f t' t
> where
> f (x, c) t =
> case lookup x t of
> Just c' | c + c' == 0 -> remove x t
> | otherwise -> (x, c + c') : remove x t
> Nothing -> (x, c) : t
>
> -- Remove the first pair in an association list that matches the key.
> remove :: Eq a => a -> [(a, b)] -> [(a, b)]
> remove _ [] = []
> remove x (y@(z, _) : ys)
> | x == z = ys
> | otherwise = y : remove x ys
>
> canonicalize :: Lin -> Lin
> canonicalize (Lin t) =
> Lin (sort t)
>
> -- Convert a linearized term into an association list.
> assocs :: Lin -> [(String, Int)]
> assocs (Lin t) = t
>
> term :: [(String, Int)] -> Lin
> term assoc =
> foldr f ide assoc
> where
> f (x, c) t = add t $ var c x
>
> -- Unification and Matching
>
> newtype Equation = Equation (Lin, Lin)
>
> newtype Maplet = Maplet (String, Lin)
>
> -- Unification is the same as matching when there are no constants
> unify :: Monad m => Equation -> m [Maplet]
> unify (Equation (t0, t1)) =
> match $ Equation (add t0 (neg t1), ide)
>
> -- Matching in Abelian groups is performed by finding integer
> -- solutions to linear equations, and then using the solutions to
> -- construct a most general unifier.
> match :: Monad m => Equation -> m [Maplet]
> match (Equation (t0, t1)) =
> case (assocs t0, assocs t1) of
> ([], []) -> return []
> ([], _) -> fail "no solution"
> (t0, t1) ->
> do
> subst <- intLinEq (map snd t0) (map snd t1)
> return $ mgu (map fst t0) (map fst t1) subst
>
> -- Construct a most general unifier from a solution to a linear
> -- equation. The function adds the variables back into terms, and
> -- generates fresh variables as needed.
> mgu :: [String] -> [String] -> Subst -> [Maplet]
> mgu vars syms subst =
> foldr f [] (zip vars [0..])
> where
> f (x, n) maplets =
> case lookup n subst of
> Just (factors, consts) ->
> Maplet (x, g factors consts) : maplets
> Nothing ->
> Maplet (x, var 1 $ genSym n) : maplets
> g factors consts =
> term (zip genSyms factors ++ zip syms consts)
> genSyms = map genSym [0..]
>
> -- Generated variables start with this character.
> genChar :: Char
> genChar = 'g'
>
> genSym :: Int -> String
> genSym i = genChar : show i
>
> -- So why solve linear equations? Consider the matching problem
> --
> -- c[0]*x[0] + c[1]*x[1] + ... + c[n-1]*x[n-1] =?
> -- d[0]*a[0] + d[1]*a[1] + ... + d[m-1]*a[m-1]
> --
> -- with n variables and m constants. We seek a most general unifier s
> -- such that
> --
> -- s(c[0]*x[0] + c[1]*x[1] + ... + c[n-1]*x[n-1]) =
> -- d[0]*a[0] + d[1]*a[1] + ... + d[m-1]*a[m-1]
> --
> -- which is the same as
> --
> -- c[0]*s(x[0]) + c[1]*s(x[1]) + ... + c[n-1]*s(x[n-1]) =
> -- d[0]*a[0] + d[1]*a[1] + ... + d[m-1]*a[m-1]
> --
> -- Notice that the number of occurrences of constant a[0] in s(x[0])
> -- plus s(x[1]) ... s(x[n-1]) must equal d[0]. Thus the mappings of
> -- the unifier that involve constant a[0] respect integer solutions of
> -- the following linear equation.
> --
> -- c[0]*x[0] + c[1]*x[1] + ... + c[n-1]*x[n-1] = d[0]
> --
> -- To compute a most general unifier, a most general integer solution
> -- to a linear equation must be found.
>
> -- Integer Solutions of Linear Inhomogeneous Equations
>
> type LinEq = ([Int], [Int])
>
> -- A linear equation with integer coefficients is represented as a
> -- pair of lists of integers, the coefficients and the constants. If
> -- there are no constants, the linear equation represented by (c, [])
> -- is the homogeneous equation:
> --
> -- c[0]*x[0] + c[1]*x[1] + ... + c[n-1]*x[n-1] = 0
> --
> -- where n is the length of c. Otherwise, (c, d) represents a
> -- sequence of inhomogeneous linear equations with the same
> -- left-hand-side:
> --
> -- c[0]*x[0] + c[1]*x[1] + ... + c[n-1]*x[n-1] = d[0]
> -- c[0]*x[0] + c[1]*x[1] + ... + c[n-1]*x[n-1] = d[1]
> -- ...
> -- c[0]*x[0] + c[1]*x[1] + ... + c[n-1]*x[n-1] = d[m-1]
> --
> -- where m is the length of d.
>
> type Subst = [(Int, LinEq)]
>
> -- A solution is a partial map from variables to terms, and a term is
> -- a pair of lists of integers, the variable part of the term followed
> -- by the constant part. The variable part may specify variables not
> -- in the input. For example, the solution of
> --
> -- 64x = 41y + 1
> --
> -- is x = -41z - 16 and y = -64z - 25. The computed solution is read
> -- off the list returned as an answer.
> --
> -- intLinEq [64,-41] [1] =
> -- [(0,([0,0,0,0,0,0,-41],[-16])),
> -- (1,([0,0,0,0,0,0,-64],[-25]))]
>
> -- Find integer solutions to linear equations
> intLinEq :: Monad m => [Int] -> [Int] -> m Subst
> intLinEq coefficients constants =
> intLinEqLoop (length coefficients) (coefficients, constants) []
>
> -- The algorithm used to find solutions is described in Vol. 2 of The
> -- Art of Computer Programming / Seminumerical Alorithms, 2nd Ed.,
> -- 1981, by Donald E. Knuth, pg. 327.
>
> -- On input, n is the number of variables in the original problem, c
> -- is the coefficients, d is the constants, and subst is a list of
> -- eliminated variables.
> intLinEqLoop :: Monad m => Int -> LinEq -> Subst -> m Subst
> intLinEqLoop n (c, d) subst =
> -- Find the smallest coefficient in absolute value
> let (i, ci) = smallest c in
> case () of
> _ | ci < 0 -> intLinEqLoop n (invert c, invert d) subst
> -- Ensure the smallest coefficient is positive
> | ci == 0 -> fail "bad problem"
> -- Lack of non-zero coefficients is an error
> | ci == 1 ->
> -- A general solution of the following form has been found:
> -- x[i] = sum[j] -c'[j]*x[j] + d[k] for all k
> -- where c' is c with c'[i] = 0.
> return $ eliminate n (i, (invert (zero i c), d)) subst
> | divisible ci c ->
> -- If all the coefficients are divisible by c[i], a solution is
> -- immediate if all the constants are divisible by c[i],
> -- otherwise there is no solution.
> if divisible ci d then
> let c' = divide ci c
> d' = divide ci d in
> return $ eliminate n (i, (invert (zero i c'), d')) subst
> else
> fail "no solution"
> | otherwise ->
> -- Eliminate x[i] in favor of freshly created variable x[n],
> -- where n is the length of c.
> -- x[n] = sum[j] (c[j] div c[i] * x[j])
> -- The new equation to be solved is:
> -- c[i]*x[n] + sum[j] c'[j]*x[j] = d[k] for all k
> -- where c'[j] = c[j] mod c[i] for j /= i and c'[i] = 0.
> let c' = map (\x -> mod x ci) (zero i c)
> c'' = divide ci (zero i c)
> subst' = eliminate n (i, (invert c'' ++ [1], [])) subst in
> intLinEqLoop n (c' ++ [ci], d) subst'
>
> -- Find the smallest coefficient in absolute value
> smallest :: [Int] -> (Int, Int)
> smallest xs =
> foldl f (-1, 0) (zip [0..] xs)
> where
> f (i, n) (j, x)
> | n == 0 = (j, x)
> | x == 0 || abs n <= abs x = (i, n)
> | otherwise = (j, x)
>
> invert :: [Int] -> [Int]
> invert t = map negate t
>
> -- Zero the ith position in a list
> zero :: Int -> [Int] -> [Int]
> zero _ [] = []
> zero 0 (_:xs) = 0 : xs
> zero i (x:xs) = x : zero (i - 1) xs
>
> -- Eliminate a variable from the existing substitution. If the
> -- variable is in the original problem, add it to the substitution.
> eliminate :: Int -> (Int, LinEq) -> Subst -> Subst
> eliminate n m@(i, (c, d)) subst =
> if i < n then
> m : map f subst
> else
> map f subst
> where
> f m'@(i', (c', d')) = -- Eliminate i in c' if it occurs in c'
> case get i c' of
> 0 -> m' -- i is not in c'
> ci -> (i', (addmul ci (zero i c') c, addmul ci d' d))
> -- Find ith coefficient
> get _ [] = 0
> get 0 (x:_) = x
> get i (_:xs) = get (i - 1) xs
> -- addnum n xs ys sums xs and ys after multiplying ys by n
> addmul 1 [] ys = ys
> addmul n [] ys = map (* n) ys
> addmul _ xs [] = xs
> addmul n (x:xs) (y:ys) = (x + n * y) : addmul n xs ys
>
> divisible :: Int -> [Int] -> Bool
> divisible small t =
> all (\x -> mod x small == 0) t
>
> divide :: Int -> [Int] -> [Int]
> divide small t =
> map (\x -> div x small) t
>
> -- Input and Output
>
> instance Show Lin where
> showsPrec _ (Lin []) =
> showString "0"
> showsPrec _ x =
> showFactor t . showl ts
> where
> Lin (t:ts) = canonicalize x
> showFactor (x, 1) = showString x
> showFactor (x, -1) = showChar '-' . showString x
> showFactor (x, c) = shows c . showString x
> showl [] = id
> showl ((s,n):ts)
> | n < 0 =
> showString " - " . showFactor (s, negate n) . showl ts
> showl (t:ts) = showString " + " . showFactor t . showl ts
>
> instance Read Lin where
> readsPrec _ s0 =
> [ (t1, s2) | (t0, s1) <- readFactor s0,
> (t1, s2) <- readRest t0 s1 ]
> where
> readPrimary s0 =
> [ (t0, s1) | (x, s1) <- scan s0, isVar x,
> let t0 = var 1 x ] ++
> [ (t0, s1) | ("0", s1) <- scan s0,
> (s, _) <- scan s1, not (isVar s),
> let t0 = ide ] ++
> [ (t0, s2) | (n, s1) <- scan s0, isNum n,
> (x, s2) <- scan s1, isVar x,
> let t0 = var (read n) x ] ++
> [ (t0, s3) | ("(", s1) <- scan s0,
> (t0, s2) <- reads s1,
> (")", s3) <- scan s2 ]
> readFactor s0 =
> [ (t1, s2) | ("-", s1) <- scan s0,
> (t0, s2) <- readPrimary s1,
> let t1 = neg t0 ] ++
> [ (t0, s1) | (s, _) <- scan s0, s /= "-",
> (t0, s1) <- readPrimary s0 ]
> readRest t0 s0 =
> [ (t2, s3) | ("+", s1) <- scan s0,
> (t1, s2) <- readFactor s1,
> (t2, s3) <- readRest (add t0 t1) s2 ] ++
> [ (t2, s3) | ("-", s1) <- scan s0,
> (t1, s2) <- readPrimary s1,
> (t2, s3) <- readRest (add t0 (neg t1)) s2 ] ++
> [ (t0, s0) | (s, _) <- scan s0, s /= "+" && s /= "-" ]
>
> isNum :: String -> Bool
> isNum (c:_) = isDigit c
> isNum _ = False
>
> isVar :: String -> Bool
> isVar (c:_) = isAlpha c && c /= genChar
> isVar _ = False
>
> scan :: ReadS String
> scan "" = [("", "")]
> scan (c:s)
> | isSpace c = scan s
> | isAlpha c = [ (c:part, t) | (part,t) <- [span isAlphaNum s] ]
> | isDigit c = [ (c:part, t) | (part,t) <- [span isDigit s] ]
> | otherwise = [([c], s)]
>
> instance Show Equation where
> showsPrec _ (Equation (t0, t1)) =
> shows t0 . showString " = " . shows t1
>
> instance Read Equation where
> readsPrec _ s0 =
> [ (Equation (t0, t1), s3) | (t0, s1) <- reads s0,
> ("=", s2) <- scan s1,
> (t1, s3) <- reads s2 ]
>
> instance Show Maplet where
> showsPrec _ (Maplet (x, t)) =
> showString x . showString " -> " . shows t
>
> -- Test Routine
>
> -- Given an equation, display a unifier and a matcher.
> test :: String -> IO ()
> test prob =
> case readM prob of
> Err err -> putStrLn err
> Ans (Equation (t0, t1)) ->
> do
> putStr "Problem: "
> print $ Equation (canonicalize t0, canonicalize t1)
> subst <- unify $ Equation (t0, t1)
> putStr "Unifier: "
> print subst
> putStr "Matcher: "
> case match $ Equation (t0, t1) of
> Err err -> putStrLn err
> Ans subst -> print subst
> putStrLn ""
>
> readM :: (Read a, Monad m) => String -> m a
> readM s =
> case [ x | (x, t) <- reads s, ("", "") <- lex t ] of
> [x] -> return x
> [] -> fail "no parse"
> _ -> fail "ambiguous parse"
>
> data AnsErr a
> = Ans a
> | Err String
>
> instance Monad AnsErr where
> (Ans x) >>= k = k x
> (Err s) >>= _ = Err s
> return = Ans
> fail = Err
>
> main :: IO ()
> main =
> do
> done <- isEOF
> case done of
> True -> return ()
> False ->
> do
> prob <- getLine
> test prob
> main
More information about the Haskell-Cafe
mailing list