Haskell/Continuation passing style

(Redirected from Haskell/CPS)

Continuation Passing Style (CPS for short) is a style of programming in which functions do not return values; rather, they pass control onto a continuation, which specifies what happens next. In this chapter, we are going to consider how that plays out in Haskell and, in particular, how CPS can be expressed with a monad.

What are continuations?

edit

To dispel puzzlement, we will have a second look at an example from way back in the book, when we introduced the ($) operator:

> map ($ 2) [(2*), (4*), (8*)]
[4,8,16]

There is nothing out of ordinary about the expression above, except that it is a little quaint to write that instead of map (*2) [2, 4, 8]. The ($) section makes the code appear backwards, as if we are applying a value to the functions rather than the other way around. And now, the catch: such an innocent-looking reversal is at heart of continuation passing style!

From a CPS perspective, ($ 2) is a suspended computation: a function with general type (a -> r) -> r which, given another function as argument, produces a final result. The a -> r argument is the continuation; it specifies how the computation will be brought to a conclusion. In the example, the functions in the list are supplied as continuations via map, producing three distinct results. Note that suspended computations are largely interchangeable with plain values: flip ($) [1] converts any value into a suspended computation, and passing id as its continuation gives back the original value.

What are they good for?

edit

There is more to continuations than just a parlour trick to impress Haskell newbies. They make it possible to explicitly manipulate, and dramatically alter, the control flow of a program. For instance, returning early from a procedure can be implemented with continuations. Exceptions and failure can also be handled with continuations - pass in a continuation for success, another continuation for fail, and invoke the appropriate continuation. Other possibilities include "suspending" a computation and returning to it at another time, and implementing simple forms of concurrency (notably, one Haskell implementation, Hugs, uses continuations to implement cooperative concurrency).

In Haskell, continuations can be used in a similar fashion, for implementing interesting control flow in monads. Note that there usually are alternative techniques for such use cases, especially in tandem with laziness. In some circumstances, CPS can be used to improve performance by eliminating certain construction-pattern matching sequences (i.e. a function returns a complex structure which the caller will at some point deconstruct), though a sufficiently smart compiler should be able to do the elimination [2].

Passing continuations

edit

An elementary way to take advantage of continuations is to modify our functions so that they return suspended computations rather than ordinary values. We will illustrate how that is done with two simple examples.

pythagoras

edit

Example: A simple module, no continuations

-- We assume some primitives add and square for the example:

add :: Int -> Int -> Int
add x y = x + y

square :: Int -> Int
square x = x * x

pythagoras :: Int -> Int -> Int
pythagoras x y = add (square x) (square y)

Modified to return a suspended computation, pythagoras looks like this:

Example: A simple module, using continuations

-- We assume CPS versions of the add and square primitives,
-- (note: the actual definitions of add_cps and square_cps are not
-- in CPS form, they just have the correct type)

add_cps :: Int -> Int -> ((Int -> r) -> r)
add_cps x y = \k -> k (add x y)

square_cps :: Int -> ((Int -> r) -> r)
square_cps x = \k -> k (square x)

pythagoras_cps :: Int -> Int -> ((Int -> r) -> r)
pythagoras_cps x y = \k ->
 square_cps x $ \x_squared ->
 square_cps y $ \y_squared ->
 add_cps x_squared y_squared $ k

How the pythagoras_cps example works:

  1. square x and throw the result into the (\x_squared -> ...) continuation
  2. square y and throw the result into the (\y_squared -> ...) continuation
  3. add x_squared and y_squared and throw the result into the top level/program continuation k.

We can try it out in GHCi by passing print as the program continuation:

*Main> pythagoras_cps 3 4 print
25

If we look at the type of pythagoras_cps without the optional parentheses around (Int -> r) -> r and compare it with the original type of pythagoras, we note that the continuation was in effect added as an extra argument, thus justifying the "continuation passing style" moniker.

thrice

edit

Example: A simple higher order function, no continuations

thrice :: (a -> a) -> a -> a
thrice f x = f (f (f x))
*Main> thrice tail "foobar"
"bar"

A higher order function such as thrice, when converted to CPS, takes as arguments functions in CPS form as well. Therefore, f :: a -> a will become f_cps :: a -> ((a -> r) -> r), and the final type will be thrice_cps :: (a -> ((a -> r) -> r)) -> a -> ((a -> r) -> r). The rest of the definition follows quite naturally from the types - we replace f by the CPS version, passing along the continuation at hand.

Example: A simple higher order function, with continuations

thrice_cps :: (a -> ((a -> r) -> r)) -> a -> ((a -> r) -> r)
thrice_cps f_cps x = \k ->
 f_cps x $ \fx ->
 f_cps fx $ \ffx ->
 f_cps ffx $ k


The Cont monad

edit

Having continuation-passing functions, the next step is providing a neat way of composing them, preferably one which does not require the long chains of nested lambdas we have seen just above. A good start would be a combinator for applying a CPS function to a suspended computation. A possible type for it would be:

chainCPS :: ((a -> r) -> r) -> (a -> ((b -> r) -> r)) -> ((b -> r) -> r)

(You may want to try implementing it before reading on. Hint: start by stating that the result is a function which takes a b -> r continuation; then, let the types guide you.)

And here is the implementation:

chainCPS s f = \k -> s $ \x -> f x $ k

We supply the original suspended computation s with a continuation which makes a new suspended computation (produced by f) and passes the final continuation k to it. Unsurprisingly, it mirrors closely the nested lambda pattern of the previous examples.

Doesn't the type of chainCPS look familiar? If we replace (a -> r) -> r with (Monad m) => m a and (b -> r) -> r with (Monad m) => m b we get the (>>=) signature. Furthermore, our old friend flip ($) plays a return-like role, in that it makes a suspended computation out of a value in a trivial way. Lo and behold, we have a monad! All we need now [3] is a Cont r a type to wrap suspended computations, with the usual wrapper and unwrapper functions.

cont :: ((a -> r) -> r) -> Cont r a
runCont :: Cont r a -> (a -> r) -> r

The monad instance for Cont follows directly from our presentation, the only difference being the wrapping and unwrapping cruft:

instance Monad (Cont r) where
    return x = cont ($ x)
    s >>= f  = cont $ \c -> runCont s $ \x -> runCont (f x) c

The end result is that the monad instance makes the continuation passing (and thus the lambda chains) implicit. The monadic bind applies a CPS function to a suspended computation, and runCont is used to provide the final continuation. For a simple example, the Pythagoras example becomes:

Example: The pythagoras example, using the Cont monad

-- Using the Cont monad from the transformers package.
import Control.Monad.Trans.Cont

add_cont :: Int -> Int -> Cont r Int
add_cont x y = return (add x y)

square_cont :: Int -> Cont r Int
square_cont x = return (square x)

pythagoras_cont :: Int -> Int -> Cont r Int
pythagoras_cont x y = do
    x_squared <- square_cont x
    y_squared <- square_cont y
    add_cont x_squared y_squared

callCC

edit

While it is always pleasant to see a monad coming forth naturally, a hint of disappointment might linger at this point. One of the promises of CPS was precise control flow manipulation through continuations. And yet, after converting our functions to CPS we promptly hid the continuations behind a monad. To rectify that, we shall introduce callCC, a function which gives us back explicit control of continuations - but only where we want it.

callCC is a very peculiar function; one that is best introduced with examples. Let us start with a trivial one:

Example: square using callCC

-- Without callCC
square :: Int -> Cont r Int
square n = return (n ^ 2)

-- With callCC
squareCCC :: Int -> Cont r Int
squareCCC n = callCC $ \k -> k (n ^ 2)

The argument passed to callCC is a function, whose result is a suspended computation (general type Cont r a) which we will refer to as "the callCC computation". In principle, the callCC computation is what the whole callCC expression evaluates to. The caveat, and what makes callCC so special, is due to k, the argument to the argument. It is a function which acts as an eject button: calling it anywhere will lead to the value passed to it being made into a suspended computation, which then is inserted into control flow at the point of the callCC invocation. That happens unconditionally; in particular, whatever follows a k invocation in the callCC computation is summarily discarded. From another perspective, k captures the rest of the computation following the callCC; calling it throws a value into the continuation at that particular point ("callCC" stands for "call with current continuation"). While in this simple example the effect is merely that of a plain return, callCC opens up a number of possibilities, which we are now going to explore.

Deciding when to use k

edit

callCC gives us extra power over what is thrown into a continuation, and when that is done. The following example begins to show how we can use this extra power.

Example: Our first proper callCC function

foo :: Int -> Cont r String
foo x = callCC $ \k -> do
    let y = x ^ 2 + 3
    when (y > 20) $ k "over twenty"
    return (show $ y - 4)

foo is a slightly pathological function that computes the square of its input and adds three; if the result of this computation is greater than 20, then we return from the callCC computation (and, in this case, from the whole function) immediately, throwing the string "over twenty" into the continuation that will be passed to foo. If not, then we subtract four from our previous computation, show it, and throw it into the continuation. Remarkably, k here is used just like the 'return' statement from an imperative language, that immediately exits the function. And yet, this being Haskell, k is just an ordinary first-class function, so you can pass it to other functions like when, store it in a Reader, etc.

Naturally, you can embed calls to callCC within do-blocks:

Example: More developed callCC example involving a do-block

bar :: Char -> String -> Cont r Int
bar c s = do
    msg <- callCC $ \k -> do
        let s0 = c : s
        when (s0 == "hello") $ k "They say hello."
        let s1 = show s0
        return ("They appear to be saying " ++ s1)
    return (length msg)

When you call k with a value, the entire callCC call takes that value. In effect, that makes k a lot like an 'goto' statement in other languages: when we call k in our example, it pops the execution out to where you first called callCC, the msg <- callCC $ ... line. No more of the argument to callCC (the inner do-block) is executed. Hence the following example contains a useless line:

Example: Popping out a function, introducing a useless line

quux :: Cont r Int
quux = callCC $ \k -> do
    let n = 5
    k n
    return 25

quux will return 5, and not 25, because we pop out of quux before getting to the return 25 line.

Behind the scenes

edit

We have deliberately broken a trend here: normally when we introduce a function we give its type straight away, but in this case we chose not to. The reason is simple: the type is pretty complex, and it does not immediately give insight into what the function does, or how it works. After the initial presentation of callCC, however, we are in a better position to tackle it. Take a deep breath...

callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a

We can make sense of that based on what we already know about callCC. The overall result type and the result type of the argument have to be the same (i.e. Cont r a), as in the absence of an invocation of k the corresponding result values are one and the same. Now, what about the type of k? As mentioned above, k's argument is made into a suspended computation inserted at the point of the callCC invocation; therefore, if the latter has type Cont r a k's argument must have type a. As for k's result type, interestingly enough it doesn't matter as long as it is wrapped in the same Cont r monad; in other words, the b stands for an arbitrary type. That happens because the suspended computation made out of the a argument will receive whatever continuation follows the callCC, and so the continuation taken by k's result is irrelevant.

Note

The arbitrariness of k's result type explains why the following variant of the useless line example leads to a type error:

quux :: Cont r Int
quux = callCC $ \k -> do
   let n = 5
   when True $ k n
   k 25

k's result type could be anything of form Cont r b; however, the when constrains it to Cont r (), and so the closing k 25 does not match the result type of quux. The solution is very simple: replace the final k by a plain old return.


To conclude this section, here is the implementation of callCC. Can you identify k in it?

callCC f = cont $ \h -> runCont (f (\a -> cont $ \_ -> h a)) h

Though the code is far from obvious, an amazing fact is that the implementations of callCC, return and (>>=) for Cont can be produced automatically from their type signatures - Lennart Augustsson's Djinn [1] is a program that will do this for you. See Phil Gossett's Google tech talk: [2] for background on the theory behind Djinn; and Dan Piponi's article: [3] which uses Djinn in deriving continuation passing style.

Example: a complicated control structure

edit

We will now look at some more realistic examples of control flow manipulation. The first one, presented below, was originally taken from "The Continuation monad" section of the All about monads tutorial, used with permission.

Example: Using Cont for a complicated control structure

{- We use the continuation monad to perform "escapes" from code blocks.
This function implements a complicated control structure to process
numbers:

Input (n)     Output                    List Shown
=========     ======                    ==========
0-9           n                         none
10-199        number of digits in (n/2) digits of (n/2)
200-19999     n                         digits of (n/2)
20000-1999999 (n/2) backwards           none
>= 2000000    sum of digits of (n/2)    digits of (n/2)
-} 
fun :: Int -> String
fun n = (`runCont` id) $ do
    str <- callCC $ \exit1 -> do                            -- define "exit1"
        when (n < 10) (exit1 (show n))
        let ns = map digitToInt (show (n `div` 2))
        n' <- callCC $ \exit2 -> do                         -- define "exit2"
            when ((length ns) < 3) (exit2 (length ns))
            when ((length ns) < 5) (exit2 n)
            when ((length ns) < 7) $ do
                let ns' = map intToDigit (reverse ns)
                exit1 (dropWhile (=='0') ns')               --escape 2 levels
            return $ sum ns
        return $ "(ns = " ++ (show ns) ++ ") " ++ (show n')
    return $ "Answer: " ++ str

fun is a function that takes an integer n. The implementation uses Cont and callCC to set up a control structure using Cont and callCC that does different things based on the range that n falls in, as stated by the comment at the top. Let us dissect it:

  1. Firstly, the (`runCont` id) at the top just means that we run the Cont block that follows with a final continuation of id (or, in other words, we extract the value from the suspended computation unchanged). That is necessary as the result type of fun doesn't mention Cont.
  2. We bind str to the result of the following callCC do-block:
    1. If n is less than 10, we exit straight away, just showing n.
    2. If not, we proceed. We construct a list, ns, of digits of n `div` 2.
    3. n' (an Int) gets bound to the result of the following inner callCC do-block.
      1. If length ns < 3, i.e., if n `div` 2 has less than 3 digits, we pop out of this inner do-block with the number of digits as the result.
      2. If n `div` 2 has less than 5 digits, we pop out of the inner do-block returning the original n.
      3. If n `div` 2 has less than 7 digits, we pop out of both the inner and outer do-blocks, with the result of the digits of n `div` 2 in reverse order (a String).
      4. Otherwise, we end the inner do-block, returning the sum of the digits of n `div` 2.
    4. We end this do-block, returning the String "(ns = X) Y", where X is ns, the digits of n `div` 2, and Y is the result from the inner do-block, n'.
  3. Finally, we return out of the entire function, with our result being the string "Answer: Z", where Z is the string we got from the callCC do-block.

Example: exceptions

edit

One use of continuations is to model exceptions. To do this, we hold on to two continuations: one that takes us out to the handler in case of an exception, and one that takes us to the post-handler code in case of a success. Here's a simple function that takes two numbers and does integer division on them, failing when the denominator is zero.

Example: An exception-throwing div

divExcpt :: Int -> Int -> (String -> Cont r Int) -> Cont r Int
divExcpt x y handler = callCC $ \ok -> do
    err <- callCC $ \notOk -> do
        when (y == 0) $ notOk "Denominator 0"
        ok $ x `div` y
    handler err

{- For example,
runCont (divExcpt 10 2 error) id --> 5
runCont (divExcpt 10 0 error) id --> *** Exception: Denominator 0
-}

How does it work? We use two nested calls to callCC. The first labels a continuation that will be used when there's no problem. The second labels a continuation that will be used when we wish to throw an exception. If the denominator isn't 0, x `div` y is thrown into the ok continuation, so the execution pops right back out to the top level of divExcpt. If, however, we were passed a zero denominator, we throw an error message into the notOk continuation, which pops us out to the inner do-block, and that string gets assigned to err and given to handler.

A more general approach to handling exceptions can be seen with the following function. Pass a computation as the first parameter (more precisely, a function which takes an error-throwing function and results in the computation) and an error handler as the second parameter. This example takes advantage of the generic MonadCont class [4] which covers both Cont and the corresponding ContT transformer by default, as well as any other continuation monad which instantiates it.

Example: General try using continuations.

import Control.Monad.Cont

tryCont :: MonadCont m => ((err -> m a) -> m a) -> (err -> m a) -> m a
tryCont c h = callCC $ \ok -> do
    err <- callCC $ \notOk -> do
        x <- c notOk
        ok x
    h err

And here is our try in action:

Example: Using try

data SqrtException = LessThanZero deriving (Show, Eq)

sqrtIO :: (SqrtException -> ContT r IO ()) -> ContT r IO ()
sqrtIO throw = do 
    ln <- lift (putStr "Enter a number to sqrt: " >> readLn)
    when (ln < 0) (throw LessThanZero)
    lift $ print (sqrt ln)

main = runContT (tryCont sqrtIO (lift . print)) return

In this example, error throwing means escaping from an enclosing callCC. The throw in sqrtIO jumps out of tryCont's inner callCC.

Example: coroutines

edit

In this section we make a CoroutineT monad that provides a monad with fork, which enqueues a new suspended coroutine, and yield, that suspends the current thread.

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-- We use GeneralizedNewtypeDeriving to avoid boilerplate. As of GHC 7.8, it is safe.

import Control.Applicative
import Control.Monad.Cont
import Control.Monad.State

-- The CoroutineT monad is just ContT stacked with a StateT containing the suspended coroutines.
newtype CoroutineT r m a = CoroutineT {runCoroutineT' :: ContT r (StateT [CoroutineT r m ()] m) a}
    deriving (Functor,Applicative,Monad,MonadCont,MonadIO)

-- Used to manipulate the coroutine queue.
getCCs :: Monad m => CoroutineT r m [CoroutineT r m ()]
getCCs = CoroutineT $ lift get

putCCs :: Monad m => [CoroutineT r m ()] -> CoroutineT r m ()
putCCs = CoroutineT . lift . put

-- Pop and push coroutines to the queue.
dequeue :: Monad m => CoroutineT r m ()
dequeue = do
    current_ccs <- getCCs
    case current_ccs of
        [] -> return ()
        (p:ps) -> do
            putCCs ps
            p

queue :: Monad m => CoroutineT r m () -> CoroutineT r m ()
queue p = do
    ccs <- getCCs
    putCCs (ccs++[p])

-- The interface.
yield :: Monad m => CoroutineT r m ()
yield = callCC $ \k -> do
    queue (k ())
    dequeue

fork :: Monad m => CoroutineT r m () -> CoroutineT r m ()
fork p = callCC $ \k -> do
    queue (k ())
    p
    dequeue

-- Exhaust passes control to suspended coroutines repeatedly until there isn't any left.
exhaust :: Monad m => CoroutineT r m ()
exhaust = do
    exhausted <- null <$> getCCs
    if not exhausted
        then yield >> exhaust
        else return ()

-- Runs the coroutines in the base monad.
runCoroutineT :: Monad m => CoroutineT r m r -> m r
runCoroutineT = flip evalStateT [] . flip runContT return . runCoroutineT' . (<* exhaust)

Some example usage:

printOne n = do
    liftIO (print n)
    yield

example = runCoroutineT $ do
    fork $ replicateM_ 3 (printOne 3)
    fork $ replicateM_ 4 (printOne 4)
    replicateM_ 2 (printOne 2)

Outputting:

3
4
3
2
4
3
2
4
4

Example: Implementing pattern matching

edit

An interesting usage of CPS functions is to implement our own pattern matching. We will illustrate how this can be done by some examples.

Example: Built-in pattern matching

check :: Bool -> String
check b = case b of
    True  -> "It's True"
    False -> "It's False"

Now we have learnt CPS, we can refactor the code like this.

Example: Pattern matching in CPS

type BoolCPS r = r -> r -> r

true :: BoolCPS r
true x _ = x

false :: BoolCPS r
false _ x = x

check :: BoolCPS String -> String
check b = b "It's True" "It's False"
*Main> check true
"It's True"
*Main> check false
"It's False"

What happens here is that, instead of plain values, we represent True and False by functions that would choose either the first or second argument they are passed. Since true and false behave differently, we can achieve the same effect as pattern matching. Furthermore, True, False and true, false can be converted back and forth by \b -> b True False and \b -> if b then true else false.

We should see how this is related to CPS in this more complicated example.

Example: More complicated pattern matching and its CPS equivalence

data Foobar = Zero | One Int | Two Int Int

type FoobarCPS r = r -> (Int -> r) -> (Int -> Int -> r) -> r

zero :: FoobarCPS r
zero x _ _ = x

one :: Int -> FoobarCPS r
one x _ f _ = f x

two :: Int -> Int -> FoobarCPS r
two x y _ _ f = f x y

fun :: Foobar -> Int
fun x = case x of
    Zero -> 0
    One a -> a + 1
    Two a b -> a + b + 2

funCPS :: FoobarCPS Int -> Int
funCPS x = x 0 (+1) (\a b -> a + b + 2)
*Main> fun Zero
0
*Main> fun $ One 3
4
*Main> fun $ Two 3 4
9
*Main> funCPS zero
0
*Main> funCPS $ one 3
4
*Main> funCPS $ two 3 4
9

Similar to former example, we represent values by functions. These function-values pick the corresponding (i.e. match) continuations they are passed to and pass to the latter the values stored in the former. An interesting thing is that this process involves in no comparison. As we know, pattern matching can work on types that are not instances of Eq: the function-values "know" what their patterns are and would automatically pick the right continuations. If this is done from outside, say, by an pattern_match :: [(pattern, result)] -> value -> result function, it would have to inspect and compare the patterns and the values to see if they match -- and thus would need Eq instances.

Notes

  1. That is, \x -> ($ x), fully spelled out as \x -> \k -> k x
  2. attoparsec is an example of performance-driven usage of CPS.
  3. Beyond verifying that the monad laws hold, which is left as an exercise to the reader.
  4. Found in the mtl package, module Control.Monad.Cont.