JavaでWriterモナドでFizzBuzz

All About Monads」を読んでたら、前に書いたFizzBuzzが実はWriterモナドであることに気がついたのでそれっぽく書き直してみた。


しかし自分のような素人が考え付くような単純なモナドは大体「標準的モナドのカタログ」の中のどれかになってるような。本当に良く出来たカタログですわー……。

Monoidの定義

The Writer monad」によると、Writerは「値,ログ」という組を持ち、ログの型はモノイドでなければならないらしい。
まずモノイドをエミュレートするためのインタフェースMonoidを定義する。

public interface Monoid<M extends Monoid<M>> {
    /*
     * terazzoからのお願い:
     *   o. Monoidの実装クラスに指定するパラメータMはその実装クラスにしてね。
     *   o. そのクラスにはmempty()ってstaticメソッドを実装して単位元を返してね。
     */
    // static M mempty();

    /** 単位元を返すstaticメソッドの名前 */
    String EMPTY_NAME = "mempty";

    /** thisとthatの間の二項演算を定義する。*/
    M mappend(M that);

    /** thisが単位元の時trueを戻す。*/
    boolean isEmpty();
}

二項演算子にあたるメソッドの名前をmappend()とする。インタフェースに「自分自身のクラスを返せ」というメソッドは定義できないので、再帰型総称型でその代わりにする。
単位元を返すメソッドの名前をmempty()とする。インタフェースでstaticメソッドを強制はできないので、お願いベースで定義してもらいリフレクションで取得することにする。


これを継承して、文字列ベースのログクラスを定義する。

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;

public class Log implements Monoid<Log> {
    /** ログの内容をあらわす文字列値 */
    private final String value;
    private Log(String value) {
        this.value = value;
    }

    /** 単位元を戻す。 */
    public static Log mempty() {
        return new Log(null);
    }

    /** valueを値として持つLogを戻す。 */
    public static Log of(String value) {
        if (value == null) {
            throw new IllegalArgumentException("value should not be null.");
        }
        return new Log(value);
    }

    /** thisが単位元かどうかを戻す。 */
    @Override
    public boolean isEmpty() {
        return value == null;
    }

    /** 二項演算を定義する。*/
    @Override
    public Log mappend(Log that) {
        // 片方が 単位元であれば、他方の結果を返す。
        // どちらも 単位元でないならば、両方から新たな値を作って戻す。
        // ここでは端的に文字列を連結して新たな値を作る。
        return this.isEmpty() ? that
             : that.isEmpty() ? this
             : /* otherwise */  Log.of(this.value + that.value);
    }

    @Override
    public String toString() {
        return value == null ? "*null*" : value;
    }
    @Override
    public int hashCode() {
        return HashCodeBuilder.reflectionHashCode(this);
    }
    @Override
    public boolean equals(Object other) {
        return EqualsBuilder.reflectionEquals(this, other);
    }
}

一応適当な値でテスツ。

import static org.junit.Assert.assertEquals;
import org.junit.Test;

public class LogTest {
    // 単位元の存在
    @Test
    public void testEmpty() {
        Log e = Log.mempty();
        Log a = Log.of("hoge");

        // e ・ a = a 
        assertEquals(e.mappend(a), a);
        // a ・ e = a 
        assertEquals(a.mappend(e), a);
    }

    // 結合律
    @Test
    public void testAssociativity() {
        Log a = Log.of("hoge");
        Log b = Log.of("fuga");
        Log c = Log.of("piyo");

        // (a ・ b) ・ c = a ・ (b ・ c)
        assertEquals(
            (a.mappend(b)).mappend(c),
            a.mappend(b.mappend(c)));
    }
}

Writerの定義

Writer モナドの定義を参考にしながらWriterクラスを実装する。FunctionはGoogle guavaの定義を拝借。

import java.lang.reflect.Method;
import com.google.common.base.Function;

public class Writer<T, L extends Monoid<L>> {
    public final T value;
    public final L log;

    /**
     *  「値,ログ」という組を持つWriterを生成する。
     */
    public Writer(T value, L log) {
        this.value = value;
        this.log = log;
    }

    /** Writerモナドのbind */
    public <S> Writer<S, L> bind(Function<T, Writer<S, L>>f) {
        Writer<S, L> ret = f.apply(this.value);
        return new Writer<S, L>(ret.value, this.log.mappend(ret.log));
    }

    /** Writerモナドのunit */
    public static <T, L extends Monoid<L>> Writer<T, L> unit(T value, L... dummy) {
        L emptyLog = emptyLog((Class<L>) dummy.getClass().getComponentType());
        return new Writer<T, L>(value, emptyLog);
    }
    /** logClassからリフレクションで単位元を取得する。 */
    private static <L> L emptyLog(Class<L> logClass) {
        try {
            Method factory = logClass.getMethod(Monoid.EMPTY_NAME);
            return logClass.cast(factory.invoke(null));
        } catch (Exception e) {
            throw new IllegalArgumentException(
                "The parameter class " + logClass.getName() + "" +
                " not provides static " + Monoid.EMPTY_NAME + "()", e);
        }
    }

    /** fのapply()にthisを渡した結果を戻す。 */
    public <R> R apply(Function<Writer<T, L>, R> f) {
        return f.apply(this);
    }
}

unit()でログのクラスを取得するためにjavaのトリックを使っている。取得したクラスからリフレクションでmempty()を呼び出して単位元を取得する。
型パラメータに実クラスを指定して以下のように使える。

    Writer<Integer, Log> ten = Writer.unit(10);

FizzBuzzの実装

Writerに関する処理のチェーンとして実装して、最後に結果出力用の値を取り出す。
特に再利用とかもしないんでスクリプト的に書くよ。

import static org.junit.Assert.assertEquals;
import org.junit.Test;
import com.google.common.base.Function;

public class FizzBuzz {
    // Fizz用やBuzz用の関数を生成するメソッド
    Function<Integer, Writer<Integer, Log>>
            makeFizzBuzzOperator(final int divider, final String message) {

        if (divider == 0) {
            throw new IllegalArgumentException("divider should not be zero.");
        }
        return new Function<Integer, Writer<Integer, Log>>() {
            @Override
            public Writer<Integer, Log> apply(Integer value) {
                if (value % divider == 0) {
                    return new Writer<Integer, Log>(value, Log.of(message));
                } else {
                    return new Writer<Integer, Log>(value, Log.mempty());
                }
            }   
        };
    }
    Function<Integer, Writer<Integer, Log>> fizz = makeFizzBuzzOperator(3, "Fizz");
    Function<Integer, Writer<Integer, Log>> buzz = makeFizzBuzzOperator(5, "Buzz");
    // 結果出力用の関数。logがemptyならvalueを、さもなければlogの内容を出力。
    Function<Writer<Integer, Log>, Object> value =
        new Function<Writer<Integer, Log>, Object>() {
            @Override
            public Object apply(Writer<Integer, Log> writer) {
                return writer.log.isEmpty() ? writer.value : writer.log.toString();
            }
        };

    Object fizzBuzz(int n) {
        //「n fizz buzz value」に近い感じで、通常より直感的に表現!
        return Writer.<Integer, Log>unit(n).bind(fizz).bind(buzz).apply(value);
    }
    @Test
    public void testFizzBuzz() {
        assertEquals(1, fizzBuzz(1));
        assertEquals(2, fizzBuzz(2));
        assertEquals("Fizz", fizzBuzz(3));
        assertEquals(4, fizzBuzz(4));
        assertEquals("Buzz", fizzBuzz(5));
        assertEquals(7, fizzBuzz(7));
        assertEquals("FizzBuzz", fizzBuzz(15));

        // 表示してみる
        for (int n = 1; n <= 100; n++) {
            System.out.printf("%s\n", fizzBuzz(n));
        }
    }
}