"First Class" Types in Scala 3

I'm recently migrating some libs and projects to Scala 3, I guess it would be very helpful to me or anyone interested to learn some new functional programming features that Scala 3 is bringing to us.

Source code 👉 https://github.com/jcouyang/meow

First Class Types

In Idris, types are first class, meaning that they can be computed and manipulated (and passed to functions) just like any other language construct. https://idris2.readthedocs.io/en/latest/tutorial/typesfuns.html#first-class-types

In Scala 3, types can be computed at certain level, but can not be passed to or returned by function. It is pretty close but not first class yet.

Let us compare side by side with Idris' examples.

Type level function

In Idris, types can be input and output of a normal function:

isSingleton : Bool -> Type
isSingleton True = Nat
isSingleton False = List Nat

True, False, Nat, List Nat here are all types, but they are all in the position of values can be in a normal function isSingleton.

In Scala, types cannot be passed into function, but in Scala 3, there is a new feature called Match Types https://dotty.epfl.ch/docs/reference/new-types/match-types.html .

type IsSingleton[X <: Boolean] = X match {
  case true => Int
  case false => List[Int]

Clearly IsSingleton is not a function in Scala, but it is kind of doing the similar thing that isSingleton does in Idris.

Notice that true and false here are not value, they are singleton types.

val singletonBool: true = true
val singletonInt: 1 = 1

Next let us translate another Idris example into Scala 3:

sum : (single : Bool) -> isSingleton single -> Nat
sum True x = x
sum False [] = 0
sum False (x :: xs) = x + sum False xs

In the type signature, it is super cool that we can call a function isSingleton directly. This is the exactly place where you can feel that types are first class and just like values in Idris.

It is a bit verbose in Scala 3:

1: def sum(single: Boolean, x: IsSingleton[single.type]): Int = (single, x) match {
2:   case (true, x: IsSingleton[true]) =>  x                // <- (verbose1)
3:   case (false, Nil) => 0
4:   case (false, ((x:Int)::(xs: IsSingleton[false]))) => { // <- (verbose2)
5:     sum(false, xs) + x
6:   } 
7: }

By using Matching Types IsSingleton, we can compute the types of x, but we cannot feel or say that types here are first class since IsSingleton is not a function. Matching types can do some sort of computation but not as power as function.

Apart from type level function, there are few places are kind of verbose in Scala 3.

Verbose 1:

have to manually give compiler a hint that x has type IsSingleton[true], otherwise compile error:

[E]      Found:    (x : DependentTypes2.this.IsSingleton[(single : Boolean)])
[E]      Required: Int

clearly Scala compiler is not as good at proving as Idris that x is IsSingleton[true] since single is true.

Verbose 2:

have to manually give compiler a hint that x has type Int and xs has type IsSingletype[false]

Although a bit clumsy, it still works, everything computed at type level correctly.

Typelevel Ops

I really hope you still remember Vector:

data Nat    = Z   | S Nat 
data Vect : Nat -> Type -> Type where
   Nil  : Vect Z a
   (::) : a -> Vect k a -> Vect (S k) a

Simply translate it to Scala 3:

import scala.compiletime.{S}
enum Vector[Nat, +A] {
  case Nil extends Vector[0, Nothing]
  case Cons[N <: Int, AA](head: AA, tail: Vector[N, AA]) extends Vector[S[N], AA]

We don't really need to define Nat in Scala 3, there is singleton Int type we can use. And we can also use S combinator for Int just like S in Idris.

The real challenge here is to compute the length at type level, which is very easy to achieve in Idris since types can be passed into function.

Since length is Nat, we can simply define + function for Nat:

(+) : Nat -> Nat -> Nat
(+) Z     y = y
(+) (S k) y = S (+ k y)

Next to combine two Vect and compute length at type level is just:

(++) : Vect n a -> Vect m a -> Vect (n + m) a
(++) Nil       ys = ys
(++) (x :: xs) ys = x :: xs ++ ys

We just did it again, calling a function (n + m) in type signature.

Ideally we should be able to translate in Scala 3:

import scala.compiletime.ops.int.{+}
def combine[N <:Int, M<:Int,A](a: Vector[N, A], b: Vector[M, A]): Vector[N + M, A] =
  (a, b) match
    case (Nil, b) => b
    case Cons(head, tail) => Cons(head, combine(tail, b))

Type level operator + here is from compiletime.ops, it calculates singleton Int types at compile time. https://dotty.epfl.ch/docs/reference/metaprogramming/inline.html#the-scalacompiletimeops-package For instance:

val x: 1 + 2 * 3 = 7

But it is a crude real world…

[E]      Found:    (b : DependentTypes2.this.Vector[M, A])
[E]      Required: DependentTypes2.this.Vector[N + M, A]
[E]      L61:     case (Nil, b) => b
[E]      Found:    DependentTypes2.this.Vector[scala.compiletime.S[Int + M], Any]
[E]      Required: DependentTypes2.this.Vector[N + M, A]
[E]      L62:     case Cons(head, tail) => Cons(head, combine(tail, b))

Scala cannot infer return type should be Vector[0+M, A] since Nil has type Vector[0, A].

Also compiler cannot prove N - 1 + M + 1 is N + M.

I wouldn't suggest doing this on production, but we need to help Scala prove that the type is correct.

def combine[N <:Int, M<:Int,A](a: Vector[N, A], b: Vector[M, A]): Vector[N + M, A] =
  (a, b) match
    case (Nil, b) => b.asInstanceOf[Vector[N+M, A]]
    case (Cons(head: A, tail: Vector[N-1, A]), b) =>
      (Cons[N -1 + M, A](head,  combine(tail, b))).asInstanceOf[Vector[N+M, A]]

Overall I think it is great that Scala 3 provides Match Types and compiletime.ops, which make it possible to compute types just like values without using any shapeless, although types are yet first class.