tomabouの日記

勉強したことなどを書いていきます

Haskellで10を作るゲームを解く

はじめに

切符に書かれた四つの数字を使用して10を作る有名な遊びがあります*1
目的の数字を与えられた数字と四則演算を使用して作るプログラムを書いてみました。言語は最近ハマっているHaskellを使用してみました。

構想

式を定義
→与えられた数字から式を列挙
→式を計算
→目的の数字が出たら勝利
このような流れで書きます。工夫の余地がある点は式を列挙するパートで、できるだけ重複しないで列挙したいです。例えば(3+2)*4と4*(2+3)どちらも試すのはアホです。なので次のように正規形を定義して正規形のみ列挙することにします。(与えられた数字は以下狭義単調増加に並んでいるとします)

正規形::任意の演算子について左側にある数字の最小値<右側にある数字の最小値

例えば(1*3+2)*4は正規形ですが、(2+3*1)*4は正規形ではないです。任意の式を正規形にするためには3-2や3/2を2 [演算子] 3と書く必要があります。a - b = b $- a , a / b = b $/ a として新しい演算子を定めることで任意の式をくるくるひっくり返して正規形にすることができます。

実装

式の定義

data Equation = Number Integer | Operator String Equation Equation 

日本語で書けば式とは数字単体、もしくは一つの演算子と二つの式の組である。非常に明快です。二分木の変形ですね。

式の表示

instance Show Equation where 
    show (Number x) = show x
    show (Operator ('$':xs) a b) = "(" ++ show b ++ " " ++ xs ++" "++show a ++")"
    show (Operator xs a b) ="("++ (show a) ++" "++ xs++" " ++(show b)++")"

数字はそのまま表示し、演算子があったら括弧でくくってから左右の式の間にその演算子を書きます。$という記号がついた演算子は左右をひっくり返して表示します。

式の計算

calculate :: Equation -> Maybe Rational
calculate (Number x) = Just(x % 1)
calculate (Operator "+" a b) = (+) <$> calculate a <*> calculate b
calculate (Operator "*" a b) = (*) <$> calculate a <*> calculate b
calculate (Operator "-" a b) = (-) <$> calculate a <*> calculate b
calculate (Operator "$-" a b) = (-) <$> calculate b <*> calculate a
calculate (Operator "/" a b) 
    |calculate b /= Just 0 = (/) <$> calculate a <*> calculate b
    |otherwise = Nothing
calculate (Operator "$/" a b) 
    |calculate a /= Just 0 = (/) <$> calculate b <*> calculate a
    |otherwise = Nothing

Rationalは割り算を表す型です。数字だったらそのままにして演算子があったら左右にその演算子を適用し、$があったらひっくり返しています。割り算の失敗を表すためにMaybeモナドを用います。アプリカティブファンクターとして書いているところが個人的に綺麗で気に入っています。

式の列挙

makeEq :: [Integer] ->[Equation]
makeEq [] = []
makeEq [x] = [Number x]
makeEq (x:xs) = do
    let n = 1 + length xs
    i <-[0..(n-2)]
    ope <- ["+","*","-","$-","/","$/"]
    (ps,qs)<- combi i xs
    a <- makeEq (x:ps)
    b <- makeEq (qs)
    return (Operator ope a b)

リストモナドを用いて複数通りを列挙しています。リストの内包表記と本質的に同じものなのでリストモナドが良くわからなくても内包表記から類推すればなんとなくわかると思います。与えられた数字の一番最初は正規形の制約より必ず左側に渡します。任意に左側の数字の個数を選び、演算子を選び、先ほど決めた個数数字をチョイスして、左右に配置しています。チョイスする関数の実装は次のようなものです。

combi :: Int -> [a] -> [([a],[a])]
combi 0 xs = [([],xs)]
combi n (x:xs) 
    |length xs + 1 == n = [(x:xs,[])]
    |otherwise = [(x:ps,qs)|(ps,qs)<-combi (n-1) xs] 
        ++ [(ps,x:qs)|(ps,qs)<- combi n xs]

個数とリストを受け取って、その個数分取り出したものとその残りを全通り列挙しています。素直に再帰的に定義できていると思います。このとき順序を保っているのは重要です。常に使用する数字は最初の入力と同じ順序で並んでいるため先頭の数字を左に配置すれば正規形になるのです。

インターフェース

もう論理部分は完成しています。あとは入出力だけです。

main = forever $ do
    putStrLn "What Integer do you want to make?"
    i <- readLn ::IO Integer
    putStrLn "Input numbers like [1,2,3,4]"
    numbers <- readLn :: IO [Integer]
    let ans = headMaybe . (filter (\x -> calculate x == Just (i%1))) $ makeEq numbers
    putStrLn (show i ++" = " ++ showMaybe ans)
    putStrLn ""

実際に探索しているのはletの行だけです。headMaybeは長さゼロのリストに対してリストの先頭を取る関数を対応させたもの、showMaybeはMaybeモナドに文字列に変換する関数を対応させたものです。

headMaybe::[a] -> Maybe a
headMaybe [] = Nothing
headMaybe xs = Just (head xs) 

showMaybe ::(Show a)=> Maybe a -> String
showMaybe (Just a) = show a
showMaybe Nothing = "No Answer"

これで完成です!

テスト

難問の[1,1,5,8]を解かせてみましょう(答えを知らない人はネタバレに注意)

What Integer do you want to make?
10
Write numbers you should use like [1,2,3,4]
[1,1,5,8]
10 = (8 / (1 - (1 / 5)))

やりました!
速度としては遅延評価により答えが見つかればすぐに表示されるため、目的の数字が小さめですぐ答えが見つかるものは使用する数字が多くても大丈夫です。もしも作れない場合、使用する数字が6個を超えてくると全通り試してNoAnswerを表示するまでに非常に時間がかかってしまいます。組み合わせは爆発するものなので仕方ないですね。

感想、改善点

非常に直感的に書くことができたので非常に満足感がありました。計算式が本質的に二分木であることがここまで明示的に書けるとは想像以上に気持ち良いです。
改善点としてはまだ無駄な列挙をしている点です。入力に同じ数字がある場合は非常に無駄な探索をしてしまいます。また、結合則を無視しているため、(1+2)+3と1+(2+3)どちらも調べています。まあ速さを求めるならC++などで動的計画法を活用して書くべきである気がするのでそこまで気にしないでおいておきます。
コードの改善点がある場合ぜひ教えてください。
ソースとexeファイルを上げておきます。
findEquation - Google ドライブ

*1:[3,4,7,8] [1,1,5,8]などが難しいことで有名です

Haskellでフィボナッチ

去年の六月ごろに「すごいHaskell楽しく学ぼう」(以下すごいH)を買って読んでみようと思ったのですが、当時はほとんどプログラミングをしたことがなく、良くわからないまま終わりました。8か月ほど経って経験を積んだのでようやく本棚でほこりをかぶっていた本を読み始めてみました。
とりあえず再帰フィボナッチ数列を求めてみるコードを書いてみます

fibo1 ::(Num a)=> Int -> a
fibo1 0 = 1
fibo1 1 = 1
fibo1 n = fibo1 (n-1) + fibo1 (n-2)

いやぁ気持ちが良いですね
なんか漸化式をそのまま書くと動くので嬉しいです

このままだとn番目のフィボナッチ数を求めるオーダーが指数オーダーなので動的計画法っぽいことをして、O(n)で計算させてみます。

applyN :: Int ->(a->a) ->a ->a
applyN 0 _ a = a
applyN n f a = f $ applyN (n-1) f a

fibo2 :: Int -> Integer
fibo2 n = fst $ applyN n (\(a,b) -> (a+b, a)) (1,0)

ラムダ計算で書くとなんかそれっぽくてカッコいいですね
それにバグっていないという安心感がありますね
これが副作用がないということなんでしょうか
折角なので行列を繰り返し二乗法で計算してO(log n)で計算するやつをやってみました

data Mat a = Mat a a a a deriving (Show)

mmul:: (Num a)=>Mat a -> Mat a -> Mat a 
Mat x y z w `mmul` Mat i j k l 
    = Mat (x*i+y*k) (x*j+y*l) (z*i+w*k) (z*j+z*l)

getnum :: Mat a -> a
getnum (Mat x _ _ _) = x

fiboMat = Mat 1 1 1 0

appself ::(a-> a->a) -> a -> a
appself f x = f x x

foldn:: Int -> (a -> a-> a)-> a-> a
foldn 1 _ a = a
foldn n f a 
    |even n = appself f (foldn (n `div` 2) f a )
    |otherwise = f a $ appself f (foldn (div n 2) f a )

fibo3 :: Int -> Integer
fibo3 n = getnum (foldn n mmul fiboMat)

関数の名前を付けるセンスがないですねぇ。
いや非常にテンションが上がりますね。行列を定義するのが非常に自然です。
これめちゃくちゃ自然に書けますねぇ
appselfという関数が実は重要で、これを使わないと再帰展開が分岐してしまってO(n)になってしまいます(しまいました)
すごいH本を読み進めていきましょう

modP整数型(途中経過)

競技プログラミングには「答えが非常に大きくなるので1000000007で割った余りを答えなさい」という問題がよくあります
毎回剰余を求めるコードを書くのも面倒であるので次のようなことができるクラスを作ってみました

int main() {
	modP a ;
	a = 2;
	modP b;
	b = 5;
	cout << "a = "<<a << endl;
	cout << "b = "<<b << endl;
	cout << "a+b = "<<a + b << endl;
	cout << "a-b = "<<a - b << endl;
	cout << "a*b = "<<a*b << endl;
	cout << "a/b = "<<a / b << endl;
	cout << "a^5 = "<<(a^5) << endl;
	a.setP(7);
	cout << "P = 7" << endl;
	cout << "a+b = "<<a + b << endl;
	cout << "a-b = "<<a - b << endl;
	cout << "a*b = "<<a*b << endl;
	cout << "a/b = "<<a / b << endl;
	cout << "a^5 = "<<(a^5) << endl;
}

出力は

a = 2
b = 5
a+b = 7
a-b = 1000000004
a*b = 10
a/b = 800000006
a^5 = 32
P = 7
a+b = 0
a-b = 4
a*b = 3
a/b = 6
a^5 = 4

という感じです。
演算子オーバーロードを試してみたかったのです。
割り算を互除法で逆元を求めるように、xorを繰り返し二乗法で累乗を求めるように変えてみました。
以下ソースコード(初心者なので間違っているところを指摘していただけたら泣いて喜びます)

#include<iostream>
using namespace std;

class modP
{
public:
	long long value;
	modP();
	void setP(long long a) {
		modulo = a;
	}
	modP(const long long &a) { value = a; }
	modP(const int &a) { value = a; }
	modP operator+(modP &bar) { modP x; x.value = (this->value + bar.value) % modulo; return x; }
	modP operator++(int) { modP x = *this; this->value++; return x; }
	modP operator--(int){ modP x = *this; this->value--; return x; }
	modP& operator++() { this->value++; return *this; }
	modP& operator--() { this->value--; return *this; }
	modP operator-() const { modP x = *this; x.value = modulo - x.value; return x; }
	modP operator+() const { return *this; }
	explicit operator long long() const noexcept { return (this->value + modulo) % modulo; }
	explicit operator int() const noexcept { return(int)(this->value + modulo) % modulo; }
	modP& operator=(const int&a) { this->value = a%modulo; return *this; }
	modP& operator=(const long long&a) { this->value = a%modulo; return *this; }
	modP& operator=(const modP&bar) { this->value = bar.value; return *this; }
	//modP& operator=(T&&) noexcept;
	modP& operator*=(const modP&bar) { this->value = (this->value* bar.value) % modulo; return *this; }
	modP& operator*=(const int &bar) { this->value = (this->value* bar) % modulo; return *this; }
	modP& operator*=(const long long &bar) { this->value = (this->value* bar) % modulo; return *this; }
	modP& operator/=(const long long &);
	modP& operator/=(const modP&bar) { return operator/=((long long)bar); }
	modP& operator+=(const modP&bar) { this->value = (this->value+ bar.value) % modulo; return *this; }
	modP& operator+=(const int &bar) { this->value = (this->value+ bar) % modulo; return *this; }
	modP& operator+=(const long long &bar) { this->value = (this->value+ bar) % modulo; return *this; }
	modP& operator-=(const modP&bar) { this->value = (this->value- bar.value+modulo) % modulo; return *this; }
	modP& operator-=(const int &bar) { this->value = (this->value- bar+modulo) % modulo; return *this; }
	modP& operator-=(const long long &bar) { this->value = (this->value- bar+modulo) % modulo; return *this; }
	modP& operator^=(const long long &);
	friend modP operator*(const modP&, const modP&);
	friend modP operator/(const modP&, const modP&);
	friend modP operator+(const modP&, const modP&);
	friend modP operator-(const modP&, const modP&);
	friend modP operator*(const long long&, const modP&);
	friend modP operator/(const long long&, const modP&);
	friend modP operator+(const long long&, const modP&);
	friend modP operator-(const long long&, const modP&);
	friend modP operator*(const modP&, const long long&);
	friend modP operator/(const modP&, const long long&);
	friend modP operator+(const modP&, const long long&);
	friend modP operator-(const modP&, const long long&);
	friend modP operator^(const modP&, const long long&);
	friend ostream& operator<<(ostream& os, const modP& dt);
private:
	static long long modulo;
	long long extgcd(long long, long long, long long&, long long&);
};
long long modP::modulo;
modP operator*(const modP& t1, const modP& t2) { return modP(t1) *= t2; }
modP operator/(const modP& t1, const modP& t2) { return modP(t1) /= t2; }
modP operator+(const modP& t1, const modP& t2) { return modP(t1) += t2; }
modP operator-(const modP& t1, const modP& t2) { return modP(t1) -= t2; }
modP operator*(const modP& t1, const long long& t2) { return modP(t1) *= t2; }
modP operator/(const modP& t1, const long long& t2) { return modP(t1) /= t2; }
modP operator+(const modP& t1, const long long& t2) { return modP(t1) += t2; }
modP operator-(const modP& t1, const long long& t2) { return modP(t1) -= t2; }
modP operator*(const modP& t1, const int& t2) { return modP(t1) *= t2; }
modP operator/(const modP& t1, const int& t2) { return modP(t1) /= t2; }
modP operator+(const modP& t1, const int& t2) { return modP(t1) += t2; }
modP operator-(const modP& t1, const int& t2) { return modP(t1) -= t2; }
modP operator^(const modP& t1, const int& t2) { return modP(t1) ^= t2; }
modP operator^(const modP& t1, const long long& t2) { return modP(t1) ^= t2; }
modP operator*(const long long& t1, const modP& t2) { return modP(t1) *= t2; }
modP operator/(const long long& t1, const modP& t2) { return modP(t1) /= t2; }
modP operator+(const long long& t1, const modP& t2) { return modP(t1) += t2; }
modP operator-(const long long& t1, const modP& t2) { return modP(t1) -= t2; }
modP::modP()
{
	if (modulo == 0)modulo = 1000000007;
}
long long modP::extgcd(long long a, long long b, long long &x, long long &y) {
	long long d = a;
	if (b != 0) {
		d = extgcd(b, a%b, y, x);
		y -= (a / b)*x;
	}
	else x = 1, y = 0;
	return d;
}
modP& modP::operator/=(const long long &bar) {
	long long x;
	long long y;
	extgcd(modulo, bar, x, y);
	this->value = (this->value*y) % modulo;
	return *this;
}
modP& modP::operator^=(const long long &n) {
	long long a = this->value;
	long long x = n % (modulo - 1);
	long long ans = 1;
	while (x > 0) {
		if ((x & 1) == 1) {
			ans = (ans*a) % modulo;
		}
		a *= a;
		x=x >> 1;
	}
	this->value = ans;
	return *this;
}
ostream& operator<<(ostream& os, const modP& dt)
{
	os << (long long)dt;
	return os;
}
modP combination(int a, int b) {
	modP ans = 1;
	if (a - b < b)b = a - b;
	if (b < 0) {
		ans = 0;
		return ans;
	}
	for (int i = a - b + 1; i <= a; i++) {
		ans *= i;
	}
	for (int i = 1; i <= b; i++) {
		ans /= i;
	}
	return ans;
}
int main() {
	modP a ;
	a = 2;
	modP b;
	b = 5;
	cout << "a = "<<a << endl;
	cout << "b = "<<b << endl;
	cout << "a+b = "<<a + b << endl;
	cout << "a-b = "<<a - b << endl;
	cout << "a*b = "<<a*b << endl;
	cout << "a/b = "<<a / b << endl;
	cout << "a^5 = "<<(a^5) << endl;
	a.setP(7);
	cout << "P = 7" << endl;
	cout << "a+b = "<<a + b << endl;
	cout << "a-b = "<<a - b << endl;
	cout << "a*b = "<<a*b << endl;
	cout << "a/b = "<<a / b << endl;
	cout << "a^5 = "<<(a^5) << endl;
}