Joachim Breitner

Don’t think, just defunctionalize

Published 2020-12-22 in sections English, Haskell.

TL;DR: CPS-conversion and defunctionalization can help you to come up with a constant-stack algorithm.

Update: Turns out I inadvertedly plagiarized the talk The Best Refactoring You’ve Never Heard Of by James Koppel. Please consider this a form of sincere flattery.

The starting point

Today, I’ll take you on a another little walk through the land of program transformations. Let’s begin with a simple binary tree, with value of unknown type in the leaves, as well as the canonical map function:

data T a = L a | B (T a) (T a)

map1 :: (a -> b) -> T a -> T b
map1 f (L x) = L (f x)
map1 f (B t1 t2) = B (map1 f t1) (map1 f t2)

As you can see, this map function is using the program stack as it traverses the tree. Our goal is now to come up with a map function that does not use the stack!

Why? Good question! In Haskell, there wouldn’t be a strong need for this, as the Haskell stack is allocated on the heap, just like your normal data, so there is plenty of stack space. But in other languages or environments, the stack space may have a hard limit, and it may be advised to not use unbounded stack space.

That aside, it’s a fun exercise, and that’s sufficient reason for me.

(In the following, I assume that tail-calls, i.e. those where a function end with another function call, but without modifying its result, do not actually use stack space. Once all recursive function calls are tail calls, the code is equivalent to an imperative loop, as we will see.)

Think?

We could now just stare at the problem (rather the code), and try to come up with a solution directly. We’d probably think “ok, as I go through the tree, I have to remember all the nodes above me… so I need a list of those nodes… and for each of these nodes, I also need to remember whether I am currently processing the left child, and yet have to look at the right one, or whether I am done with the left child… so what do I have to remember about the current node…?”

… ah, my brain spins already. Maybe eventually I figure it out, but why think when we can derive the solution? So let’s start with above map1, and rewrite it, in several, mechanical, steps into a stack-less, tail-recursive solution.

Go!

Before we set out, let me rewrite the map function using a local go helper, as follows:

map2 :: forall a b. (a -> b) -> T a -> T b
map2 f t = go t
  where
    go :: T a -> T b
    go (L x) = L (f x)
    go (B t1 t2) = B (go t1) (go t2)

This transformation (effectively the “static argument transformation”) has the nice advantage that we do not have to pass f around all the time, and that when we copy the function, I only have to change the top-level name, but not the names of the inner functions.

Also, I find it more aesthetically pleasing.

CPS

A blunt, effective tool to turn code that is not yet using tail-calls into code that only uses tail-calls is use continuation-passing style. If we have a function of type … -> t, we turn it into a function of type … -> (t -> r) -> r, where r is the type of the result we want at the very end. This means the function now receives an extra argument, often named k for continuation, and instead of returning some x, the function calls k x.

We can apply this to our go function. Here, both t and r happen to be T b; the type of finished trees:

map3 :: forall a b. (a -> b) -> T a -> T b
map3 f t = go t (\r -> r)
  where
    go :: T a -> (T b -> T b) -> T b
    go (L x) k  = k (L (f x))
    go (B t1 t2) k  = go t1 (\r1 -> go t2 (\r2 -> k (B r1 r2)))

Note that when initially call go, we pass the identity function (\r -> r) as the initial continuation.

Alas, suddenly all function calls are in tail position, and this codes does not use stack space! Technically, we are done, although it is not quite satisfying; all these lambdas floating around obscure the meaning of the code, are maybe a bit slow to execute, and also, we didn’t really learn much yet. This is certainly not the code we would have writing after “thinking hard”.

Defunctionalization

So let’s continue rewriting the code to something prettier, simpler. Something that does not use lambdas like this.

Again, there is a mechanical technique that can help it. It likely won't make the code prettier, but it will get rid of the lambdas, so let’s do that an clean up later.

The technique is called defunctionalization (because it replaces functional values by plain data values), and can be seen as a form of refinement.

Note that we pass around vales of type (T b -> T b), but we certainly don’t mean the full type (T b -> T b). Instead, only very specific values of that type occur in our program, So let us replace (T b -> T b) with a data type that contains representatives of just the values we actually use.

  1. We find at all values of type (T b -> T b). These are:

    • (\r -> r)
    • (\r1 -> go t2 (\r2 -> k (B r1 r2)))
    • (\r2 -> k (B r1 r2))
  2. We create a datatype with one constructor for each of these:

     data K = I | K1 | K2

    (This is not complete yet.)

  3. We introduce an interpretation function that turns a K back into a (T b -> T b):

    eval :: K -> (T b -> T b)
    eval = (* TBD *)
  4. In the function go, instead of taking a parameter of type (T b -> T b), we take a K. And when we actually use the continuation, we have to turn the K back to the function using eval:

    go :: T a -> K a b -> T b
    go (L x) k  = eval k (L (f x))
    go (B t1 t2) k = go t1 K1
    We also do this to the code fragments identified in the first step; these become:
    • (\r -> r)
    • (\r1 -> go t2 K2)
    • (\r2 -> eval k (B r1 r2))
  5. Now we complete the eval function: For each constructor, we simply map it to the corresponding lambda from step 1:

    eval :: K -> (T b -> T b)
    eval I = (\r -> r)
    eval K1 = (\r1 -> go t2 K2)
    eval K2 = (\r2 -> eval k (B r1 r2))
  6. This doesn’t quite work yet: We have variables on the right hand side that are not bound (t2, r1, k). So let’s add them to the constructors K1 and K2 as needed. This also changes the type K itself; it now needs to take type parameters.

This leads us to the following code:

data K a b
  = I
  | K1 (T a) (K a b)
  | K2 (T b) (K a b)

map4 :: forall a b. (a -> b) -> T a -> T b
map4 f t = go t I
  where
    go :: T a -> K a b -> T b
    go (L x) k  = eval k (L (f x))
    go (B t1 t2) k  = go t1 (K1 t2 k)

    eval :: K a b -> (T b -> T b)
    eval I = (\r -> r)
    eval (K1 t2 k) = (\r1 -> go t2 (K2 r1 k))
    eval (K2 r1 k) = (\r2 -> eval k (B r1 r2))

Not really cleaner or prettier, but everything is still tail-recursive, and we are now working with plain data.

We like lists

To clean it up a little bit, we can notice that the K data type really is just a list of values, where the values are either T a or T b. We do not need a custom data type for this! Instead of our K, we can just use the following, built from standard data types:

type K' a b = [Either (T a) (T b)]

Now I replace I with [], K1 t2 k with Left t2 : k and K2 r1 k with Right r1 : k. I also, very suggestively, rename go to down and eval to up:

map5 :: forall a b. (a -> b) -> T a -> T b
map5 f t = down t []
  where
    down :: T a -> K' a b -> T b
    down (L x) k  = up k (L (f x))
    down (B t1 t2) k  = down t1 (Left t2 : k)

    up :: K' a b -> T b -> T b
    up [] r = r
    up (Left  t2 : k) r1 = down t2 (Right r1 : k)
    up (Right r1 : k) r2 = up k (B r1 r2)

At this point, the code suddenly makes more sense again. In fact, I can try to verbalize it:

As we traverse the tree, we have to remember for all parent nodes, whether there is still something Left to do when we come back to it (so we remember a T a), or if we are done with that (so we have a T b). This is the list K' a b.

We begin to go down the left of the tree (noting that the right siblings are still left to do), until we hit a leaf. We transform the leaf, and then go up.

If we go up and hit the root, we are done. Else, if we go up and there is something Left to do, we remember the subtree that we just processed (as that is already in the Right form), and go down the other subtree. But if we go up and there is nothing Left to do, we put the two subtrees together and continue going up.

Quite neat!

The imperative loop

At this point we could stop: the code is pretty, makes sense, and has the properties we want. But let’s turn the dial a bit further and try to make it an imperative loop.

We know that if we have a single tail-recursive function, then that’s equivalent to a loop, with the function’s parameter turning into mutable variables. But we have two functions!

It turns out that if you have two functions a -> r and b -> r that have the same return type (which they necessarily have here, since we CPS-converted them further up), then those two functions are equivalent to a single function taking “a or b”, i.e. Either a b -> r. This really nothing else than the high-school level algebra rule of ra ⋅ rb = ra + b.

So (after reordering the arguments of down to put T b first) we can rewrite the code as

map6 :: forall a b. (a -> b) -> T a -> T b
map6 f t = go (Left t) []
  where
    go :: Either (T a) (T b) -> K' a b -> T b
    go (Left (L x))     k        = go (Right (L (f x))) k
    go (Left (B t1 t2)) k        = go (Left t1) (Left t2 : k)
    go (Right r)  []             = r
    go (Right r1) (Left  t2 : k) = go (Left t2) (Right r1 : k)
    go (Right r2) (Right r1 : k) = go (Right (B r1 r2)) k

Do you see the loop yet? If not, maybe it helps to compare it with the following equivalent imperative looking pseudo-code:

mapLoop :: forall a b. (a -> b) -> T a -> T b
mapLoop f t {
  var node = Left t;
  var parents = [];
  while (true) {
    switch (node) {
      Left (L x) -> node := Right (L (f x))
      Left (B t1 t2) -> node := Left t1; parents.push(Left t2)
      Right r1 -> {
        if (parents.len() == 0) {
          return r1;
        } else {
          switch (parents.pop()) {
            Left t2  -> node := Left t2; parents.push(Right r1);
            Right r2 -> node := Right (B r1 r2)
          }
        }
      }
    }
  }
}

Conclusion

I find it enlightening to see how apparently very different approaches to a problem (recursive, lazy functions and imperative loops) are connected by a series of rather mechanical transformations. When refactoring code, it is helpful to see if one can conceptualize the refactoring as one of those mechanical steps (refinement, type equivalences, defunctionalization, cps conversion etc.)

If you liked this post, you might enjoy my talk The many faces of isOrderedTree, which I have presented at MuniHac 2019 and Haskell Love 2020.

Comments

You probably are already familiar with this, but this idea goes back to “Definitional Interpreters for Higher-Order Languages” by Reynolds. His purpose was to avoid the defined interpreter depending on details of the language it's defined within, e.g. evaluation order.

It was then focused on Olivier Danvy and Darius Biernacki and others in the dual papers “A Functional Correspondence between Evaluators and Abstract Machines” and “From Interpreter to Compiler and Virtual Machine: a Functional Derivation”.

There were several papers using the same idea in different contexts throughout the years afterward which I collected here: http://lambda-the-ultimate.org/node/2423#comment-38384

I use this technique in a similar way to explain a GCC optimization, and here I go through an almost identical derivation as yours for the same reason. “Defunctionalization at Work” also by Danvy (and Nielsen) shows similar examples and others including the bidirectional nature of these transforms.

Finally, the result isn't “stack-less”. It's just that we've made the stack explicit, namely K' is the stack. It's, of course, well-known that that recursion can be eliminated with an explicit stack, and CPS transforming followed by defunctionalization is one of the nicer ways of showing this. Defunctionalizing CPSed code will always produce a type isomorphic to a list, i.e. a stack (or something simpler if there is no recursion). This won't be true if the original code had control operators like callCC.

At any rate, this transformation, which turns control structure into an explicit data structure, is indeed often very enlightening. It's also reasonably reversible allowing you to understand and/or reimplement low-level code in a high-level manner.

#1 Derek Elkins am 2020-12-22

I think it's instructive to at least briefly mention possible performance implications in the case of GHC/Haskell. One would expect that by playing clever tricks we could save at least something in running rime. In reality, map6 performs more than 2x worse than map1 on my (rather old) laptop with a (rather recent) GHC 8.10.2 with -O2 on a complete tree of height 23 and (+ 1) as the function I learned it the hard way when trying to optimize Haskell's binary-trees in the Benchmarks Game once…

#2 Artem Pelenitsyn am 2020-12-26

Thanks for your defunctionalising blog post, linked below.

Three thoughts:

  1. You might want to mention The Zipper (due to Huet). It’s a standard, repeatable way of doing the bit that made your head spin under “Think?”.

  2. Under “The imperative loop” I think you actually make the program worse. Before you have up and down, two loops that jump to each other – they are both join points. Super efficient. (I have not actually compiled it to check, but I’m reasonably sure.) But by using an Either as the parameter you add extra allocation and pattern matching to every iteration.

  3. Indeed SpecConstr may well precisely undo the change you have just made.

  4. If memory serves, Martin C. Henson’s book, Elements of Functional Programming, Blackwell Scientific Publications (1987) takes defunctionalisation as its central theme. It’s barely known but it’s an excellent book. I have a copy but I don’t know how available it is.

#3 Simon Peyton Jones am 2020-12-28

Thanks for these points!

I guess I should have mentioned the Zipper (but several commenters have since :-)).

I am not sure if it would really involve no thinking – a zipper allows the program to navigate efficiently into all directions, and (typically) has the same type of values everywhere. But the tree map function has to change value types, and there is a type-based distinction between “visited” and “to be visited” nodes. So that somehow needs to be taken into account here…

#4 Joachim Breitner am 2020-12-28

Have something to say? You can post a comment by sending an e-Mail to me at <mail@joachim-breitner.de>, and I will include it here.