Before we start, I want to use an example to illustrate why monad transformer is introduced and what problem it is trying to solve.
Sometimes we have some basic monads like IO
, Maybe
and State
. Each of them handles their own stuff. But what if I want to return something might not exists in a State
monad? Intuitively, we will write something like this.
run :: State (Maybe String) Int
run = do
s1 <- get
case s1 of
Nothing -> ...
Just x -> do
put $ Just "abcde"
s2 <- get
case s2 of
Nothing -> ...
Just x' -> ...
But the problem is every time we try to get the state, we have to do a pattern matching. The code is purely boilerplate because we already have Maybe
monad to continue calculation only if there is something. We want something that can combine the powers of different monads together. And that is what I am talking about today, monad transformers.
A monad transformer is a Monad
while is also being an instance of class MonadTrans
. Let us take a look of their type signatures first.
class Applicative m => Monad m where
return :: a -> m a
(>>=) :: m a -> (a -> m b) -> m b
class MonadTrans (t :: (* -> *) -> * -> *) where
lift :: Monad m => m a -> t m a
You may see that MonadTrans
takes a * -> *
and a *
then gives us a real Type
. It only has only method called lift
which is very interesting. It takes a monad m
(corresponding to * -> *
) then wrapped it into it self (t
) so we get a t m a
.
Typically, a monad transformer is also a monad. Which means we can do this
run = do
config <- ask
lift $ print config
This piece of code will lift
a Reader
into an IO
monad while let us easily get what is inside the monad. Wonderful!
ReaderT
is a very simple monad transformer. It takes a read-only config and pass along the monad while allow you to query the config with ask
.
import Control.Monad.Trans (MonadTrans, lift)
newtype ReaderT r m a = ReaderT {runReaderT :: r -> m a}
instance Functor m => Functor (ReaderT r m) where
fmap f rt = ReaderT $ fmap f . runReaderT rt
instance (Applicative m) => Applicative (ReaderT r m) where
pure x = ReaderT $ const (pure x)
f <*> rt = ReaderT $
\r -> runReaderT f r <*> runReaderT rt r
instance (Monad m) => Monad (ReaderT r m) where
return = pure
rt >>= f = ReaderT $
\r -> runReaderT rt r >>= (flip runReaderT r <$> f)
instance MonadTrans (ReaderT r) where
lift m = ReaderT $ const m
And we need to implement something to get our r
out.
ask :: (Monad m) => ReaderT r m r
ask = ReaderT return
Sometimes we just want use a single Reader
, so let us make Reader
be an type alias for apply ReaderT
on the Identity
monad.
import Control.Monad.Identity (Identity)
type Reader r = ReaderT r Identity
Now that we have ReaderT
defined, it is time to experiment with it a little bit.
run :: ReaderT Int (State String) ()
run = do
lift $ put "Hello"
config <- ask
case config of
0 -> return ()
n -> lift $ modify (++ replicate n 'a')
As you can see, we are able to read our config anywhere in the do notation block. But the cost is we have to put a lift at everything else which almost make our ReaderT
unusable. Is there a way to improve this?
The most oblivious solution is make an alias for lift put
.
putLifted = lift put
putLiftedLifted = lift putLifted
putLiftedLiftedLifted = lift putLiftedLifted
Quite tedious, right?
Let us check what put
really is.
class Monad m => MonadState s m | m -> s where
-- | Replace the state inside the monad.
put :: s -> m ()
...
So put
is actually a method for class MonadState
and may return any monad m ()
as a result. In order to make put return what we want, we shall make any MonadState
wrapped in ReaderT
also an instance of MonadState
.
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
instance MonadState s m => MonadState s (ReaderT r m) where
get = lift get
put s = lift $ put s
run' :: ReaderT Int (State String) ()
run' = do
put "Hello"
config <- ask
case config of
0 -> return ()
n -> modify (++ replicate n 'a')
Before we move on and write a MonadReader
for our ReaderT
. I want to explain a little what does | m -> s
means in the class definition.
The m -> s
is called a functional dependency in haskell. It tells ghc
that if there is a m
then there has and only has one s
that makes MonadState s m
being a valid instance. So ghc
will not try to infer what s
is once it knows what m
is. This is like to tell ghc
that I have a lemma that shows m -> s
, and then ghc
will try to prove and use this lemma when it checks your type. Otherwise in some situation when s
can not be inferred, ghc
will simply throw an error says it does not know what s
is.
Ok, enough for theories, let us code. We need to move our ask
into MonadReader
so if someone is use EitherT String (Reader String) ()
, they do not need to lift
the ask
as well.
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
class Monad m => MonadReader r m | m -> r where
ask :: m r
ask = reader id
reader :: (r -> a) -> m a
reader f = ask >>= \r -> return (f r)
-- Make `ReaderT` a `MonadReader`
instance Monad m => MonadReader r (ReaderT r m)
ask = ReaderT return
For any monad transformer we want to stack on ReaderT
, we may simply make it an instance of MonadReader
. Then ghc
will automatically lift
them for us.