Monads

2 minute read

A non-mathematical introduction to Monads.

Read Functor first!

Code working with databases, performing HTTP requests, publishing messages on queues, all return something like IO[A] in functional codebases.

Here is a real-world example, slightly modified to protect the guilty, making calls to remote services and queues.

def completeSubscription(token: String, id: UserId): F[Unit] =
  for {
    now      <- Clock[F].instant              // F[Instant]
    token    <- createToken(tokenString, now) // F[Token]
    _        <- activate(now, token, id) now) // F[Unit]
    _        <- entitleUser(id, token)        // F[Unit]
    _        <- sendSignupEmail(id, token)    // F[Unit]
  } yield ()

Remember that a for-comprehension in Scala is just syntactic sugar for map and flatMap.

Here is a simpler example:

def foo(x: F[Int], y: F[Int]): F[Int] =
  for {
    xv <- x
    yv <- y
  } yield xv + yv

which desugars to:

def sum(x: F[Int], y: F[Int]): F[Int] =
  x.flatMap(xv => y.map(yv => xv + yv))

sum and completeSubscription do not care about F itself, the so-called effect, they care about the values contained by it. The F[A] returned from each of the generators in the for-comprehension must provide flatMap and map, it doesn’t need anything else.

Monads, which are also Functors, are values that provide map and flatMap for an instance of an effect, F[_]. Here is sum rewritten using a Monad for any F[_] that has a Monad[F]:

def sum[F[_]](x: F[Int], y: F[Int])(using Monad[F]): F[Int] =
  for {
    xv <- x
    yv <- y
  } yield xv + yv

Writing a Monad

A Monad is a Functor and also provides the function described above for some F[_]:

trait Monad[F[_]] extends Functor[F] {
  extension [A, B](x: F[A])
    def flatMap(f: A => F[B]): F[B]
}

Here are example instances for Option and List:

given Monad[Option] with {
  extension [A, B](m: Option[A])
    override def map(f: A => B): Option[B] = m.map(f)
    override def flatMap(f: A => Option[B]): Option[B] = m.flatMap(f)
}

given Monad[List] with {
  extension [A, B](m: List[A])
    override def map(f: A => B): List[B] = m.map(f)
    override def map(f: A => List[B]): List[B] = m.flatMap(f)
}

Remember that Scala’s List and Option have flatMap functions already, but don’t confuse that with the Monad implementations.

Your own types can play too:

final case class Blub[A](v: A)

object Blub {
  given Monad[Blub] with {
    extension [A, B](blub: Blub[A])
      override def flatMap(f: A => Blub[B]): Blub[B] = f(blub.v)
  }
}

sum(Blub(1), Blub(2)) // Blub(3)

More

  1. A Monad is also a Functor
  2. Monad has three laws:
    1. left identity: (Monad[F].pure(x).flatMap(f)) === f(x)
    2. right identity: (m.flatMap(Monad[F].pure(_))) === m
    3. associativity: (m.flatMap(f)).flatMap(g) === m.flatMap(x => f(x).flatMap(g))
  3. Learn about this in pictures (Haskell)
  4. What we talk about when we talk about monads.pdf