続・Javaで継続モナド

少し使い方が分かったら書いてみる。
継続モナドCpsクラスは前回のものベースにする。


この辺がとても参考になった。ほぼ写経……*1

breakとcontinue

for文の中でbreakを使うとループから抜ける。ということでbreakは大域脱出なのでcallCCで書ける。
例として、Integerのリストを順番に見て行きつつプリントし、負の数が来たらループを抜けるコードを書いてみる。

    // とりあえずrunするためのダミー関数
    Function<Void, Void> ignore = new Function<Void, Void>() {
        public Void apply(Void value) {
            return null;
        }
    };

    // ループの外側でcallCC()
    Cps.callCC(new Function<Function<Void, Cps<Void, Void>>, Cps<Void, Void>>(){
        public Cps<Void, Void> apply(final Function<Void, Cps<Void, Void>> break_) {
            // break_に大域脱出用の関数が渡ってくる

            // 要素ごと処理用の関数をbindしていくことでループを表現
            // 初期値はダミー用のCps
            Cps<Void, Void> current = Cps.unit((Void) null);
            for (final Integer n : Arrays.asList(1, 2, 3, -1, 3, 0)) {
                current = current.bind(
                    new Function<Void, Cps<Void, Void>>(){
                        public Cps<Void, Void> apply(Void value) {
                            if (n < 0) {
                                // 負の数が来たらbreak_を呼んで結果を戻す
                                return break_.apply(null);
                            } else {
                                // それ以外なら数値をプリントして続行
                                System.out.println(n);
                                return Cps.unit((Void) null);
                            }
                        }
                    }
                );
            }
            return current;
        }
    }).run(ignore);

実行結果

1
2
3

大域脱出だけなら前回も書いたので、次は、continueを実装してみる。
合成が簡単になるようにcombine()(Haskellでいう>>)を定義しておく。

public class Cps<T, R> {
...
    // 今の値は無視して引数のcを戻すCpsを戻す。
    public <S> Cps<S, R> combine(final Cps<S, R> c) {
        return bind(new Function<T, Cps<S, R>>(){
            public Cps<S, R> apply(T dummy) {
                return c;
            }});
    }

continueは「ループの次の項目に進む処理」なので、ループの内側でcallCC()を呼べばよい。
折角なのでクラスにして中身を変えられるようにし、戻り値を返せるようにしてみる。

// リストを渡すと中身を順に処理するクラス。(break & continue付き)
// <T>はリストの中身の型、<R>脱出時の戻り値の型
public abstract class ForeachHandler<T, R> {
    // 処理の中身(具体クラスで実装)
    protected abstract Cps<R, R> handle(
            T value, Function<R, Cps<R, R>> break_, Function<R, Cps<R, R>> continue_);

    // リストを渡すと中身を順に処理する
    public R foreach(final List<T> list) {

        // runに渡して戻り値を取り出す為の関数を用意
        Function<R, R> id = new Function<R, R>() {
            public R apply(R value) {
                return value;
            }
        };

        // まず外側のcallCC。break用。
        return Cps.callCC(new Function<Function<R, Cps<R, R>>, Cps<R, R>>(){
            public Cps<R, R> apply(final Function<R, Cps<R, R>> break_) {
                Cps<R, R> current = Cps.unit((R) null);
                for (final T item : list) {
                    current = current.combine(
                        // 内側のcallCC。continue用。
                        Cps.callCC(new Function<Function<R, Cps<R, R>>, Cps<R, R>>(){
                            public Cps<R, R> apply(Function<R, Cps<R, R>> continue_) {
                                // 要素とbreak用関数とcontinue用関数を渡して処理呼び出し
                                return handle(item, break_, continue_);
                            }
                        })
                    );
                }
                return current;
            }
        }).run(id);
    }
}

使う側。Integerのリストを順番に見て行きつつプリントし、0が来たらスキップ、負の数が来たらループを抜けるサンプル。

    String message = new ForeachHandler<Integer, String>(){
        protected Cps<String, String> handle(
                Integer value,
                Function<String, Cps<String, String>> break_,
                Function<String, Cps<String, String>> continue_) {

            if (value < 0) {
                return break_.apply("Error. Negative number:" + value);
            } else if (value == 0) {
                return continue_.apply("last item has been skipped.");
            }
            System.out.println("n = " + value);
            return Cps.unit("finished.");
        }

    }.foreach(Arrays.asList(1, 0, 3, 2, 3, 0, -1, 1, 2));
    System.out.println("message = " + message);
n = 1
n = 3
n = 2
n = 3
message = Error. Negative number:-1

外側はともかく、内側のcallCC()で作った関数をつかって、内側の処理が続行できるのが少し不思議な気もする。
内側のcallCC()の戻り値を外側につないで、全体がCpsのチェーンのようになる構造にするのがポイントになっているようだ。

Generatorの実装

PythonなどにはGeneratorといって外部イテレータを簡単に作れる仕組みがあるようだ。
以下はWikipediaに載ってる、呼ぶ度に数値をインクリメントして返すGeneratorの例:

def countfrom(n):
    while True:
        yield n
        n += 1

countfrom()自体は数値ではなくgeneratorオブジェクトを返し、そのgeneratorオブジェクトに対してnext()を呼ぶことで実際の処理をおこなう。処理内でyieldを呼ぶとそこで処理中断してyieldに渡した値を戻し、その後再度next()を呼ぶことでyieldで中断したところの続きが実行される。


継続モナドを使ってそれを真似してみる。
abstractなクラスGeneratorクラスを定義し、実クラスでメソッドprocess()で個々の処理を記述したCpsを戻す。
Generatorにはyield()というメソッドが定義されているので、中断したい箇所で値と再開時に使用する値を渡す。
以下は呼ばれる度に数を1ずつ増やしていく例。但し最初はループは難しいので用意しているのは三回分だけ。

        final int count = 10;
        Generator<Integer, Integer> countup = new Generator<Integer, Integer>() {
            @Override
            protected Cps<Integer, Integer> process() {
                return Cps.<Integer, Integer>unit(count)
                    .bind(new Function<Integer, Cps<Integer, Integer>>() {
                        public Cps<Integer, Integer> apply(Integer n) {
                            // nを中断時に戻し、再開時にn+1を使用する。
                            return yield(n, unit(n + 1));
                        }
                    })
                    .bind(new Function<Integer, Cps<Integer, Integer>>() {
                        public Cps<Integer, Integer> apply(Integer n) {
                            return yield(n, unit(n + 1));
                        }
                    })
                    .bind(new Function<Integer, Cps<Integer, Integer>>() {
                        public Cps<Integer, Integer> apply(Integer n) {
                            return yield(n, unit(n + 1));
                        }
                    });
            }
        };
        System.out.println(countup.next());
        System.out.println(countup.next());
        System.out.println(countup.next());

出力

10
11
12


ではGeneratorのソースコード

import java.util.LinkedList;
import java.util.Queue;

// @param <T> 処理中に持ちまわす値の型
// @param <R> 中断時に戻す値の型
public abstract class Generator<T, R> {
    // 処理本体は実クラスで実装
    protected abstract Cps<T, R> process();

    // 中断用の関数を保持するフィールド
    private Function<R, Cps<T, R>> suspend;

    // next()は初回実行時と再開時で処理の内容が変わる。
    public final R next() {
        return (suspend == null) ? start() : resume();
    }

    // next()初回用の処理
    // Cps.callCC()を呼んで中断用の関数作成した後、process()で処理を取得し、runする。
    private R start() {
        // runに渡して戻り値を取り出す為の関数を作っておく
        Function<R, R> id = new Function<R, R>() {
            public R apply(R value) {
                return value;
            }
        };

        return Cps.callCC(new Function<Function<R, Cps<T, R>>, Cps<R, R>>(){
            public Cps<R, R> apply(final Function<R, Cps<T, R>> suspend) {
                // 中断用関数をフィールドに保存。
                Generator.this.suspend = suspend;
                // 処理本体を記述したCpsを取得。
                return process().bind(new Function<T, Cps<R, R>>() {
                        public Cps<R, R> apply(T value) {
                            return Cps.unit((R) null); // 最終結果は捨ててnullを戻している。
                        }
                    });
            }
        }).run(id);
    }

    // 中断再開時の処理を保持するQueue。実際は高々長さ1。
    private final Queue<Cps<T, R>> taskQueue = new LinkedList<Cps<T, R>>();

    // 中断のためのyieldメソッドの定義
    // @param value 中断時に戻す値
    // @param next 中断再開時に使用するCps
    protected final Cps<T, R> yield(final R value, final Cps<T, R> next) {
        // 戻ってくるための継続を生成
        return Cps.callCC(new Function<Function<T,Cps<T,R>>, Cps<T,R>>() {
            public Cps<T, R> apply(final Function<T, Cps<T, R>> f) {
                // 再開時に実行する処理を保存
                taskQueue.offer(next.bind(f));
                // 中断時用の関数に値を渡して一旦中断
                return suspend.apply(value);
            }
        });
    }
    // next()二回目以降の処理
    private R resume() {
        if (taskQueue.isEmpty()) {
            throw new IllegalStateException("Not suspended or no more task.");
        }
        // run()に渡して処理を再開するためのダミー関数
        Function<T, R> dummy = new Function<T, R>() {
            public R apply(T value) {
                return null; // 呼ばれないはず
            }
        };
        // 処理を再開
        return taskQueue.poll().run(dummy);
    }

    // 以下便利メソッド

    // Cps.unit()を呼んで戻す
    protected final Cps<T, R> unit(T value) {
        return Cps.unit(value);
    }
    // 継続関数に関係なく値valueを戻すCpsを戻す。
    // 最終値の生成に使用
    protected final Cps<T, R> end(final R value) {
        return new Cps<T,R>(new Function<Function<T, R>, R>(){
            public R apply(Function<T, R> k) {
                return value;
            }});
    }
}

process()の戻り値のチェックとかした方が良いけど今回は省略。

Generatorの使用例

リストから外部イテレータを生成してみる。

    final List<Integer> numbers = Arrays.asList(1, 2, 3, -1, 3, 0);

    Generator<Void, Integer> iter = new Generator<Void, Integer>() {
        @Override
        protected Cps<Void, Integer> process() {
            Cps<Void, Integer> current = unit(null);
            for (final Integer n : numbers) {
                current = current.bind(new Function<Void, Cps<Void,Integer>>() {
                    public Cps<Void, Integer> apply(Void dummy) {
                       return yield(n, unit(null));
                    }
                });
            }
            return current.combine(end(Integer.MIN_VALUE));
        }
    };
    Integer value = null;
    while ((value = iter.next()) != Integer.MIN_VALUE) {
        System.out.println("value = " + value);
    }

終了条件はnullでも良いけどInteger.MIN_VALUEを使ってみた。
next()を呼ぶ前に値が取れるかを確認できれば良いのだけどこの実装ではムリ。


ファイルから一行ずつ読み込んで行番号を付ける処理の例。
while文のような処理を作るのは難しいので、yield時にbindし足すようにする。

    private void printWithLineNumber(final InputStream in) {
        Generator<Void, String> generator = new Generator<Void, String>() {
            @Override
            protected Cps<Void, String> process() {
                final BufferedReader reader =
                    new BufferedReader(new InputStreamReader(in));

                return
                    unit(null)
                    .bind(new Function<Void, Cps<Void, String>>() {
                         public Cps<Void, String> apply(Void dummy) {
                            try {
                                String line = reader.readLine();
                                if (line != null) {
                                    // yield時にbindし足す。thisはFunction自体。
                                    return yield(line, unit(null).bind(this));
                                }
                            } catch (IOException e) {
                                 e.printStackTrace();
                            }
                            return end(null);
                        }                        
                    })
                    .combine(end(null)); // 終了判定はnullでおこなう
            }
        };
        int c = 0;
        String line = null;
        while ((line = generator.next()) != null) {
            System.out.printf("%6d: %s\n", ++c, line);
        }
    }

リストのIteratorもBufferedReaderも元々外部イテレータみたいなものなのでインパクト薄い……。
良いサンプルを思いついたら足す。

Generatorの使用例(2) 無限ループ

Wikipediaの例を実装してみた。

    public Generator<Integer, Integer> countfrom(final int base) {
        return new Generator<Integer, Integer>() {
            @Override
            protected Cps<Integer, Integer> process() {
                return unit(base)
                .bind(new Function<Integer, Cps<Integer,Integer>>() {
                    public Cps<Integer, Integer> apply(Integer n) {
                       return yield(n, unit(n + 1).bind(this));
                    }
                });
            }
        };
    }
    @Test
    public void testIterator() {
        Integer n = null;
        Generator<Integer, Integer> iter = countfrom(10);

        while ((n = iter.next()) <= 20) {
            System.out.println(n);
        }
 
        assertEquals(Integer.valueOf(21), n);
    }

特に終了条件無く、呼び出した分だけ何回でも使える。


もう一つの例。必要に応じて素数をいくらでも作成する

    public Generator<Integer, Integer> primes() {
        final int n = 2;
        final List<Integer> p = new ArrayList<Integer>();
        return new Generator<Integer, Integer>() {
            @Override
            protected Cps<Integer, Integer> process() {
                return unit(n)
                .bind(new Function<Integer, Cps<Integer,Integer>>() {
                    public Cps<Integer, Integer> apply(Integer n) {
                        for(int f : p) {
                           if (n % f == 0) {
                               return unit(n + 1).bind(this);
                            }
                        }
                        p.add(n);
                        return yield(n, unit(n + 1).bind(this));
                     }
                });
            }
        };
    }
    @Test
    public void testPrimes() {
        Generator<Integer, Integer> f = primes();

        System.out.println(f.next());
        System.out.println(f.next());
        System.out.println(f.next());
        System.out.println(f.next());
        System.out.println(f.next());
        System.out.println(f.next());
        
        assertEquals(Integer.valueOf(17), f.next());
    }

実行例

2
3
5
7
11
13

Generator書き直し

  • 初期値を渡せるように変更(unit()したものがprocess()に引数で渡ってくる。)
  • コンストラクタで初回用の継続処理を作成するようにし、start()とresume()を統合
package sample.cont;

import java.util.LinkedList;
import java.util.Queue;

/**
 * Cpsクラスを使い、外部イテレータを生成するクラス。
 * @param <T> 処理内部の中間的な型
 * @param <R> 中断時に戻す値の型
 */
public abstract class Generator<T, R> {
    /** 中断用の関数 */
    private Function<R, Cps<T, R>> suspend;
    /** 中断時の継続処理を保持するQueue。高々長さ1。 */
    private final Queue<Cps<R, R>> taskQueue = new LinkedList<Cps<R, R>>();

    /**
     * 実クラスで処理本体を記述したCpsを生成して戻すこと。
     * 引数としてコンストラクタに渡した初期値をくるんだCps<T, R>を渡すので使っても良い。
     * 戻り値は、Cps<T, R>オブジェクトまたはnullを戻すこと。
     * nullを戻した場合、対応するnext()呼び出しの戻り値にはnullを戻す。
     * @param initial 初期値
     * @return 処理を結合したCps<T, R>オブジェクトまたはnull
     */
    protected abstract Cps<T, R> process(Cps<T, R> initial);

    /**
     * 初期値を渡しつつGeneratorを初期化
     * @param initial 初期値。unit()でCpsオブジェクト化し、process()に渡す。
     */
    public Generator(final T initial) {
        taskQueue.offer(Cps.callCC(new Function<Function<R, Cps<T, R>>, Cps<R, R>>(){
            public Cps<R, R> apply(final Function<R, Cps<T, R>> suspend) {
                // 脱出関数をフィールドに保存。
                Generator.this.suspend = suspend;
                // 処理本体を記述したCpsを生成する。
                Cps<T, R> result = process(unit(initial));
                if (result == null) {
                    return Cps.unit((R) null);
                }
                return 
                    result.bind(new Function<T, Cps<R, R>>() {
                        public Cps<R, R> apply(T value) {
                            return Cps.unit((R) null);
                        }
                    });
            }
        }));
    }
    
    /**
     * 中断されたタスクを呼び出し、継続処理を実行する。
     * @return 継続処理実行後の値。例えばyield()で中断した場合、その引数を戻す。
     */
    public final R next() {
        if (taskQueue.isEmpty()) {
            throw new IllegalStateException("Not started or no more task.");
        }
        // run()に渡して値を取得するためのおの関数
        Function<R, R> id = new Function<R, R>() {
            public R apply(R value) {
                return value;
            }
        };
        // 処理を再開
        return taskQueue.poll().run(id);
    }

    /**
     * 処理を中断し値を戻すためのメソッド。実クラスから呼び出す。
     * @param value 中断時に戻す値
     * @param next 中断再開時に使用するCps
     * @return 中断用のCpsを戻す。
     */
    protected final Cps<T, R> yield(final R value, final Cps<T, R> next) {
        if (suspend == null) {
            throw new IllegalStateException("Not started yet.");
        }
        // 戻ってくるための継続を生成しつつ、中断用の関数を呼び出す。
        return Cps.callCC(new Function<Function<T, Cps<R, R>>, Cps<T, R>>() {
            public Cps<T, R> apply(final Function<T, Cps<R, R>> f) {
                // 再開時に実行する処理を保存
                taskQueue.offer(next.bind(f));
                // 中断時用の関数に値を渡して一旦中断
                return suspend.apply(value);
            }
        });
    }
    
    /* 以下便利メソッド*/
    // 端的にCps.unit()を呼んで戻す
    protected final Cps<T, R> unit(T value) {
        return Cps.unit(value);
    }

    // 継続関数に関係なく値valueを戻すCpsを戻す。
    // 最終値の生成に使用
    protected final Cps<T, R> end(final R value) {
        return new Cps<T,R>(new Function<Function<T, R>, R>(){
            public R apply(Function<T, R> k) {
                return value;
            }});
    }
}

*1:と言いたいところだけどコンピュテーション式の展開規則とか良く分からなくて勘違いしてるかも。