Javaで非決定計算

http://d.hatena.ne.jp/keyesberry/20110831/p1」によると、RubyにはAmbという非決定計算をおこなう為の拡張モジュールがあるらしい。
Javaでもやってみた。


Javaではカレント継続のキャプチャ機能が言語レベルでサポートされてないので、前に作った継続モナドクラスを使う。


参考にしたページ:

ソースコード

先に便利そうなメソッドをいくつか用意しておく

package sample.cont;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

public final class Utils {
    private Utils() {
    }
    /**
     * collectionからitemsの除いたコレクション(実際にはArrayList)を新たに作成して戻す。
     * 要素が重複している場合全て削除されないこともある。
     */
    public static <T> Collection<T> without(Collection<T> collection, Object... items) {
        List<T> newCollection = new ArrayList<T>(collection);
        for (Object item : items) {
            newCollection.remove(item);
        }
        return newCollection;
    }
    /**
     * fromからtoまでの連続する整数を含むCollectionを戻す。
     */
    public static Collection<Integer> sequence(int from, int to) {
        Collection<Integer> collection = new ArrayList<Integer>();
        for (int i = from; i <= to; i++) {
            collection.add(i);
        }
        return collection;
    }
    /**
     * 恒等関数を作る
     */
    public static <T> Function<T, T> makeId() {
        return new Function<T, T>() {
            @Override
            public T apply(T value) {
                return value;
            }
        };
    }
}

ではAmbクラス。深さ優先探索をします。
ambオペレータにあたるメソッドはchoose()という名前で定義し、また、バックトラックは空リストを渡すのではなく明示的にfailure()を呼ぶようにした。

package sample.cont;

import static sample.cont.Utils.*;
import java.util.Collection;
import java.util.Stack;

/**
 * 非決定計算をおこなうためのクラス
 * @param <R> 最終的な値の型
 */
public class Amb<R> {
    // 遅延評価のためのインタフェース
    private interface LazyExpr<T> {
        T eval();
    }

    // 継続を捕捉しておくためのスタック
    private final Stack<LazyExpr<Cps<R, R>>> fail = new Stack<LazyExpr<Cps<R, R>>>();

    /**
     * バックトラックをおこなう。
     */
    public Cps<R, R> failure() {
        if (fail.isEmpty()) {
            // 空のときは最終的な値としてnullを戻すCpsを戻す
            return new Cps<R, R>(new Function<Function<R, R>, R>(){
                @Override
                public R apply(Function<R, R> value) {
                    System.err.println("no choice");
                    return null;
                }
            });
        }
        return fail.pop().eval();
    }

    /**
     * @return 非決定な値を持つCpsを戻す。
     * @param <T> itemsの要素の型
     * @param items 選択肢。空のCollectionは指定できない。
     */
    public <T> Cps<T, R> choose(final Collection<T> items) {
        if (items == null || items.isEmpty()) {
            throw new IllegalArgumentException("items should not be empty.");
        }
        return Cps.callCC(new Function<Function<T,Cps<T,R>>, Cps<T,R>>() {
            @Override
            public Cps<T, R> apply(final Function<T, Cps<T, R>> exit) {
                return Cps.callCC(new Function<Function<T,Cps<R, R>>, Cps<T, R>>() {
                    @Override
                    public Cps<T, R> apply(final Function<T, Cps<R, R>> cc) {
                        // 要素を一つ取り出す
                        T head = items.iterator().next();
                        // headを取り除いた残りの要素
                        final Collection<T> rest = without(items, head);
                        if (!rest.isEmpty()) {
                            // 残りの要素があれば残りの要素の処理をスタックに積む
                            fail.push(new LazyExpr<Cps<R, R>>() {
                                @Override
                                public Cps<R, R> eval() {
                                    return choose(rest).bind(cc);
                                }
                            });
                        }
                        return exit.apply(head);
                    }
                });
            }
        });
    }
}

使用例1: 積が12になる数の組をプリントする

1〜12の数二つの組み合わせで、積が12になるものをプリントするコードを考える。
まず最終結果の型にあわせてAmbインスタンスを作成する。今回は結果をList(例えば[3, 4]とか)で出すことにする。

    final Amb<List<Integer>> amb = new Amb<List<Integer>>();

一つ目の1〜12の数を生成して使うところをを書いてみる。sequence()は上で定義した便利メソッドを使用。

        Cps<List<Integer>, List<Integer>> cont =
            // xは1〜12を取り得る
            amb.choose(sequence(1, 12))
                .bind(new Function<Integer, Cps<List<Integer>, List<Integer>>>() {
                    @Override
                    public Cps<List<Integer>, List<Integer>> apply(final Integer x) {
                        // ここにxを使う処理を書く
                    }
                });

もう一つの1〜12の数を生成して使うところを書いてみる。(上の「// ここにxを使う処理を書く」の部分)

                        // yは1〜12を取り得る
                        return amb.choose(sequence(1, 12))
                            .bind(new Function<Integer, Cps<List<Integer>, List<Integer>>>() {
                                @Override
                                public Cps<List<Integer>, List<Integer>> apply(final Integer y) {
                                    // ここにxとyを使う処理を書く
                                }
                            });

インデントが……do記法とかコンピュテーション式とかfor内包表記とか無いので仕方ないのです。
以降、インデントは勘弁してもらって浅く書く。


判定する部分を書く。(上の「// ここにxとyを使う処理を書く」の部分)

                                        if (x * y == target) {
                                            return Cps.unit(Arrays.asList(x, y));
                                        } else {
                                            return amb.failure();
                                        }

条件にマッチしたら結果をCps.unit()で包んで戻す。マッチしなければfailure()を呼んでバックトラックする。


これだけだと、一連の計算を表すCpsを生成してcontに代入しただけなのでまだ計算自体は実行されていない。
実際の値を取り出す処理も書く必要がある。上記の分もまとめると次のようになる。

    @Test
    public void testProduct() {
        final int target = 12;
        final Amb<List<Integer>> amb = new Amb<List<Integer>>();

        Cps<List<Integer>, List<Integer>> cont =
            // x <- [1..12]
            amb.choose(sequence(1, 12))
                .bind(new Function<Integer, Cps<List<Integer>, List<Integer>>>() {
                    public Cps<List<Integer>, List<Integer>> apply(final Integer x) {

            // y <- [1..12]
            return amb.choose(sequence(1, 12))
                .bind(new Function<Integer, Cps<List<Integer>, List<Integer>>>() {
                    public Cps<List<Integer>, List<Integer>> apply(final Integer y) {

                // 計算部分
                if (x * y == target) {
                    return Cps.unit(Arrays.asList(x, y));
                } else {
                    return amb.failure();
                }
                         
            }}); // y 
            }}); // x

        Function<List<Integer>, List<Integer>> id = makeId();

        for (List<Integer> result;
                (result = cont.run(id)) != null;
                    cont = amb.failure()) {

            System.out.println("result = " + result);
            assertEquals(target, result.get(0) * result.get(1));
        };
    }

cont.run()にidを渡して結果を取得している。バックトラックを続けると最後にnullを返すCpsが得られるので、それを終了条件にしている。
結果の出力

result = [1, 12]
result = [2, 6]
result = [3, 4]
result = [4, 3]
result = [6, 2]
result = [12, 1]

これだけだとfor文を二個ネストしたのと変わらない気もするけど*1、探索を途中で中断したり、中断した所から再開したりできるので、便利なのです!(受け売り)

使用例2: 覆面算(SEND+MORE=MONEY)

同じ文字には同じ数字が入り、違う文字には違う数字が入る、というルールで、アルファベットに0〜9の数字を当てはめて式が成り立つようにするようです。

  SEND
+ MORE
------
 MONEY

結果を格納するResultクラスを作ってそれを最終的な答えの型とする。数字は一文字選ぶごとに候補から削除して重複しないようにする。あとsとmは先頭なので0も除去して選んでいる。

    // 数字の列から数に変換する。例: [1,2,3]→123
    private static int toInt (int... digits) {
        int total = 0;
        for (int digit : digits) {
            total = total * 10 + digit;
        }
        return total;
    }
    @Test
    public void testSendMoreMoney () {
        class Result {
            int send;
            int more;
            int money;
            @Override
            public String toString() {
                return "  " + send + "\n" +
                       "+ " + more + "\n" +
                       "------\n" +
                       " "  + money  + "\n";
            }
        }

        final Collection<Integer> digits = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);

        final Amb<Result> amb = new Amb<Result>();
        Cps<Result, Result> cont =
            // s
            amb.choose(without(digits, 0))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer s) {
            // e
            return amb.choose(without(digits, s))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer e) {
            // n
            return amb.choose(without(digits, s, e))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer n) {
            // d
            return amb.choose(without(digits, s, e, n))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer d) {
            // m
            return amb.choose(without(digits, 0, s, e, n, d))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer m) {
            // o
            return amb.choose(without(digits, s, e, n, d, m))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer o) {
            // r
            return amb.choose(without(digits, s, e, n, d, m, o))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer r) {
            // y
            return amb.choose(without(digits, s, e, n, d, m, o, r))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer y) {

                // 計算部分
                if (toInt(s, e, n, d) + toInt(m, o, r, e) == toInt(m, o, n, e, y)) {
                    Result result = new Result();
                    result.send = toInt(s, e, n, d);
                    result.more = toInt(m, o, r, e);
                    result.money = toInt(m, o, n, e, y);
                    return Cps.unit(result);
                } else {
                    return amb.failure();
                }

            }}); // y
            }}); // r
            }}); // o
            }}); // m
            }}); // d
            }}); // n
            }}); // e
            }}); // s

        Function<Result, Result> id = makeId();
        Result result = cont.run(id);

        System.out.println("result:\n" + result);
        assertEquals(9567, result.send);
        assertEquals(1085, result.more);
        assertEquals(10652, result.money);
    }

実行すると……StackOverflowErrorになってしまう。正解でない組み合わせが続くと際限なくamb.failure()が呼ばれてしまうため。Javaは末尾再帰最適化とかしないので仕方が無い。
直接amb.failure()を呼ばず、不正解の組み合わせでも一旦処理を戻すようにする。

...(略)...
        final Result MISS = new Result(); // 不正解の場合のマーカー
        Cps<Result, Result> cont =
...(略)...
                if (toInt(s, e, n, d) + toInt(m, o, r, e) == toInt(m, o, n, e, y)) {
                    Result result = new Result();
                    result.send = toInt(s, e, n, d);
                    result.more = toInt(m, o, r, e);
                    result.money = toInt(m, o, n, e, y);
                    return Cps.unit(result);
                } else {
//                    return amb.failure();
                    return Cps.unit(MISS);
                }
...(略)...
            }}); // s

        Function<Result, Result> id = makeId();
        Result result;

        // MISS以外の結果が来るまでfailure()を呼び続ける
        while((result = cont.run(id)) == MISS) {
            cont = amb.failure();
        }

        if (result != null) {
...(略)...

不正解の組み合わせの場合に、マーカーとしてMISSを戻すようにし、MISSの場合に再度外側からfailure()を呼んでやることでStackOverflowErrorを避ける。
結果出力

result:
  9567
+ 1085
------
 10652

使用例3: ピタゴラス

合計がn以下のピタゴラス数を生成する。
これも直ぐにスタックあふれるので外れの時も空リストを戻している。

    public Cps<List<Integer>, List<Integer>> makePythag(final Amb<List<Integer>> amb, final int n) {
        // a <- [1..n/2]
        return amb.choose(sequence(1, n/2))
            .bind(new Function<Integer, Cps<List<Integer>, List<Integer>>>(){
                public Cps<List<Integer>, List<Integer>> apply(final Integer a) {
        // b <- [a..n/2]
        return amb.choose(sequence(a, n/2))
            .bind(new Function<Integer, Cps<List<Integer>, List<Integer>>>(){
                public Cps<List<Integer>, List<Integer>> apply(final Integer b) {
        // c <- [b..n/2]
        return amb.choose(sequence(b, n/2))
            .bind(new Function<Integer, Cps<List<Integer>, List<Integer>>>(){
                public Cps<List<Integer>, List<Integer>> apply(final Integer c) {
            // 計算部分
            if ( a + b + c <= n &&
                a * a +  b * b == c * c) {
                return Cps.unit(Arrays.asList(a, b, c));
            } else {
                return Cps.unit(Collections.<Integer>emptyList());
            }
        }}); // c
        }}); // b
        }}); // a
    }

    @Test
    public void testPythag() {
       Amb<List<Integer>> amb = new Amb<List<Integer>>();

        Cps<List<Integer>, List<Integer>> cont = makePythag(amb, 100);
        Function<List<Integer>, List<Integer>> id = makeId();
        for (List<Integer> result;
                (result = cont.run(id)) != null;
                    cont = amb.failure()) {
            if (!result.isEmpty()) {
                System.out.println("result = " + result);
                assertTrue(result.get(0) * result.get(0) + result.get(1) * result.get(1)
                            == result.get(2) * result.get(2));
            }
        };
    }

結果

result = [3, 4, 5]
result = [5, 12, 13]
result = [6, 8, 10]
result = [7, 24, 25]
result = [8, 15, 17]
result = [9, 12, 15]
result = [9, 40, 41]
result = [10, 24, 26]
result = [12, 16, 20]
result = [12, 35, 37]
result = [15, 20, 25]
result = [15, 36, 39]
result = [16, 30, 34]
result = [18, 24, 30]
result = [20, 21, 29]
result = [21, 28, 35]
result = [24, 32, 40]

使用例4: うそつきパズル

SICPに載ってるらしい。

5人の女子生徒が試験を受けた。彼女らの両親は結果に対し過度の関心を持っている、と彼女らは考えている。 そこで彼女らは自宅へ試験についての手紙を書くのに、誰もが1つの正しい情報と1つのうその情報を書こうと 約束した。以下は彼女らの手紙の関係する部分である。
  Betty: 「Kitty は試験が2番で私は3番でした。」
  Ethel: 「私がトップと聞いてうれしいでしょう。Joan が2ばんでした。」
  Joan: 「私は3番でした。可哀想な Ethel はビリでした。」
  Kitty: 「私は2番になりました。Mary は4番でしかありませんでした。」
  Mary: 「私は4番でした。トップの座は Betty がとりました。」
5人の女子生徒の本当の順番はどうなっているのか。

工夫も無く全組み合わせで試していく。

    @Test
    public void testLiars() {
        // 各人の順位を保持するためのクラス
        class Result {
            int betty;
            int ethel;
            int joan;
            int kitty;
            int mary;
            @Override
            public String toString() {
                return "{Betty=" + betty + "," +
                    " Ethel=" + ethel + "," +
                    " Joan="  + joan  + "," +
                    " Kitty=" + kitty + "," +
                    " Mary="  + mary  + "}";
            }
        }

        final Collection<Integer> orders = Arrays.asList(1, 2, 3, 4, 5);

        final Amb<Result> amb = new Amb<Result>();
        Cps<Result, Result> cont =
            // Betty
            amb.choose(orders)
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer betty) {
            // Ethel
            return amb.choose(without(orders, betty))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer ethel) {
            // Joan
            return amb.choose(without(orders, betty, ethel))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer joan) {
            // Kitty
            return amb.choose(without(orders, betty, ethel, joan))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer kitty) {
            // Mary
            return amb.choose(without(orders, betty, ethel, joan, kitty))
                .bind(new Function<Integer, Cps<Result, Result>>(){
                    public Cps<Result, Result> apply(final Integer mary) {

                if ((kitty == 2 ^ betty == 3) && // Betty:  「Kitty は試験が2番で私は3番でした。」 
                    (ethel == 1 ^ joan == 2) &&  // Ethel:  「私がトップと聞いてうれしいでしょう。Joan が2ばんでした。」 
                    (joan == 3 ^ ethel == 5) &&  // Joan:   「私は3番でした。可哀想な Ethel はビリでした。」
                    (kitty == 2 ^ mary == 4) &&  // Kitty:  「私は2番になりました。Mary は4番でしかありませんでした。」
                    (mary == 4 ^ betty == 1)) {  // Mary:   「私は4番でした。トップの座は Betty がとりました。」  
                    
                    Result r = new Result();
                    r.betty = betty;
                    r.ethel = ethel;
                    r.joan = joan;
                    r.kitty = kitty;
                    r.mary = mary;
                    return Cps.unit(r);
                } else {
                    return amb.failure();
                }

            }});
            }});
            }});
            }});
            }});

        Function<Result, Result> id = makeId();
        Result result = cont.run(id);

        if (result != null) {
            System.out.println("result = " + result);
            assertEquals(1, result.kitty);
            assertEquals(2, result.joan);
            assertEquals(3, result.betty);
            assertEquals(4, result.mary);
            assertEquals(5, result.ethel);
        }
    }

「(論理式1 ^ 論理式2)」はどちらか一方だけが成り立つときだけ真になる。いわゆる排他的論理和
出力結果:

result = {Betty=3, Ethel=5, Joan=2, Kitty=1, Mary=4}

使用例5: 選択肢の型が違う例

上の例は全て整数列から候補を選ぶようになっているけど、他の型を混ぜても大丈夫。
例えば、「果物の名前の文字数が2,4,6の場合のみ受け付ける」というのを考える。

    @Test
    public void testLength() {
        final Amb<String> amb = new Amb<String>();
        Cps<String, String> cont =
            // fruit names
            amb.choose(Arrays.asList("apple", "orange", "lemon", "lime", "banana"))
                .bind(new Function<String, Cps<String, String>>(){
                    public Cps<String, String> apply(final String name) {

            // lengths to accept
            return amb.choose(Arrays.asList(2, 4, 6))
                .bind(new Function<Integer, Cps<String, String>>(){
                    public Cps<String, String> apply(final Integer length) {

                // 計算部分
                if (name.length() == length) {
                    return Cps.unit(name);
                } else {
                    return amb.failure();
                }

            }});
            }});

        Function<String, String> id = makeId();

        for (String result; (result = cont.run(id)) != null; cont = amb.failure()) {
            System.out.println("accept: " + result);
            assertTrue(result.length() == 2 || result.length() == 4 || result.length() == 6);
        };
    }

名前の方は文字列の一覧から、長さの方は整数の一覧から選んでいる。

*1:do記法などと違って見た目にも単なる二重ループっぽくなっているので本当に残念