末尾再帰

Leanの do 記法によって forwhile などの伝統的な繰り返し構文が使えるようになりますが、裏ではこれらの構文は再帰関数の呼び出しに変換されています。ほとんどのプログラミング言語では、再帰関数はループに対して重要な欠点を持っています:ループはスタック上の領域を消費しませんが、再帰関数は再帰呼び出しの数に比例してスタック領域を消費します。スタック領域は一般的に限られているため、再帰関数として自然に表現されるアルゴリズムを、明示的に可変で割り当てられたヒープ領域でのループに書き直す必要がしばしばあります。

関数型プログラミングではその逆が一般的です。可変状態を持つループとして自然に表現されるプログラムはスタック領域を消費するかもしれませんが、再帰関数に書き換えれば高速に実行できます。これは関数型プログラミング言語の重要な側面によるものです:すなわち 末尾呼び出しの除去 (tail-call elimination)です。末尾呼び出しとはある関数から別の関数への呼び出しの中でも、呼び出し時に新しいスタックフレームをプッシュするのではなく、現在のスタックフレームに置き換えることで通常のジャンプにコンパイルできるものを指します。

末尾呼び出しの除去は単なるオプショナルな最適化ではありません。その存在は効率的な関数型コードを書くための基礎的な部分です。この機能が有用であるためには、信頼できるものでなければなりません。プログラマは確実に末尾呼び出しを特定できなければならず、コンパイラによる末尾呼び出しの除去が信頼できなければなりません。

関数 NonTail.sumNat のリストの内容を加算します:

def NonTail.sum : List Nat → Nat
  | [] => 0
  | x :: xs => x + sum xs

この関数をリスト [1, 2, 3] に適用すると、次のような評価のステップの流れになります:

NonTail.sum [1, 2, 3]
===>
1 + (NonTail.sum [2, 3])
===>
1 + (2 + (NonTail.sum [3]))
===>
1 + (2 + (3 + (NonTail.sum [])))
===>
1 + (2 + (3 + 0))
===>
1 + (2 + 3)
===>
1 + 5
===>
6

この評価ステップにおいて、括弧は NonTail.sum の再帰呼び出しを示しています。言い換えると、3つの数値を足すには、このプログラムは最初にリストが空でないことをチェックしなければなりません。リストの先頭(1)とリストの後続の和を足すには、まずリストの後続の和を計算する必要があります:

1 + (NonTail.sum [2, 3])

しかし、リストの後続の和を計算するためには、プログラムはそれが空かどうかをチェックしなければなりません。そしてこれは空ではなく、後続のリストの先頭は 2 です。上記の結果のステップでは NonTail.sum [3] の結果が返ることを待ちます:

1 + (2 + (NonTail.sum [3]))

実行時の呼び出しのスタックの要点は、値 123 とそれらを再帰呼び出しの結果に加算する命令を追跡することです。再帰呼び出しが完了すると、制御が呼び出しを行ったスタックフレームに戻り、加算の各ステップが実行されます。リストの先頭とそれらの加算の命令を格納した領域は解放されません;これはリストの長さに比例した領域を占めます。

関数 Tail.sumNat のリストの内容を加算します:

def Tail.sumHelper (soFar : Nat) : List Nat → Nat
  | [] => soFar
  | x :: xs => sumHelper (x + soFar) xs

def Tail.sum (xs : List Nat) : Nat :=
  Tail.sumHelper 0 xs

これをリスト [1, 2, 3] に適用すると、次のような評価の流れになります:

Tail.sum [1, 2, 3]
===>
Tail.sumHelper 0 [1, 2, 3]
===>
Tail.sumHelper (0 + 1) [2, 3]
===>
Tail.sumHelper 1 [2, 3]
===>
Tail.sumHelper (1 + 2) [3]
===>
Tail.sumHelper 3 [3]
===>
Tail.sumHelper (3 + 3) []
===>
Tail.sumHelper 6 []
===>
6

内部の補助関数は自分自身を再帰的に呼び出しますが、最終的な結果を計算するために何も覚えておく必要はありません。Tail.sumHelper の中間呼び出しは再帰呼び出しの結果をそのまま返すだけであるため、Tail.sumHelper が基本ケースに到達すると制御を直接 Tail.sum に戻すことができます。つまり、Tail.sumHelper を再帰的に呼び出すたびに1つのスタックフレームを再利用することができます。末尾呼び出しの除去とはまさにこのスタックフレームの再利用のことであり、Tail.sumHelper末尾再帰関数 と呼ばれます。

Tail.sumHelper の最初の引数にはコールスタックで追跡すべき情報がすべて含まれています。すなわち、遭遇してきた数値をすべて加算した値です。各再帰呼び出しでは、コールスタックに新しい情報を追加するのではなく、この引数が新しい情報で更新されます。コールスタックの情報を置き換える soFar のような引数は アキュムレータ (accumulator)と呼ばれます。

この記事を書いている時点と筆者のコンピュータでは、216,856以上の要素を持つリストを渡すと、NonTail.sum はスタックオーバーフローでクラッシュします。一方で、Tail.sum は100,000,000個の要素を持つリストでもスタックオーバーフローを起こすことなく合計を計算することができます。Tail.sum の実行中に新しいスタックフレームをプッシュする必要がないため、現在のリストを保持する可変な変数を持つ while ループと完全に等価です。再帰呼び出しの度に、スタック上の関数の引数はリストの次のノードに置き換えられます。

末尾位置と非末尾位置

Tail.sumHelper が末尾再帰である理由は、再帰呼び出しが 末尾の位置 にあるからです。非形式的には、関数呼び出しが末尾の位置にあるのは、呼び出し元が戻り値を変更する必要がなく、そのまま返す場合です。より形式的には、末尾位置は式に対して明示的に定義することができます。

もし match 式が末尾の位置であれば、その各ブランチも末尾の位置となります。一度 match がブランチを選択すると、制御はすぐにそのブランチに進みます。同様に、if 式が末尾にある場合、if 式の両方のブランチも末尾の位置となります。最後に、let 式が末尾にある場合は、その本体も末尾となります。

これ以外の他のすべての位置は末尾の位置とはなりません。関数やコンストラクタの引数は末尾の位置ではありません。なぜなら、評価は引数の値に適用される関数やコンストラクタを追跡しなければならないからです。内部関数の本体も末尾の位置ではありません。というのも制御が渡されないかもしれないからです:すなわち、関数の本体は関数が呼び出されるまで評価されないからです。同様に、関数型の本体も末尾の位置ではありません。E in (x : α) → E を評価するためには、結果として得られる型が (x : α) → ... で囲まれていることを追跡する必要があります。

NonTail.sum では、再帰呼び出しは + の引数であるため末尾の位置ではありません。Tail.sumHelper ではパターンマッチ(これ自体が関数の本体である)の直下であるため、この再帰呼び出しは末尾の位置です。

この記事を書いている時点では、Leanは再帰関数内の直接の末尾呼び出しのみを除去します。つまり、ある関数 f の定義における f への末尾呼び出しは除去されますが、他の関数 g への末尾呼び出しは除去されません。スタックフレームを節約して他の関数への末尾呼び出しを除去することは可能ですが、Leanではまだ実装されていません。

リストの反転

関数 NonTail.reverse は各サブリストの先頭を結果の末尾に追加することでリストを反転させます:

def NonTail.reverse : List α → List α
  | [] => []
  | x :: xs => reverse xs ++ [x]

これを [1, 2, 3] に使用して反転させると、次のような評価の流れになります:

NonTail.reverse [1, 2, 3]
===>
(NonTail.reverse [2, 3]) ++ [1]
===>
((NonTail.reverse [3]) ++ [2]) ++ [1]
===>
(((NonTail.reverse []) ++ [3]) ++ [2]) ++ [1]
===>
(([] ++ [3]) ++ [2]) ++ [1]
===>
([3] ++ [2]) ++ [1]
===>
[3, 2] ++ [1]
===>
[3, 2, 1]

末尾再帰版では各ステップでのアキュムレータに対して · ++ [x] の代わりに x :: · を用います:

def Tail.reverseHelper (soFar : List α) : List α → List α
  | [] => soFar
  | x :: xs => reverseHelper (x :: soFar) xs

def Tail.reverse (xs : List α) : List α :=
  Tail.reverseHelper [] xs

これは NonTail.reverse で計算している間に各スタックフレームに保存されたコンテキストが基本ケースに至って初めて適用されるためです。コンテキストの各「記憶された」断片は、後入れ先出しの順に実行されます。一方で、アキュムレータ渡し版では、以下の簡約ステップの流れからわかるように、元の基本ケースではなくリストの最初の要素からアキュムレータを更新します:

Tail.reverse [1, 2, 3]
===>
Tail.reverseHelper [] [1, 2, 3]
===>
Tail.reverseHelper [1] [2, 3]
===>
Tail.reverseHelper [2, 1] [3]
===>
Tail.reverseHelper [3, 2, 1] []
===>
[3, 2, 1]

つまり、非末尾再帰版は基本ケースから開始し、対象のリストを右から左に走査して再帰の結果を更新します。これらのリストの要素は先入先出の順序でアキュムレータに影響を与えます。アキュムレータを使用する末尾再帰版はリストの先頭から開始し、リストを左から右へ走査し、アキュムレータの初期値を更新します。

加算は可換であるため、Tail.sum ではこの点を考慮する必要はありません。リストの結合は可換ではないため、逆方向に実行しても同じ効果を持つ演算を見つけるように注意しなければなりません。NonTail.reverse の再帰の結果の後に [x] を追加することは、反転されたリストの先頭に x を追加することに似ています。

複数の再帰呼び出し

BinTree.mirror の定義では、2つの再帰呼び出しがあります:

def BinTree.mirror : BinTree α → BinTree α
  | .leaf => .leaf
  | .branch l x r => .branch (mirror r) x (mirror l)

命令型言語が reversesum のような関数にwhileループを使うように、この種の走査には再帰関数を使うことが一般的です。この関数はアキュムレータを渡すスタイルを使って末尾再帰的に書き換えることは簡単にはできません。

通常、各再帰ステップに1回より多い再帰呼び出しが必要な場合、アキュムレータを渡すスタイルを使用することは難しいです。この難しさは、再帰関数をループと明示的なデータ構造を使用するように書き換える難しさと似ており、さらに関数が終了することをLeanに納得させる複雑さも備わっています。しかし、BinTree.mirror のように複数の再帰呼び出しは、それ自体が複数回再帰的に出現するコンストラクタを持つデータ構造を示すことが多いです。このような場合、構造体の深さは全体のサイズに対して対数になることが多く、スタックとヒープのトレードオフが小さくなります。これらの関数を末尾再帰にするための体系的なテクニックとして 継続渡しスタイル (continuation-passing style)を使用するなどの方法がありますが、この章の範囲外です。

演習問題

以下の非末尾再帰関数をそれぞれアキュムレータを渡すスタイルの末尾再帰関数に変換してください:

def NonTail.length : List α → Nat
  | [] => 0
  | _ :: xs => NonTail.length xs + 1 
def NonTail.factorial : Nat → Nat
  | 0 => 1
  | n + 1 => factorial n * (n + 1)

NonTail.filter の変換では、末尾再帰によって一定のスタック領域と入力リストの長さに線形な時間を要するプログラムになるはずです。オリジナルに対して一定のオーバーヘッドは許容されます:

def NonTail.filter (p : α → Bool) : List α → List α
  | [] => []
  | x :: xs =>
    if p x then
      x :: filter p xs
    else
      filter p xs