Reacting quicker: optimizing monadic stream functions
Monadic stream functions are a really useful abstraction for writing interactive systems in a purely functional language. I’ve been playing around with them with the intent of programming robots in Haskell, but they’re also useful for simulations, video games, signal processing, and more. This post will be a short write-up of what I’ve discovered in trying various ways to encode MSF’s. By switching from a recursive, final encoding, to a non-recursive, initial encoding, we can drastically reduce code size and improve execution times. Most importantly, we can allow GHC’s optimizer to work across our combinators.
Background
Conceptually, a stream function SF a b
(or signal function,
depending on your mood), takes a stream of a
’s, and produces a
stream of b
’s. More precisely, it alternates between first reading
an a
, then producing a b
, then reading an a
, and so on,
forever.
In Functional reactive programming,
refactored, the authors define stream functions
using a final encoding. That is, they give a function, step
, that
runs the stream function for one tick:
step :: SF a b -> a -> (b, SF a b)
This takes the first input a
, produces the first output b
, and
produces a continuation SF a b
, to use on the next tick. They use
this as the definition of the SF
data type:
newtype SF a b = SF { step :: a -> (b, SF a b) }
A monadic stream function, or MSF
, is like an SF
, except we allow
step
to have side effects in some monad m
:
newtype MSF m a b = MSF { step :: a -> m (b, MSF m a b) }
The monad parameter here is extermely useful for all sorts of things.
My favorite one is setting m
to Either e
, so that at each step, we
either produce an output and continuation as usual, or we “terminate
early” with an e
. That’s a topic for another post though, and for
today, we won’t care what exactly m
is.
Once we have that type definition, we can define some combinators for making MSF’s:
- An instance
Arrow (MSF m)
, for combining compuations together. - A function for lifting from the Kleisli arrow
a -> m b
into an MSF. The paper above calls thisarrM
, but I’ll call itliftK
, and give it typeKleisli m a b -> MSF m a b
. This makes some code below cleaner. - A function
feedback :: c -> MSF m (a, c) (b, c) -> MSF m a b
, which adds state of typec
, taking an initial value for the state and an MSF that updates it.
These are all defined in the dunai package, by the same authors.
The problem
Since our data definition is recursive, all of these combinators have recursive definitions as well. For example, composition is defined as
. g = MSF $ \x -> do
f <- step g x
(y, g') <- step f y
(z, f') pure (z, f' . g')
While recursive definitions are elegent and simple, they are terrible
for performance when used in a DSL like this. This is because GHC
does not inline recursive definitions. Inlining is the main way GHC
discovers optimizations, so any MSF combinators stop the optimizer
dead in its tracks. This is pretty bad, since even trivial
combinators like id
need to be defined recursively:
id = MSF $ \x -> (x, id)
To see exactly what goes wrong, consider the following code snippet that computes sums of squares:
countS :: SF a Int
= feedback 0 $ arr $ \(_, c) -> let c' = c + 1 in (c', c')
countS
sumS :: SF Int Int
= feedback 0 $ arr $ \(x, s) -> let s' = x + s in (s', s')
sumS
sumSquaresS :: SF a Int
= proc _ -> do
sumSquaresS <- countS -< ()
n <- sumS -< n * n
s -< s returnA
Using GHC 9.2.4 and -O2
, this code alone produces 795 lines of core!
Re-arranging it a bit and removing annotations, we get
sumSquaresS :: forall a. SF a Int
= sumSquaresS1 `cast` <Co:7>
sumSquaresS
sumSquaresS1 :: forall {a}. a -> Identity (Int, MSF Identity a Int)
sumSquaresS1= \ (@a_abtP) ->
w3_ibwG :: a_abtP) ->
\ ($w$c.
$fApplicativeIdentity
$fMonadIdentity_$c>>=
`cast` <Co:7>)
(sumSquaresS5 `cast` <Co:7>)
(sumSquaresS2
w3_ibwG
sumSquaresS2 :: forall {a}. a -> Identity (((), ()), MSF Identity a ((), ()))
sumSquaresS2= \ (@a_abtP) (w2_X3 :: a_abtP) ->
$w$carr ($fApplicativeIdentity3 `cast` <Co:9>) sumSquaresS3 w2_X3
-- ... snip ...
sumSquaresS14 :: (((), ()), ())
-> Identity ((Int, ()), MSF Identity (((), ()), ()) (Int, ()))
sumSquaresS14= \ (w2_ibw9 :: (((), ()), ())) ->
case w2_ibw9 of { (ww6_ibwn, ww7_ibwo) ->
$w$cfirst
$fApplicativeIdentity
$fMonadIdentity_$c>>=
`cast` <Co:7>)
(sumSquaresS15
ww6_ibwn
ww7_ibwo
}
-- ... snip ...
sumSquaresS33 :: Int -> Identity (Int, MSF Identity Int Int)
sumSquaresS33= \ (w2_X3 :: Int) ->
$w$carr ($fApplicativeIdentity3 `cast` <Co:9>) id w2_X3
Yes, GHC needed 33 helper functions just to compile that three line
function. Most of them correspond to single calls to first
, arr
,
and .
. GHC can’t inline them, so it’s powerless to do any further
optimization. Ouch!
The fix
To fix this, we’ll use an initial encoding, rather than a final
encoding. In other words, instead of defining MSF’s in terms of how
we run them (via step
), we’ll define them in terms of how we
construct them (via arrow combinators, liftK
, and feedback
),
modulo some rules we know about which MSF’s should be equivalent to
each other. Importantly, none of our definitions will be recursive.
The only recursion will be done by the user, once at the top level,
in order to repeatedly call step
.
To define Arrow
(and Category
), we need at least id
, .
, arr
,
and first
, in addition to liftK
and feedback
. Naively, we could
just make these the data constructors in a GADT:
data MSF m a b where
Id :: MSF m a a
Compose :: MSF m c b -> MSF m a c -> MSF m a b
Arr :: (a -> b) -> MSF m a b
First :: MSF m a b -> MSF m (a, c) (b, c)
LiftK :: Kleisli m a b -> MSF m a b
Feedback :: c -> MSF m (a, c) (b, c) -> MSF m a b
Of course, we can simplify this to fewer constructors if we impose
some laws. I won’t prove them all here, but none of them are very
surprising. To start off, liftK
should be an arrow homomorphism,
meaning it commutes with all the arrow (and category) combinators:
id = liftK id
. liftK g = liftK (f . g)
liftK f = liftK (arr f)
arr f = liftK (first f) first (liftK f)
The arrow combinators on the left are all in MSF m
, and the arrow
combinators on the right are all in Kleisli m
. From this, we learn
two things:
- We can write
id
andarr
in terms ofliftK
, so we don’t need those two constructors, and - we can move
.
andfirst
pastliftK
, where they become.
andfirst
inKleisli m
. If we can move these two combinators pastfeedback
as well, then we can remove them entirely.
We can, in fact, move them past feedback
:
. g = feedback c (f . first g)
feedback c f . feedback c g = feedback c (first f . g)
f = feedback c (swap23 . first f . swap23)
first (feedback c f) where swap23 = arr $ \(x, (y, z)) -> (x, (z, y))
That means every MSF can be written as some number feedback
calls,
wrapped around a call to liftK
, i.e.
... (feedback cn (liftK f)) ... ) feedback c1 (feedback c2
To finish things off, notice that we can always flatten nested calls to feedback into a single call, by merging the states into pairs:
= feedback (d, c) (assoc . f . unassoc)
feedback c (feedback d f) where assoc = arr $ \((x, y), z) -> (x, (y, z))
= arr $ \(x, (y, z)) -> ((x, y), z) unassoc
This means every MSF can be written as either feedback c (liftK f)
,
or just liftK f
. That leads to the following definition:
data MSF m a b where
FeedbackLiftK :: c -> Kleisli m (a, c) (b, c) -> MSF m a b
LiftK :: Kleisli m a b -> MSF m a b
We could go further and shave it down to just FeedbackLiftK
, using
()
as the state if we don’t need feedback, but that actually leads
to slower code and messier core, since we have to build up and break
down a big tree of ()
’s, one for each instance of liftK
.
The nice thing about this definition is that in retrospect, it feels
obvious. We’re just collecting together the states from all the
feedback
calls, and doing the other arrow operations in Kleisli
.
This is more or less what we would do if we had to translate our code
manually. It’s theoretically nice and intuitively clear. And it will
perform well, too, once we address this next point.
Well, strictly speaking…
There’s one detail we’ve been glossing over up to this point, and that’s strictness. One of the goals of arrowized FRP is that we can take a principled approach to avoiding space leaks. If we want to achieve that, we’d better sit down and figure out our strictness. Even if we have no space leaks, getting strictness right helps GHC generate more efficient code.
As a general principal, wherever we take state, we should default to
being strict. Actually needing lazy state is pretty rare, and if the
user wants it they can always wrap it in Solo
or similar.
What does that mean in practice? Specifically,
feedback
is strict in the initial state, sofeedback ⊥ f = ⊥
,- the
.
andfirst
combinators are strict in both inputs, and step
never returns(x, ⊥)
. In other words, either the second component is defined, or the whole thing is⊥
.
The point of these rules is that they give a consistent specification
of MSF’s, while ensuring we can force the state to WHNF at nearly
every opportunity. Note that this is a departure from the definition
in the paper above and in dunai
.
Performance
With all these changes (and a few INLINE
pragmas), the example above
compiles to this:
sumSquaresS :: forall a. SF a Int
sumSquaresS= \ (@a_a1YZ) -> Feedback sumSquaresS2_r2aV sumSquaresS1_r2aU
sumSquaresS1_r2aU :: forall {a}. (a :*! (Int :!*! Int)) -> Int :*! (Int :!*! Int)
sumSquaresS1_r2aU= \ (@a_a1YZ) (x_i24z :: a_a1YZ :*! (Int :!*! Int)) ->
case x_i24z of { :*! a1_i24r s1_i24s ->
case s1_i24s of { :!*! s2_s2aK t_s2aL ->
case t_s2aL of { I# ipv_s2aO ->
case s2_s2aK of { I# ipv1_s2aR ->
let {
dt2_s29u :: Int#
LclId]
[= +# 1# ipv_s2aO } in
dt2_s29u let {
dt3_s29h :: Int#
LclId]
[= +# (*# dt2_s29u dt2_s29u) ipv1_s2aR } in
dt3_s29h :*! (I# dt3_s29h) (:!*! (I# dt3_s29h) (I# dt2_s29u))
}
}
}
}
sumSquaresS2_r2aV :: Int :!*! Int
= :!*! sumS2_r2aS sumS2_r2aS
sumSquaresS2_r2aV
sumS2_r2aS :: Int
= I# 0# sumS2_r2aS
This unpacks the state in a series of case expressions, does arithmetic with unboxed integers, and then packs the state back up. This is about as good as we can expect the core to get, short of unboxing the state entirely.
Here’s a microbenchmark where we step the above MSF 1,000,000 times, comparing the two approaches:
benchmarking step encoding
time 767.4 ms (679.6 ms .. 876.4 ms)
0.998 R² (0.992 R² .. 1.000 R²)
mean 856.6 ms (816.4 ms .. 922.6 ms)
std dev 63.62 ms (9.070 ms .. 83.67 ms)
variance introduced by outliers: 21% (moderately inflated)
benchmarking feedback/liftK encoding
time 3.973 ms (3.967 ms .. 3.981 ms)
1.000 R² (1.000 R² .. 1.000 R²)
mean 3.990 ms (3.984 ms .. 4.000 ms)
std dev 24.95 μs (18.74 μs .. 34.96 μs)
On the one hand, this is a silly microbenchmark. On the other hand, holy crap it got 200 times faster!
Going further
In a future post, I’ll look at the case m = Either e
in more detail.
In this case, MSFs are a monad in the exception type e
, which is
extremely useful for sequencing computations together. However, with
our definition above, (>>=)
still needs to be implemented
recursively. As it turns out, we’ll be saved by a free monad.