First off, keep in mind that the
State st a type is a function type (
st -> (st, a)) and that
returnState a wraps
a into a function that returns a state-value,value pair when it’s applied to a state-value.
bindState is a bit complicated to think about. It takes two arguments, both functions. The first (
m) is of type
st -> (st, a) (i.e.
State st a). The second (
k) is of type
a -> (st -> (st, b)) (i.e.
a-> (State st b)). In other words, this second argument is a function that takes a value and returns a function that we can apply to a state, the return function has “wrapped up” in it the end result of the computation. Looking more into the implementation of
bindState we see that it returns a function taking one argument (
m is used to calculate a new state-value,value (
a) pair based on
k is used to “wrap”
a into a function (
m') that we can apply a state_value to. The last step is to apply the new state (
st') to the function wrapping the new value (
(Leaf a) we see that
f a is
\b -> returnState (Leaf b), from our look at
bindState we can see that
b here will be our newly calculated value (i.e. the new value coming out of
f a will be wrapped into a leaf again).
I found it easier to understand when I applied an “identity function” to a tree:
doNothing :: Integer -> State Integer Integer doNothing a = returnState a
Working through, on paper, what goes on when applying
doNothing, a tree of one leaf, and an initial state of 0 made it quite clear what is happening.
Analysing the recursive step after this is fairly straight forward. Adding parentheses makes it a bit easier to read and I found it helpful to rewrite it to only consider the left branches:
mapTreeStateMLO :: (a -> State st Integer) -> Tree a -> State st (Tree Integer) mapTreeStateMLO f (Leaf a) = (f a) `bindState` (\b -> returnState (Leaf b)) mapTreeStateMLO f (Branch lhs rhs) = (mapTreeStateMLO f lhs) `bindState` (\lhs' -> returnState (Branch lhs' (Leaf (-1))))
The extension to right branches comes fairly naturally after that.
I wrote a few functions to play with
mapTreeStateM. All of them use a single integer to represent its state.
State is the number of leaves in the tree:
countLeaf :: Integer -> State Integer Integer countLeaf a = \st -> (st + 1, a)
State is the maximum value of a leaf
maxLeaf :: Integer -> State Integer Integer maxLeaf a = \st -> (max st a, a)
This function that numbers each leaf in the tree.
numberLeaf :: Integer -> State Integer (Integer, Integer) numberLeaf a = \st -> (st + 1, (st, a))