Monads

July 18, 2015

I know that everyone has already written an introduction to Monads at some point. For a technical blog, it is either a right-of-passage, a sign of egomania, or perhaps even borderline spam. My goal in writing this post is not to claim that it is in some way superior to other, similar posts out there. Instead, it's to solidify the concepts in my own mind and to motivate the concept of a Monad from a perspective that is most effective to me. The goal of this post is to describe what a Monad is and why it's a useful thing to think about, and to do so for a person who doesn't program in a language like Haskell (where Monads are a first-class concept).

A Monad is really just a programming pattern that may be useful or may produce better code in certain situations. That's it, and there's nothing particularly magical about it. In Haskell, Monads have specific language syntax that makes it particularly useful, but such syntax magic isn't necessary to motivate or understand the Monad pattern. I've found that many people who write about Monads do so from the perspective of Haskell, whose syntax and emphasis on purity makes them core to the language, and so their usefulness is more clear. My goal is to avoid that perspective and motivate the concept of Monads as a useful programming pattern in and of itself.

Okay, having said all that, let's now forget the word Monad. It's not helpful here.

Imagine instead that you're writing a program in a language like Python or Java. In your language of choice, you write a function f. Let's suppose that the function f takes an object of one type and returns an object of that same type. (In a statically typed language like Java, one can enforce this constraint using the compiler. In a dynamic language like python, this isn't enforced by the language, but certainly one can write such a function and keep its signature in mind.) To be concrete, let's imagine that f takes a floating point value and returns a floating point value (for example, it may square that value):

Python:

def f(x):
    return x*x

Java:

public static float f(float x) {
    return x*x;
}

Simple. What can we do with f? It's clear that, because f takes and returns objects of the same type, we can call f twice in a row, or compose f with itself:

x = f(f(10.0))
float x = f(f(10.0));

In fact, we can do this as many times as we like, since we're always taking floats and returning floats:

x = f(f(f(f(f(f(10))))))
float x = f(f(f(f(f(f(10))))));

One can do this until their index finger gets tired of pressing "f". More generally, if I had several functions that all took floats and returned floats, I could freely compose those with each other and not worry about anything breaking (meaning, it would compile in Java and I wouldn't have any errors in python). At this point, may could introduce the concept of a Monoid (not a spelling mistake, it's a different thing than a Monad), but adding fancy names isn't necessary at this point (feel free to read about it yourself, though).

Now, let's imagine a slightly more complicated case. Imagine we have a function g that takes one type but returns another type (again, in python, type alignment isn't enforced by a compiler, but it's all there). One can think about this generally, but to be concrete, let's imagine a function g takes a float and returns a list of floating point numbers. To be really simple, let's just imagine it wraps the input float in a list:

def g(x):
    return [x]
public static List<Float> g(float x) {
    return new ArrayList<Float>(x);
}

This is just as simple before. However, because the input type and the output types don't line up, we can't simply compose g with itself. The following won't work:

x = g(g(10.0))
float x = g(g(10.0))

The python case will actually run, but it will create a list of lists, which isn't what we want in this case. The Java version won't even compile.

It's not entirely clear why this is interesting, but the goal is to start simple and see how certain patterns emerge, so bear with me.

If we did want to call g many times in a row, we could do the following:

tmp = g(10.0)
x = g(tmp[0])
Lis<Float> tmp = g(10.0);
float x = g(tmp.get(0));

To use g twice in a row, we have to add an extra step where we get the value out of the list before passing it to the function g. More generally, if one had many functions with the same signature as g that they wanted to compose, one would have to similarly take values out of lists and pass them to these various functions. To be concrete about that, let's extend our example by introducing the function h. h, like g, takes a floating point value and returns a list. If that value is greater than or equal to 0, it does what g does and puts that value into a list. But, if the value is less than 0, it returns an empty list:

def h(x):
    if x >= 0:
        return [x]
    else:
        return []
public static List<Float> h(float x){
    if (x >= 0) {
        return new ArrayList<Float>(x);
    } else {
        return new ArrayList<Float>();
    }
}

Let's say our goal was to call h on an value and then call g. We know that we can't simply do this:

g(h(10.0))

From before, we know that we have to get the result out of the return value of h before putting it into g. Following our example before, we can try writing:

tmp = h(10.0)
x = g(tmp[0]) 
List<Float> tmp = h(10.0);
float x = g(tmp.get(0));

This works too, but it's clearly broken when we try to use a negative value:

tmp = h(-10.0)
x = g(tmp[0]) 
List<Float> tmp = h(-10.0);
float x = g(tmp.get(0));

The function h will return an empty list, and we'll get an exception in both python and Java when we try to get the first element from that list. So, if we want to call h and then g on a value, we need to add more logic to make it work every time without throwing an exception. We can write the following:

tmp = h(-10.0)
if len(tmp) > 0:
    x = g(tmp[0]) 
else:
    tmp = []
List<Float> tmp = h(-10.0);
float x;
if (tmp.size() > 0) {
    x = g(tmp.get(0));
} else {
    x = new ArrayList<Float>();;
}

And that works, though it's a bit cumbersome. One thing to note is that it makes us explicitly state what we want to do when the tmp list is empty (we here assume that we want x to be the empty list). What we really wanted to do was write g(h(-10.0)). Instead, we had to write boilerplate to make that work.

That's not the end of the world, but as programmers, we like to look for patterns that can help us avoid boilerplate whenever possible. Perhaps this is an opportunity to decompose the above code into an abstract pattern.

Before doing so, let's quickly look at another simple example so we don't over-optimize for our specific functions g and h.

Let's instead imagine that we have a function that takes a String and returns both a lowercase version of the original string and an integer that represents the length of the input string:

def p(s):
    return (s.lower(), len(s))
public static class StringPair {
    public final String s;
    public final Integer length;
    public StringPair(String s, Integer length) {
        this.s = s;
        this.length = length;
    }
}

public static StringPair p(String s) {
    return new StringPair(s.toLowerCase(), s.size());

Because Java doesn't have native tuples, we created a simple class to hold the string and the integer. But, aside from that, both versions of p do the same thing.

Like before, we can't call p twice in a row:

p(p("foo"))

The function p returns a string pair, so we can to unpack the pair for this to work:

tmp = p("foo")
x = p(tmp[0])
StringPair tmp = p("foo");
StringPair x = p(tmp.s);

And, like our examples with g and h, this works. However, it seems that we lost information, as we're throwing out the length returned by the first call to p. What if we want to keep track of the total size of all strings that have gone through the p function? After one call to "foo" it should be 3, but if we call p twice in a row on "foo", it should return 6 (3+3). How can we accomplish this?

tmp = p("foo")
x = p(tmp[0])

my_string = x[0]
my_count = tmp[1] + x[1]
StringPair tmp = p("foo");
StringPair x = p(tmp.s);

String myString = x.s;
Integer myCount = tmp.length + x.length;

This works, but again, we added a lot of boilerplate. All we wanted to do was do something that sort of feels like p(p("foo")). But but we had to create all these intermediate variables and we had to keep unpacking data to actually get that to work. Similar to the example with g and h, we'd like to come up with a pattern that makes the above simpler and clearer.

So, what do these two examples have in common, and how can we try to clean them both up to better express their underlying intentions? In both cases, we want to do something that is conceptually similar to composing functions, but we needed to do extra work to unpack types and handle some additional data. What we really ended up doing was defining a special type of composition for each of our examples that has additional rules which define how we pass the data into the functions when composing them.

To compose our functions, we essentially had to answer the question:

Imagine we have an object of type B. How do we pass that object to a function f that takes an object of type A and returns an object of type B?

With our g, h example, we did this by taking the first (and only) float from the list and passing that float to a function. In our second example, we took a StringPair, passed the string component to the function, got back a new string pair, and then added their length fields.

If, for a group of functions with the same signature, we were able to come up with such a rule, then we could freely compose those functions. One can think of this process as defining a function that takes an object of type B and a function of type A to B and returns an object of type B. There are a number of conventional names for this type of function. Here, we're going to call it "bind". (We will later learn that defining a "bind" operation is one of the requirements for satisfying the Monad pattern, but again, that isn't relevant right now).

Let's define bind functions for the previous examples. Recall that bind takes (a function that takes an A and returns a B) and (an object of class B) and returns a B.

For our g, h example:

def bind(g, x):
    if len(x) > 0:
        return g(x[0])
    else:
        return []
public interface Function<T, U> {
    U call(T t);
}

public static List<Float> bind(Function<Float, List<Float>> g, List<Float> x) {
    if (x.size() > 0) {
        return f.call(x.get(0));
    } else {
        return new ArrayList<Float>();
    }
}

For our p example:

def bind(p, x):
    s, length = y
    y = p(s)
    return (y[0], y[1]+length)
public static StringPair bind(Function<String, StringPair> p, StringPair x) {
    StringPair y = p.call(x.s);
    return new StringPair(y.x, x.length + y.length);
}

We've successfully moved our boilerplate code into our new "bind" functions. This is a step in the right direction: encapsulating messy functionality and boilerplate is usually a good thing. We can now compose our g, h, and p functions in the following way:

bind(g, bind(h, bind(g, bind(h, [10.0]))));
bind(g, bind(h, bind(g, bind(h, new ArrayListList<Float>(10.0)))));

This looks a lot like the composition that we really want to write! The only somewhat annoying part is the expression in the inner-most nested function call where we have to wrap our raw value, 10.0, in a List. This is necessary because, as we saw before, our bind function's second parameter is List, and not a raw float.

We can wrap that action into another function that we'll here call "unit" ("unit" is not a good or clear name at all. I'll here admit that I'm giving it that name only due to conventions. I think "wrap" would be a better name for what we want it to do). Essentially, our new "unit" function takes an object of type A and returns an object of type B. For our first example, we can define unit as:

def unit(x):
    return [x]
public static List<Float> unit(float x) {
    return new ArrayList<Float>(x);
}

With these definitions, our composition looks like:

bind(g, bind(h, bind(g, bind(h, unit(10.0)))))
bind(g, bind(h, bind(g, bind(h, unit(10.0)))));

(It's satisfying that, at this point, the python and the java are exactly the same, other than a semi-colon. Actually, semi-colons are legal in python, so I could make them literally the same code. That's a good sign that our abstractions are moving in the right direction).

We're nearly done. The next step is to encapsulate our "bind" and "unit" functions into a class (since we may want to define many different "bind" and "unit" functions for different examples, such as g, h vs p). Let's go ahead and do it:

class M:
    def bind(...)
    def unit(...)
public static class M {
    public static List<Float> bind(Function<Float, List<Float>> f, List<Float> x) {
        ...
    }
    public static List<Float> unit(Float f) {
        ...
    }
}

As one can guess by it's name, "M" is a Monad (I'm not going to here go over the technical definition of a Monad. For the purposes of this discussion, a class with the "bind" and "unit" functions is a Monad).

This next step isn't strictly necessary, but it allows me to implement some final syntactic sugar. In addition to "bind" and "unit", I'm going to define a function called "comp". All "comp" will do is call either "bind" or "unit", depending on if the value is an A or a B. For our first example, this means either calling bind or unit depending on if the value is a list or a float:

def comp(g, x):
    if isinstance(x, list):
        return bind(g, x)
    else:
        return bind(g, unit(x))
public static List<Float> comp(Function<Float, List<Float>> f, List<Float> x) {
    return bind(f, x);
}
public static List<Float> comp(Function<Float, List<Float>> f, Float x) {
    return bind(f, unit(x));
}

Finally, I can now write the full composition with this comp function:

comp(h, comp(g, comp(h, comp(g, 10.0))))
comp(h, comp(g, comp(h, comp(g, 10.0))));

While some of what we did may seen arbitrary, it ended up allowing us to write the composition we wanted in the above style. One can probably do better syntactically to make the composition prettier, but this is as far as I'll go, as I think it makes the point clear. By expressing our "special composition" into the "bind" and "unit" functions, we are able to compose our g and h functions freely together (leveraging our "comp" function). With this syntax, it's easy to see that we're just composing functions together, which is more difficult to see with all the boilerplate code that we managed to remove.

So, once again, why is this a useful thing to have done, or to think about? If you ask me, I think the most useful part of this exercise is in flexing our abstraction powers: to learn how to take a few examples of messy, boilerplate code and to be able to merge it into simpler constructs that better express the intent of what we want. But, more tangibly, we've managed to separate the "what" (composing multiple functions) with the "how". We can write the "how" in one place and then freely compose any functions that have our desired signature. This all seems like a lot of work in the case where g and h are so simple, but in more complicated cases, being able to write f, g, and M separately and later combine them allows us to write business logic more clearly, and that's really one of the primary goals of any abstraction.

So, what then is a Monad? It's a set of rules for composing functions that can't normally be composed because they return a different type than they take. Okay, there's a lot more complexity to Monads, but I think that definition gets to the heart of what they are and why they're useful. Monads are useful because function composition is easy to write and understand. Monads take potentially complicated interactions between functions and making them simpler and easier to reason about.