Monadic stream functions are a really useful abstraction for writing interactive systems in a purely functional language. I’ve been playing around with them with the intent of programming robots in Haskell, but they’re also useful for simulations, video games, signal processing, and more. This post will be a short write-up of what I’ve discovered in trying various ways to encode MSF’s. By switching from a recursive, final encoding, to a non-recursive, initial encoding, we can drastically reduce code size and improve execution times. Most importantly, we can allow GHC’s optimizer to work across our combinators.

Background

Conceptually, a stream function SF a b (or signal function, depending on your mood), takes a stream of a’s, and produces a stream of b’s. More precisely, it alternates between first reading an a, then producing a b, then reading an a, and so on, forever.

In Functional reactive programming, refactored, the authors define stream functions using a final encoding. That is, they give a function, step, that runs the stream function for one tick:

step :: SF a b -> a -> (b, SF a b)

This takes the first input a, produces the first output b, and produces a continuation SF a b, to use on the next tick. They use this as the definition of the SF data type:

newtype SF a b = SF { step :: a -> (b, SF a b) }

A monadic stream function, or MSF, is like an SF, except we allow step to have side effects in some monad m:

newtype MSF m a b = MSF { step :: a -> m (b, MSF m a b) }

The monad parameter here is extermely useful for all sorts of things. My favorite one is setting m to Either e, so that at each step, we either produce an output and continuation as usual, or we “terminate early” with an e. That’s a topic for another post though, and for today, we won’t care what exactly m is.

Once we have that type definition, we can define some combinators for making MSF’s:

These are all defined in the dunai package, by the same authors.

The problem

Since our data definition is recursive, all of these combinators have recursive definitions as well. For example, composition is defined as

f . g = MSF $ \x -> do
  (y, g') <- step g x
  (z, f') <- step f y
  pure (z, f' . g')

While recursive definitions are elegent and simple, they are terrible for performance when used in a DSL like this. This is because GHC does not inline recursive definitions. Inlining is the main way GHC discovers optimizations, so any MSF combinators stop the optimizer dead in its tracks. This is pretty bad, since even trivial combinators like id need to be defined recursively:

id = MSF $ \x -> (x, id)

To see exactly what goes wrong, consider the following code snippet that computes sums of squares:

countS :: SF a Int
countS = feedback 0 $ arr $ \(_, c) -> let c' = c + 1 in (c', c')

sumS :: SF Int Int
sumS = feedback 0 $ arr $ \(x, s) -> let s' = x + s in (s', s')

sumSquaresS :: SF a Int
sumSquaresS = proc _ -> do
  n <- countS -< ()
  s <- sumS -< n * n
  returnA -< s

Using GHC 9.2.4 and -O2, this code alone produces 795 lines of core! Re-arranging it a bit and removing annotations, we get

sumSquaresS :: forall a. SF a Int
sumSquaresS = sumSquaresS1 `cast` <Co:7>

sumSquaresS1 :: forall {a}. a -> Identity (Int, MSF Identity a Int)
sumSquaresS1
  = \ (@a_abtP) ->
      \ (w3_ibwG :: a_abtP) ->
        $w$c.
          $fApplicativeIdentity
          $fMonadIdentity_$c>>=
          (sumSquaresS5 `cast` <Co:7>)
          (sumSquaresS2 `cast` <Co:7>)
          w3_ibwG

sumSquaresS2
  :: forall {a}. a -> Identity (((), ()), MSF Identity a ((), ()))
sumSquaresS2
  = \ (@a_abtP) (w2_X3 :: a_abtP) ->
      $w$carr ($fApplicativeIdentity3 `cast` <Co:9>) sumSquaresS3 w2_X3

-- ... snip ...

sumSquaresS14
  :: (((), ()), ())
     -> Identity ((Int, ()), MSF Identity (((), ()), ()) (Int, ()))
sumSquaresS14
  = \ (w2_ibw9 :: (((), ()), ())) ->
      case w2_ibw9 of { (ww6_ibwn, ww7_ibwo) ->
      $w$cfirst
        $fApplicativeIdentity
        $fMonadIdentity_$c>>=
        (sumSquaresS15 `cast` <Co:7>)
        ww6_ibwn
        ww7_ibwo
      }
	  
-- ... snip ...

sumSquaresS33 :: Int -> Identity (Int, MSF Identity Int Int)
sumSquaresS33
  = \ (w2_X3 :: Int) ->
      $w$carr ($fApplicativeIdentity3 `cast` <Co:9>) id w2_X3

Yes, GHC needed 33 helper functions just to compile that three line function. Most of them correspond to single calls to first, arr, and .. GHC can’t inline them, so it’s powerless to do any further optimization. Ouch!

The fix

To fix this, we’ll use an initial encoding, rather than a final encoding. In other words, instead of defining MSF’s in terms of how we run them (via step), we’ll define them in terms of how we construct them (via arrow combinators, liftK, and feedback), modulo some rules we know about which MSF’s should be equivalent to each other. Importantly, none of our definitions will be recursive. The only recursion will be done by the user, once at the top level, in order to repeatedly call step.

To define Arrow (and Category), we need at least id, ., arr, and first, in addition to liftK and feedback. Naively, we could just make these the data constructors in a GADT:

data MSF m a b where
  Id :: MSF m a a
  Compose :: MSF m c b -> MSF m a c -> MSF m a b
  Arr :: (a -> b) -> MSF m a b
  First :: MSF m a b -> MSF m (a, c) (b, c)
  LiftK :: Kleisli m a b -> MSF m a b
  Feedback :: c -> MSF m (a, c) (b, c) -> MSF m a b

Of course, we can simplify this to fewer constructors if we impose some laws. I won’t prove them all here, but none of them are very surprising. To start off, liftK should be an arrow homomorphism, meaning it commutes with all the arrow (and category) combinators:

id = liftK id
liftK f . liftK g = liftK (f . g)
arr f = liftK (arr f)
first (liftK f) = liftK (first f)

The arrow combinators on the left are all in MSF m, and the arrow combinators on the right are all in Kleisli m. From this, we learn two things:

We can, in fact, move them past feedback:

feedback c f . g = feedback c (f . first g)
f . feedback c g = feedback c (first f . g)
first (feedback c f) = feedback c (swap23 . first f . swap23)
  where swap23 = arr $ \(x, (y, z)) -> (x, (z, y))

That means every MSF can be written as some number feedback calls, wrapped around a call to liftK, i.e.

feedback c1 (feedback c2 ... (feedback cn (liftK f)) ... )

To finish things off, notice that we can always flatten nested calls to feedback into a single call, by merging the states into pairs:

feedback c (feedback d f) = feedback (d, c) (assoc . f . unassoc)
  where assoc = arr $ \((x, y), z) -> (x, (y, z))
        unassoc = arr $ \(x, (y, z)) -> ((x, y), z)

This means every MSF can be written as either feedback c (liftK f), or just liftK f. That leads to the following definition:

data MSF m a b where
  FeedbackLiftK :: c -> Kleisli m (a, c) (b, c) -> MSF m a b
  LiftK :: Kleisli m a b -> MSF m a b

We could go further and shave it down to just FeedbackLiftK, using () as the state if we don’t need feedback, but that actually leads to slower code and messier core, since we have to build up and break down a big tree of ()’s, one for each instance of liftK.

The nice thing about this definition is that in retrospect, it feels obvious. We’re just collecting together the states from all the feedback calls, and doing the other arrow operations in Kleisli. This is more or less what we would do if we had to translate our code manually. It’s theoretically nice and intuitively clear. And it will perform well, too, once we address this next point.

Well, strictly speaking…

There’s one detail we’ve been glossing over up to this point, and that’s strictness. One of the goals of arrowized FRP is that we can take a principled approach to avoiding space leaks. If we want to achieve that, we’d better sit down and figure out our strictness. Even if we have no space leaks, getting strictness right helps GHC generate more efficient code.

As a general principal, wherever we take state, we should default to being strict. Actually needing lazy state is pretty rare, and if the user wants it they can always wrap it in Solo or similar.

What does that mean in practice? Specifically,

The point of these rules is that they give a consistent specification of MSF’s, while ensuring we can force the state to WHNF at nearly every opportunity. Note that this is a departure from the definition in the paper above and in dunai.

Performance

With all these changes (and a few INLINE pragmas), the example above compiles to this:

sumSquaresS :: forall a. SF a Int
sumSquaresS
  = \ (@a_a1YZ) -> Feedback sumSquaresS2_r2aV sumSquaresS1_r2aU

sumSquaresS1_r2aU
  :: forall {a}. (a :*! (Int :!*! Int)) -> Int :*! (Int :!*! Int)
sumSquaresS1_r2aU
  = \ (@a_a1YZ) (x_i24z :: a_a1YZ :*! (Int :!*! Int)) ->
      case x_i24z of { :*! a1_i24r s1_i24s ->
      case s1_i24s of { :!*! s2_s2aK t_s2aL ->
      case t_s2aL of { I# ipv_s2aO ->
      case s2_s2aK of { I# ipv1_s2aR ->
      let {
        dt2_s29u :: Int#
        [LclId]
        dt2_s29u = +# 1# ipv_s2aO } in
      let {
        dt3_s29h :: Int#
        [LclId]
        dt3_s29h = +# (*# dt2_s29u dt2_s29u) ipv1_s2aR } in
      :*! (I# dt3_s29h) (:!*! (I# dt3_s29h) (I# dt2_s29u))
      }
      }
      }
      }

sumSquaresS2_r2aV :: Int :!*! Int
sumSquaresS2_r2aV = :!*! sumS2_r2aS sumS2_r2aS

sumS2_r2aS :: Int
sumS2_r2aS = I# 0#

This unpacks the state in a series of case expressions, does arithmetic with unboxed integers, and then packs the state back up. This is about as good as we can expect the core to get, short of unboxing the state entirely.

Here’s a microbenchmark where we step the above MSF 1,000,000 times, comparing the two approaches:

benchmarking step encoding
time                 767.4 ms   (679.6 ms .. 876.4 ms)
                     0.998 R²   (0.992 R² .. 1.000 R²)
mean                 856.6 ms   (816.4 ms .. 922.6 ms)
std dev              63.62 ms   (9.070 ms .. 83.67 ms)
variance introduced by outliers: 21% (moderately inflated)

benchmarking feedback/liftK encoding
time                 3.973 ms   (3.967 ms .. 3.981 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 3.990 ms   (3.984 ms .. 4.000 ms)
std dev              24.95 μs   (18.74 μs .. 34.96 μs)

On the one hand, this is a silly microbenchmark. On the other hand, holy crap it got 200 times faster!

Going further

In a future post, I’ll look at the case m = Either e in more detail. In this case, MSFs are a monad in the exception type e, which is extremely useful for sequencing computations together. However, with our definition above, (>>=) still needs to be implemented recursively. As it turns out, we’ll be saved by a free monad.