Javaで末尾再帰最適化をする。(?)

お題:
http://d.hatena.ne.jp/wasabiz/20110118/1295335821
Rubyで末尾再帰最適化をする。 - Homoiconic Days


Javaなどの言語では、通常、再帰を使ったプログラムは、呼び出しが深くなるといつかはStack Overflowで実行時エラーになってしまう。それに対して、Schemeなどの関数型言語では、自動で末尾再帰最適化というのをおこなって、Stack Overflowがおこらないようにしているものが多い。(末尾再帰最適化についてはhttp://practical-scheme.net/docs/cont-j.htmlの「末尾再帰と継続」の解説が分かりやすい。)


PythonRubyは自動では末尾再帰最適化はおこなっていないが、言語にあるしくみを利用して末尾再帰最適化を後付けすることが出来るらしい。内容をみるとCPS化してループに変形するっぽい処理だったので、Javaでもできないかやってみた。


これで末尾再帰最適化って言ってしまってもいいのか自信ないけど、とりあえずStack Overflowが起きないようにはできたみたい。(追記: やっぱり用語的に末尾最適化というのは違うらしい。cf: 「2011-01-24 - ブートストラッピングでコンパイラを作る日記」)

準備

AOPとProxyを使って実現してみることにした。*1

いろいろ制限がある。

  • finalのクラスおよびfinalのメソッドでは使用不可。
  • publicのインスタンスメソッドのみ可
  • 戻り値の型がインタフェースまたはjava.lang.Objectの場合のみ可


AOPのライブラリとしてはSeasar2を使う。今回使った環境は以下の通り。

  • J2SE 5.0
  • Seasar 2.4.20
    • 付属のlibの中身全部をとりあえず追加
    • ミニマムでは以下の4つのjar
      • s2-framework-2.4.20.jar
      • aopalliance-1.0.jar
      • javassist-3.4.ga.jar
      • commons-logging-1.1.jar
  • JUnit 3.8.2

ソースコード

例によって動作確認しやすい&1ファイルで出来るようにJUnitのUnitTestに書いているが(略
追記: 最適化をおこなうメソッドをメソッド名で指定するように修正 *2
追記: 戻り値の型がjava.lang.Objectの場合にも対応。サンプルも修正。

package sample.tco;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;

import junit.framework.TestCase;

import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.seasar.framework.aop.Aspect;
import org.seasar.framework.aop.Pointcut;
import org.seasar.framework.aop.impl.AspectImpl;
import org.seasar.framework.aop.impl.PointcutImpl;
import org.seasar.framework.aop.proxy.AopProxy;


public class TailCallOptimizationSample extends TestCase {
    /**
     * メソッド呼び出しを末尾再帰最適化する為のインターセプタ
     */
    public static class TailCallOptimizationInterceptor implements MethodInterceptor {
        /**
         * MethodInvocationを遅延実行する為のインタフェース
         */
        private interface Continuation {
            Object __invoke() throws Throwable;
        }
        /**
         * MethodInvocationを遅延実行する為のContinuationを生成する。
         * @param invocation 遅延実行対象となるMethodInvocation
         */
        private static Object createContinuation(final MethodInvocation invocation) {
            Class<?> returnType = invocation.getMethod().getReturnType();
            if (returnType == Object.class) {
                // ObjectにならContinuationを代入可能
                return new Continuation() {
                    public Object __invoke() throws Throwable {
                        return invocation.proceed();
                    }
                };
            }
            // 戻り値の型に代入可能にするため、戻り値の型 & ContinuationであるProxyを作って戻す。
            return Proxy.newProxyInstance(
                Thread.currentThread().getContextClassLoader(), // 悩ましい
                new Class[] {returnType, Continuation.class},
                new InvocationHandler() {
                    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                        if (method.getDeclaringClass() == Continuation.class) {
                            return invocation.proceed();
                        }
                        throw new RuntimeException("Illegal invocation.");
                    }
                });
        }
        /**
         * 初回の呼び出し(トップレベル)かを判定するためのマーカー。
         * get()の結果がnullならばトップレベル
         */
        private ThreadLocal<Object> topLevelMarker = new ThreadLocal<Object>();

        /**
         * メソッドが呼ばれた際に、トップレベル以外は戻り値の型に偽装したContinuationを戻す。
         * トップレベルの場合、Continuation以外の値が取れるまでループして呼び出す。
         */
        public Object invoke(MethodInvocation invocation) throws Throwable {
            if (topLevelMarker.get() != null) {
                return createContinuation(invocation);
            }
            topLevelMarker.set(new Object());
            try {
                Object result = invocation.proceed();
                while (result instanceof Continuation) {
                    result = ((Continuation) result).__invoke();
                }
                return result;
            } finally {
                topLevelMarker.set(null);
            }
        }
    }
    /**
     * 元のクラスをエンハンスし、末尾再帰最適化されたクラスを戻す。
     * @param <T> targetClassの型
     * @param targetClass 元のクラス。
     * @param pointcutNames 最適化を適用するメソッドの名前の配列
     * @return エンハンスしたクラスを戻す
     */
    public static <T> Class<T> optimize(Class<T> targetClass, String... pointcutNames) {
        MethodInterceptor interceptor = new TailCallOptimizationInterceptor();
        Pointcut pointcut =  (pointcutNames.length > 0)
        		? new PointcutImpl(pointcutNames)
        		: new PointcutImpl(targetClass);
        Aspect aspect = new AspectImpl(interceptor, pointcut);

        AopProxy proxy = new AopProxy(targetClass, new Aspect[] {aspect});
        Class<T> enhanced = proxy.getEnhancedClass();
        return enhanced;
    }
    /**
     * 末尾再帰最適化されたクラスのインスタンスを生成して戻す便利メソッド。
     * デフォルトコンストラクタでインスタンスを生成できるクラスでのみ使用できる。
     * @param <T> targetClassの型
     * @param targetClass 元のクラス。
     * @param pointcutNames 最適化を適用するメソッドの名前の配列
     * @return 末尾最適化されたクラスのインスタンスを生成して戻す。
     */
    public static <T> T newOptimizedInstance(Class<T> targetClass, String... pointcutNames) {
        try {
            Class<T> optimizedClass = optimize(targetClass, pointcutNames);
            return optimizedClass.newInstance();
        } catch (InstantiationException e) {
            throw new IllegalArgumentException("Failed to call default constructor.", e);
        } catch (IllegalAccessException e) {
            throw new IllegalArgumentException("Failed to call default constructor.", e);
        }
    }
/*

末尾再帰で1からnまで足し算をするサンプル

上の続きで

 */
    /** 末尾再帰で1からnまで足し算をするサンプル */
    public static class SumSample  {
        public Object sum(int n, int acc) {
            if (n == 0) {
                return acc;
            } else {
                return sum(n - 1, acc + n);
            }
        }
    }
    /** 通常の呼び出し */
    public void testSumNormal()  {
        try {
            SumSample sample = new SumSample();
            int result = (Integer) sample.sum(10000, 0);
            assertEquals(50005000, result);
            fail("設定にもよるが10000回も再帰するとエラーになる");
        } catch (Throwable t) {
            t.printStackTrace();
        }
        
    }
    /** 最適化あり */
    public void testSumOptimized()  {
        SumSample sample = newOptimizedInstance(SumSample.class, "sum");
        int result = (Integer) sample.sum(10000, 0);
        assertEquals(50005000, result); // スタックオーバーフローにならない
    }
/*

相互再帰の末尾再帰最適化のサンプル

上の続きで

 */
    /** 偶数か奇数かを判定する */
    public static class OddEvenSample {
        public Object even(int n) {
            if (n == 0) {
                return true;
            } else {
                return odd(n - 1);
            }
        }
        public Object odd(int n) {
            if (n == 0) {
                return false;
            } else {
                return even(n - 1);
            }
        }
    }
    /** 最適化あり */
    public void testOddOptimized()  {
        OddEvenSample sample = newOptimizedInstance(OddEvenSample.class, "odd", "even");
        boolean result = (Boolean) sample.odd(100001);
        assertTrue("100001は奇数", result); // スタックオーバーフローにならない
    }
    public void testEvenOptimized()  {
        OddEvenSample sample = newOptimizedInstance(OddEvenSample.class, "odd", "even");
        boolean result = (Boolean) sample.even(100000);
        assertTrue("100000は偶数", result); // スタックオーバーフローにならない
    }
}

補足

今回はS2内部のAopProxyを直に呼んだけど、普通にdiconファイルに書いてで設定しても当然問題ない。

interceptorではトップレベルの呼び出し以外の場合、実際に呼び出しを行って結果を戻す代わりに、呼び出しを遅延実行できるオブジェクトを戻している。そのため、遅延実行できるオブジェクトがメソッドの元々の戻り値の型に代入できる必要がある。戻り値の型がjava.lang.Objectの場合には、結果の代わりに遅延実行用のオブジェクトをnewして戻すことができるが、戻り値の型がインタフェースの場合、遅延実行用のオブジェクトを戻り値の型を持ったProxyとして作成することで、元の型にも代入できるようにしている。


上記のサンプルはメソッドをObject型としている。インタフェースを使った例は例えば下のようになる。

    interface Wrapper<T> {
        T get();
    }
    static class WrapperImpl<T> implements Wrapper<T> {
        private T value;
        public WrapperImpl(T value) {
            this.value = value;
        }
        public T get() {
            return value;
        }
    }
    /** 末尾再帰で1からnまで足し算をするサンプル */
    public static class SumSample  {
        public Wrapper<Integer> sum(int n, int acc) {
            if (n == 0) {
                return new WrapperImpl<Integer>(acc);
            } else {
                return sum(n - 1, acc + n);
            }
        }
    }
    /** 最適化あり */
    public void testSumOptimized()  {
        SumSample sample = newOptimizedInstance(SumSample.class, "sum");
        int result = sample.sum(10000, 0).get();
        assertEquals(50005000, result); // スタックオーバーフローにならない
    }

補足2(2011/1/21追記)

折角なのでもう少しサンプル集めてみた。

相互末尾再帰で単語(スペース区切り)をカウントする

    public static class WordCounter {
        int countWord(Reader r) throws IOException {
            return (Integer) inSpace(r, 0);
        }
        public Object inWord(Reader r, int count) throws IOException {
            int c = r.read();
            if (c == -1) {
                return count;
            } else if (c == ' ') {
                return inSpace(r, count);
            } else {
                return inWord(r, count);
            }
        }
        public Object inSpace(Reader r, int count) throws IOException {
            int c = r.read();
            if (c == -1) {
                return count;
            } else if (c == ' ') {
                return inSpace(r, count);
            } else {
                return inWord(r, count + 1);
            }
        }
    }

    public void testWordCounter() throws IOException {
        WordCounter counter = newOptimizedInstance(WordCounter.class, "inWord", "inSpace");
        assertEquals(3, counter.countWord(new StringReader("死にそう オレ 死にそう")));
    }

階乗計算のサンプル

    public static class FactorialSample {
        BigInteger factorial(long n) {
            return (BigInteger) _factorial(BigInteger.valueOf(n), BigInteger.ONE);
        }

        public Object _factorial(BigInteger n, BigInteger acc) {
            if (n.equals(BigInteger.ZERO)) {
                return acc;
            } else {
                return _factorial(n.subtract(BigInteger.ONE), n.multiply(acc));
            }
        }
    }

    public void testFactorial() {
        FactorialSample sample = newOptimizedInstance(FactorialSample.class, "_factorial");
        BigInteger result = sample.factorial(100000);
        String str = result.toString();
        assertEquals("100000!の桁数は456574", 456574, str.length());
        assertEquals("100000!の先頭20桁は28242294079603478742",
            "28242294079603478742", str.substring(0, 20));
        for (int i = 0, c = str.length(); i < c; i += 80) {
            System.out.println(str.substring(i, Math.min(c, i + 80)));
        }
    }

フィボナッチ数計算のサンプル

    public static class FibonacciSample {
        BigInteger fibonacci(long n) {
            return (BigInteger) _fibonacci(BigInteger.valueOf(n), BigInteger.ZERO, BigInteger.ONE);
        }

        public Object _fibonacci(BigInteger n, BigInteger a, BigInteger b) {
            if (n.equals(BigInteger.ZERO)) {
                return a;
            } else {
                return _fibonacci(n.subtract(BigInteger.ONE), b, a.add(b));
            }
        }
    }

    public void testFibonacci() {
        FibonacciSample sample = newOptimizedInstance(FibonacciSample.class, "_fibonacci");
        BigInteger result = sample.fibonacci(100000);
        String str = result.toString();
        assertEquals("fibonacci(100000)の桁数は20899", 20899, str.length());
        assertEquals("fibonacci(100000)の先頭20桁は25974069347221724166",
            "25974069347221724166", str.substring(0, 20));
        for (int i = 0, c = str.length(); i < c; i += 80) {
            System.out.println(str.substring(i, Math.min(c, i + 80)));
        }
    }

*1:他の方法としては、例えばASMでサブクラスを動的に作るとかも考えられる。

*2:PointcutImpl#PointcutImpl(Class)の挙動がS2Containerのバージョンによって異なる為、省略時の挙動はS2Containerのバージョンに依存する。cf: http://www.seasar.org/source/browse/s2container?view=revision&revision=3243