Haskell/Continuation passing style

Continuation passing style (Solutions)
Advanced Haskell

Arrows50%.svg
Understanding arrows
Continuation passing style (CPS) 75%.svg
Zippers 75%.svg
Value recursion (MonadFix)
Applicative Functors 50%.svg
Monoids 25%.svg
Mutable objects 00%.svg
Concurrency 00%.svg

Continuation Passing Style is a format for expressions such that no function ever returns, instead they pass control onto a continuation. Conceptually a continuation is what happens next, for example the continuation for x in (x+1)*2 is add one then multiply by two.

When do you need this?

Classically, continuations have been used primarily to dramatically alter the control flow of a program. For example, returning early from a procedure can be implemented with continuations. Exceptions and failure can also be signaled using continuations - pass in a continuation for success, another continuation for fail, and simply invoke the appropriate continuation. If your function-that-can-fail itself calls another function-that-can-fail, you create a new success continuation, but pass the fail continuation you got.

Also, continuations allow you to "suspend" a computation, returning to that computation at another time. They can be used to implement simple forms of concurrency (notably, one Haskell implementation, Hugs, uses continuations to implement cooperative concurrency). Other control flows are possible.

In Haskell, continuations can be used in a similar fashion, for implementing interesting behavior in some monads. (Note that there are usually other techniques for this available, too, especially in tandem with laziness) They can be used to improve performance by eliminating certain construction-pattern matching sequences (i.e. a function returns a complex structure which the caller will at some point deconstruct), though a sufficiently smart compiler "should" be able to eliminate this.


↑Jump back a section

Starting simple

To begin with, we're going to explore two simple examples which illustrate what CPS and continuations are. Firstly a 'first order' example (meaning there are no higher order functions in to CPS transform), then a higher order one.

square

Example: A simple module, no continuations

-- We assume some primitives add and square for the example:

add :: Int -> Int -> Int
add x y = x + y

square :: Int -> Int
square x = x * x

pythagoras :: Int -> Int -> Int
pythagoras x y = add (square x) (square y)

And the same function pythagoras, written in CPS looks like this:

Example: A simple module, using continuations

-- We assume CPS versions of the add and square primitives,
-- (note: the actual definitions of add_cps and square_cps are not
-- in CPS form, they just have the correct type)

add_cps :: Int -> Int -> (Int -> r) -> r
add_cps x y k = k (add x y)

square_cps :: Int -> (Int -> r) -> r
square_cps x k = k (square x)

pythagoras_cps :: Int -> Int -> (Int -> r) -> r
pythagoras_cps x y k =
 square_cps x $ \x_squared ->
 square_cps y $ \y_squared ->
 add_cps x_squared y_squared $ \sum_of_squares ->
 k sum_of_squares

How the pythagoras_cps example operates is:

  1. square x and throw the result into the (\x_squared -> ...) continuation
  2. square y and throw the result into the (\y_squared -> ...) continuation
  3. add x_squared and y_squared and throw the result into the (\sum_of_squares -> ...) continuation
  4. throw the sum_of_squares into the toplevel/program continuation

And one can try it out:

*Main> pythagoras_cps 3 4 print
25

thrice

Example: A simple higher order function, no continuations

thrice :: (o -> o) -> o -> o
thrice f x = f (f (f x))
*Main> thrice tail "foobar"
"bar"

Now the first thing to do, to CPS convert thrice, is compute the type of the CPSd form. We can see that f :: o -> o, so in the CPSd version, f_cps :: o -> (o -> r) -> r, and the whole type will be thrice_cps :: (o -> (o -> r) -> r) -> o -> (o -> r) -> r. Once we have the new type, that can help direct you how to write the function.

Example: A simple higher order function, with continuations

thrice_cps :: (o -> (o -> r) -> r) -> o -> (o -> r) -> r
thrice_cps f_cps x k =
 f_cps x $ \fx ->
 f_cps fx $ \ffx ->
 f_cps ffx $ \fffx ->
 k fffx
Exercises
FIXME: write some exercises
↑Jump back a section

Using the Cont monad

By now, you should be used to the (meta-)pattern that whenever we find a pattern we like (here the pattern is using continuations), but it makes our code a little ugly, we use a monad to encapsulate the 'plumbing'. Indeed, there is a monad for modelling computations which use CPS.

Example: The Cont monad

newtype Cont r a = Cont { runCont :: (a -> r) -> r }

Removing the newtype and record cruft, we obtain that Cont r a expands to (a -> r) -> r. So how does this fit with our idea of continuations we presented above? Well, remember that a function in CPS basically took an extra parameter which represented 'what to do next'. So, here, the type of Cont r a expands to be an extra function (the continuation), which is a function from things of type a (what the result of the function would have been, if we were returning it normally instead of throwing it into the continuation), to things of type r, which becomes the final result type of our function.

Example: The pythagoras example, using the Cont monad

import Control.Monad.Cont

add_cont :: Int -> Int -> Cont r Int
add_cont x y = return (add x y)

square_cont :: Int -> Cont r Int
square_cont x = return (square x)

pythagoras_cont :: Int -> Int -> Cont r Int
pythagoras_cont x y =
    do x_squared <- square_cont x
       y_squared <- square_cont y
       sum_of_squares <- add_cont x_squared y_squared
       return sum_of_squares
*Main> runCont (pythagoras_cont 3 4) print
25

Every function that returns a Cont-value actually takes an extra parameter, which is the continuation. Using return simply throws its argument into the continuation.

How does the Cont implementation of (>>=) work, then? It's easiest to see it at work:

Example: The (>>=) function for the Cont monad

square_C :: Int -> Cont r Int
square_C x = return (x ^ 2)

addThree_C :: Int -> Cont r Int
addThree_C x = return (x + 3)

main = runCont (square_C 4 >>= addThree_C) print
{- Result: 19 -}

The Monad instance for (Cont r) is given below:

instance Monad (Cont r) where
  return n = Cont (\k -> k n)
  m >>= f  = Cont (\k -> runCont m (\a -> runCont (f a) k))

So return n is a Cont-value that throws n straight away into whatever continuation it is applied to. m >>= f is a Cont-value that runs m with the continuation \a -> f a k, which maybe, receive the result of computation inside m (the result is bound to a) , then applies that result to f to get another Cont-value. This is then called with the continuation we got at the top level (the continuation is bound to k); in essence m >>= f is a Cont-value that takes the result from m, applies it to f, then throws that into the continuation.

Exercises
To come.
↑Jump back a section

callCC

By now you should be fairly confident using the basic notions of continuations and Cont, so we're going to skip ahead to the next big concept in continuation-land. This is a function called callCC, which is short for "call with current continuation". We'll start with an easy example.

Example: square using callCC

-- Without callCC
square :: Int -> Cont r Int
square n = return (n ^ 2)

-- With callCC
square :: Int -> Cont r Int
square n = callCC $ \k -> k (n ^ 2)

We pass a function to callCC that accepts one parameter that is in turn a function. This function (k in our example) is our tangible continuation: we can see here we're throwing a value (in this case, n ^ 2) into our continuation. We can see that the callCC version is equivalent to the return version stated above because we stated that return n is just a Cont-value that throws n into whatever continuation that it is given. Here, we use callCC to bring the continuation 'into scope', and immediately throw a value into it, just like using return.

However, these versions look remarkably similar, so why should we bother using callCC at all? The power lies in that we now have precise control of exactly when we call our continuation, and with what values. Let's explore some of the surprising power that gives us.

Deciding when to use k

We mentioned above that the point of using callCC in the first place was that it gave us extra power over what we threw into our continuation, and when. The following example shows how we might want to use this extra flexibility.

Example: Our first proper callCC function

foo :: Int -> Cont r String
foo n =
  callCC $ \k -> do
    let n' = n ^ 2 + 3
    when (n' > 20) $ k "over twenty"
    return (show $ n' - 4)

foo is a slightly pathological function that computes the square of its input and adds three; if the result of this computation is greater than 20, then we return from the function immediately, throwing the String value "over twenty" into the continuation that is passed to foo. If not, then we subtract four from our previous computation, show it, and throw it into the computation. If you're used to imperative languages, you can think of k like the 'return' statement that immediately exits the function. Of course, the advantages of an expressive language like Haskell are that k is just an ordinary first-class function, so you can pass it to other functions like when, or store it in a Reader, etc.

Naturally, you can embed calls to callCC within do-blocks:

Example: More developed callCC example involving a do-block

bar :: Char -> String -> Cont r Int
bar c s = do
  msg <- callCC $ \k -> do
    let s' = c : s
    when (s' == "hello") $ k "They say hello."
    let s'' = show s'
    return ("They appear to be saying " ++ s'')
  return (length msg)

When you call k with a value, the entire callCC call takes that value. In other words, k is a bit like a 'goto' statement in other languages: when we call k in our example, it pops the execution out to where you first called callCC, the msg <- callCC $ ... line. No more of the argument to callCC (the inner do-block) is executed. Hence the following example contains a useless line:

Example: Popping out a function, introducing a useless line

bar :: Cont r Int
bar = callCC $ \k -> do
  let n = 5
  k n
  return 25

bar will always return 5, and never 25, because we pop out of bar before getting to the return 25 line.

A note on typing

Why do we exit using return rather than k the second time within the foo example? It's to do with types. Firstly, we need to think about the type of k. We mentioned that we can throw something into k, and nothing after that call will get run (unless k is run conditionally, like when wrapped in a when). So the return type of k doesn't matter; we can never do anything with the result of running k. Actually, k never compute the continuation argument of return Cont-value of k. We say, therefore, that the type of k is:

k :: a -> Cont r b

Inside Cont r b, because k never computes that continuation, type b which is the parameter type of that continuation can be anything independent of type a. We universally quantify the return type of k. This is possible for the aforementioned reasons, and the reason it's advantageous is that we can do whatever we want with the result of computation inside k. In our above code, we use it as part of a when construct:

when :: Monad m => Bool -> m () -> m ()

As soon as the compiler sees k being used in this when, it infers that we want a () argument type of the continuation taking from the return value of k. The return Cont-value of k has type Cont r (). This argument type b is independent of the argument type a of k. [1]. The return Cont-value of k doesn't use the continuation which is argument of this Cont-value itself, it use the continuation which is argument of return Cont-value of the callCC. So that callCC has return type Cont r String. Because the final expression in inner do-block has type Cont r String, the inner do-block has type Cont r String. There are two possible execution routes: either the condition for the when succeeds, k doesn't use continuation providing by the inner do-block which finally takes the continuation which is argument of return Cont-value of the callCC, k uses directly the continuation which is argument of return Cont-value of the callCC, expressions inside do-block after k will totally not be used, because Haskell is lazy, unused expressions will not be executed. If the condition fails, the when returns return () which use the continuation providing by the inner do-block, so execution passes on.

If you didn't follow any of that, just make sure you use return at the end of a do-block inside a call to callCC, not k.

The type of callCC

We've deliberately broken a trend here: normally when we've introduced a function, we've given its type straight away, but in this case we haven't. The reason is simple: the type is rather horrendously complex, and it doesn't immediately give insight into what the function does, or how it works. Nevertheless, you should be familiar with it, so now you've hopefully understood the function itself, here's it's type:

callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a

This seems like a really weird type to begin with, so let's use a contrived example.

callCC $ \k -> k 5

You pass a function to callCC. This in turn takes a parameter, k, which is another function. k, as we remarked above, has the type:

k :: a -> Cont r b

The entire argument to callCC, then, is a function that takes something of the above type and returns Cont r t, where t is whatever the type of the argument to k was. So, callCC's argument has type:

(a -> Cont r b) -> Cont r a

Finally, callCC is therefore a function which takes that argument and returns its result. So the type of callCC is:

callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a

The implementation of callCC

So far we have looked at the use of callCC and its type. This just leaves its implementation, which is:

 callCC f = Cont $ \k -> runCont (f (\a -> Cont $ \_ -> k a)) k

This code is far from obvious. However, the amazing fact is that the implementations for callCC f, return n and m >>= f can all be produced automatically from their type signatures - Lennart Augustsson's Djinn [1] is a program that will do this for you. See Phil Gossett's Google tech talk: [2] for background on the theory behind Djinn; and Dan Piponi's article: [3] which uses Djinn in deriving Continuation Passing Style.

↑Jump back a section

Example: a complicated control structure

This example was originally taken from the 'The Continuation monad' section of the All about monads tutorial, used with permission.

Example: Using Cont for a complicated control structure

{- We use the continuation monad to perform "escapes" from code blocks.
   This function implements a complicated control structure to process
   numbers:

   Input (n)       Output                       List Shown
   =========       ======                       ==========
   0-9             n                            none
   10-199          number of digits in (n/2)    digits of (n/2)
   200-19999       n                            digits of (n/2)
   20000-1999999   (n/2) backwards              none
   >= 2000000      sum of digits of (n/2)       digits of (n/2)
-}  
fun :: Int -> String
fun n = (`runCont` id) $ do
        str <- callCC $ \exit1 -> do                        -- define "exit1"
          when (n < 10) (exit1 $ show n)
          let ns = map digitToInt (show $ n `div` 2)
          n' <- callCC $ \exit2 -> do                       -- define "exit2"
            when (length ns < 3) (exit2 $ length ns)
            when (length ns < 5) (exit2 n)
            when (length ns < 7) $ do 
              let ns' = map intToDigit (reverse ns)
              exit1 (dropWhile (=='0') ns')                 -- escape 2 levels
            return $ sum ns
          return $ "(ns = " ++ show ns ++ ") " ++ show n'
        return $ "Answer: " ++ str

Because it isn't initially clear what's going on, especially regarding the usage of callCC, we will explore this somewhat.

Analysis of the example

Firstly, we can see that fun is a function that takes an integer n. We basically implement a control structure using Cont and callCC that does different things based on the range that n falls in, as explained with the comment at the top of the function. Let's dive into the analysis of how it works.

  1. Firstly, the (`runCont` id) at the top just means that we run the Cont block that follows with a final continuation of id. This is necessary as the result type of fun doesn't mention Cont.
  2. We bind str to the result of the following callCC do-block:
    1. If n is less than 10, we exit straight away, just showing n.
    2. If not, we proceed. We construct a list, ns, of digits of n `div` 2.
    3. n' (an Int) gets bound to the result of the following inner callCC do-block.
      1. If length ns < 3, i.e., if n `div` 2 has less than 3 digits, we pop out of this inner do-block with the number of digits as the result.
      2. If n `div` 2 has less than 5 digits, we pop out of the inner do-block returning the original n.
      3. If n `div` 2 has less than 7 digits, we pop out of both the inner and outer do-blocks, with the result of the digits of n `div` 2 in reverse order (a String).
      4. Otherwise, we end the inner do-block, returning the sum of the digits of n `div` 2.
    4. We end this do-block, returning the String "(ns = X) Y", where X is ns, the digits of n `div` 2, and Y is the result from the inner do-block, n'.
  3. Finally, we return out of the entire function, with our result being the string "Answer: Z", where Z is the string we got from the callCC do-block.
↑Jump back a section

Example: exceptions

One use of continuations is to model exceptions. To do this, we hold on to two continuations: one that takes us out to the handler in case of an exception, and one that takes us to the post-handler code in case of a success. Here's a simple function that takes two numbers and does integer division on them, failing when the denominator is zero.

Example: An exception-throwing div

 divExcpt :: Int -> Int -> (String -> Cont r Int) -> Cont r Int
 divExcpt x y handler =
   callCC $ \ok -> do
     err <- callCC $ \notOk -> do
       when (y == 0) $ notOk "Denominator 0"
       ok $ x `div` y
     handler err
 
 {- For example,
 runCont (divExcpt 10 2 error) id  -->  5
 runCont (divExcpt 10 0 error) id  -->  *** Exception: Denominator 0
 -}

How does it work? We use two nested calls to callCC. The first labels a continuation that will be used when there's no problem. The second labels a continuation that will be used when we wish to throw an exception. If the denominator isn't 0, x `div` y is thrown into the ok continuation, so the execution pops right back out to the top level of divExcpt. If, however, we were passed a zero denominator, we throw an error message into the notOk continuation, which pops us out to the inner do-block, and that string gets assigned to err and given to handler.

A more general approach to handling exceptions can be seen with the following function. Pass a computation as the first parameter (which should be a function taking a continuation to the error handler) and an error handler as the second parameter. This example takes advantage of the generic MonadCont class which covers both Cont and ContT by default, plus any other continuation classes the user has defined.

Example: General try using continuations.

 tryCont :: MonadCont m => ((err -> m a) -> m a) -> (err -> m a) -> m a
 tryCont c h =
   callCC $ \ok -> do
     err <- callCC $ \notOk -> do x <- c notOk; ok x
     h err

For an example using try, see the following program.

Example: Using try

data SqrtException = LessThanZero deriving (Show, Eq)

sqrtIO :: (SqrtException -> ContT r IO ()) -> ContT r IO ()
sqrtIO throw = do 
  ln <- lift (putStr "Enter a number to sqrt: " >> readLn)
  when (ln < 0) (throw LessThanZero)
  lift $ print (sqrt ln)

main = runContT (tryCont sqrtIO (lift . print)) return
↑Jump back a section

Example: coroutines


↑Jump back a section

Notes

  1. Type a infers a monomorphic type because k is bound by a lambda expression, and things bound by lambdas always have monomorphic types. See Polymorphism.
Continuation passing style
Solutions to exercises
Advanced Haskell

Arrows50%.svg  >> Understanding arrows  >> Continuation passing style (CPS) 75%.svg  >> Zippers 75%.svg  >> Value recursion (MonadFix)  >> Applicative Functors 50%.svg  >> Monoids 25%.svg  >> Mutable objects 00%.svg  >> Concurrency 00%.svg


Haskell

Haskell Basics >> Elementary Haskell >> Intermediate Haskell >> Monads
Advanced Haskell >> Fun with Types >> Wider Theory >> Haskell Performance


Libraries Reference >> General Practices >> Specialised Tasks

↑Jump back a section
Last modified on 21 May 2012, at 07:35