[Haskell-beginners] Can we make this program better (and shorter)?

Zachary Turner divisortheory at gmail.com
Thu Apr 2 02:34:04 EDT 2009


Be warned, this code is really ugly.  I need some help to make it better.

I was in this situation this week where there was a sample of some numbers.
I knew the valid possible range that numbers could lie in, and I knew about
30% of the actual values from the sample.  I also knew the original average
and the original standard deviation.  I wanted more information about the
other values in the sample, so I sat down and derived two formulas: One
that, given an average, the number of items in the average, and 1 number
that you know was used somewhere in the calculation of the average,
calculates what the average would be if that number had not been there.
Another that does the same for standard deviation.  The second formula is
fairly insane, if someone can figure out how to simplify it let me know, it
took forever to derive it without making any mistakes.  Actually I didn't
even know it was possible to do a running standard deviation like that until
I sat down and derived it.

So anyway the program asks the user how large the original sample was, and
the average and standard deviation.  Then it runs in a loop asking it for a
value to remove from the sample set and calculate the updated values.  For
certain values however, even if they are actually in the range of valid
values for the range of the sample space, they might be impossible to
achieve given a certain sample size, average, and stdev (for example, if you
have two numbers in the range [0..100] and the average is 100, and you
remove 1 number, none of the remaining numbers can be 0 obviously).  So I
wanted to detect this if the user inputs such a value, that way for one
thing I can calculate exact upper and lower bounds for the actual range of
the numbers just by trying all possible values.

I know that the nature of the program mandates that there's going to be a
lot of IO Monadic code, but I still feel like my code sucks.

For one thing, I really tried to make the removePoints method be pure, and
just accept a list as the argument and do the I/O in another function.  I
imagined some kind of lazy list comprehension, where each time the
removePoints method asked for the next item, it would ask for the input and
do all that stuff.  I couldn't figure out how to make this work.  Perhaps
it's not even possible, since by definition a list comprehension that asked
for input when forced isn't pure.  But I still think there's a better way to
do it.

I also really dislike having to do manual recursion for getting input,
there's got to be a way to use lists and built-in functions to do it for
you, using some kind of takeWhile() perhaps.

Anyway if anyone has any suggestions for making this code look better, more
elegant, or more Haskelly, let me know :)


import Control.Arrow

promptUntilValid :: (Read a, Show a, Eq a) => String -> IO a
promptUntilValid prompt = do
    putStr prompt
    putStr " "
    result <- getLine
    case (reads result) of
        [(parsed,_)] -> return parsed
        invalid -> do
            putStrLn $ "Invalid input!  You entered " ++ result ++ ", and
results length is " ++ (show (length invalid))
            promptUntilValid prompt

newStdev :: Int -> (Int,Int,Double,Double,Double) -> Double
newStdev removeValue (old_count, new_count, old_stdev, old_avg, new_avg) =
    sqrt.(/ (n-1.0)) $
        n * σ_x^2 + 2.0*old_avg*new_avg*(n-1.0) - (n-1.0)*old_avg^2 - (a_n -
old_avg)^2 - (n-1.0)*new_avg^2
    where
        σ_x = old_stdev
        n = fromIntegral old_count
        a_n = fromIntegral removeValue

removeSinglePoint :: Int -> (Int,Double,Double) -> (Int,Double,Double)
removeSinglePoint value (count,stdev,avg) =
    (new_count,new_stdev,new_avg)
    where new_count = count-1
          new_avg = (avg*(fromIntegral count) - (fromIntegral value)) /
(fromIntegral new_count)
          new_stdev = newStdev value (count, new_count, stdev, avg, new_avg)

removePoints :: (Int,Double,Double) -> IO ()
removePoints (count,stdev,avg) = do
    putStr "Enter a number to remove from the sample set (Enter to stop): "
    putStr " "
    result <- getLine
    case (result,reads result) of
        ("",_) -> return ()
        (_, [(parsed,_)]) -> do
            if (any (uncurry (||) . (isInfinite &&& isNaN)) [new_stdev,
new_avg])
             then do
                putStrLn "That number could not have been there!  The new
values are NaN!"
                removePoints (count,stdev,avg)
             else do
                putStrLn $ "New count: " ++ (show new_count) ++ ", new
stdev: " ++ (show new_stdev) ++ ", new avg: " ++ (show new_avg)
                removePoints (new_count,new_stdev,new_avg)
            where (new_count,new_stdev,new_avg) = removeSinglePoint parsed
(count,stdev,avg)
        (_,invalid) -> do
            putStrLn $ "Invalid input!  You entered " ++ result ++ ", and
results length is " ++ (show (length invalid))
            removePoints (count,stdev,avg)

main = do
    num_values::Int <- promptUntilValid "Enter the initial number of values:
"
    stdev::Double <- promptUntilValid "Enter the initial standard deviation:
"
    average::Double <- promptUntilValid "Enter the initial average: "

    putStrLn ("num_values = " ++ (show num_values))
    putStrLn ("stdev = " ++ (show stdev))
    putStrLn ("average = " ++ (show average))

    removePoints (num_values, stdev, average)

    return ()
-------------- next part --------------
An HTML attachment was scrubbed...
URL: http://www.haskell.org/pipermail/beginners/attachments/20090402/2f361ee4/attachment-0001.htm


More information about the Beginners mailing list