読者です 読者をやめる 読者になる 読者になる

tomabouの日記

Haskellなどで勉強したことなどを書いていきます

modP整数型(途中経過)

競技プログラミングには「答えが非常に大きくなるので1000000007で割った余りを答えなさい」という問題がよくあります
毎回剰余を求めるコードを書くのも面倒であるので次のようなことができるクラスを作ってみました

int main() {
	modP a ;
	a = 2;
	modP b;
	b = 5;
	cout << "a = "<<a << endl;
	cout << "b = "<<b << endl;
	cout << "a+b = "<<a + b << endl;
	cout << "a-b = "<<a - b << endl;
	cout << "a*b = "<<a*b << endl;
	cout << "a/b = "<<a / b << endl;
	cout << "a^5 = "<<(a^5) << endl;
	a.setP(7);
	cout << "P = 7" << endl;
	cout << "a+b = "<<a + b << endl;
	cout << "a-b = "<<a - b << endl;
	cout << "a*b = "<<a*b << endl;
	cout << "a/b = "<<a / b << endl;
	cout << "a^5 = "<<(a^5) << endl;
}

出力は

a = 2
b = 5
a+b = 7
a-b = 1000000004
a*b = 10
a/b = 800000006
a^5 = 32
P = 7
a+b = 0
a-b = 4
a*b = 3
a/b = 6
a^5 = 4

という感じです。
演算子オーバーロードを試してみたかったのです。
割り算を互除法で逆元を求めるように、xorを繰り返し二乗法で累乗を求めるように変えてみました。
以下ソースコード(初心者なので間違っているところを指摘していただけたら泣いて喜びます)

#include<iostream>
using namespace std;

class modP
{
public:
	long long value;
	modP();
	void setP(long long a) {
		modulo = a;
	}
	modP(const long long &a) { value = a; }
	modP(const int &a) { value = a; }
	modP operator+(modP &bar) { modP x; x.value = (this->value + bar.value) % modulo; return x; }
	modP operator++(int) { modP x = *this; this->value++; return x; }
	modP operator--(int){ modP x = *this; this->value--; return x; }
	modP& operator++() { this->value++; return *this; }
	modP& operator--() { this->value--; return *this; }
	modP operator-() const { modP x = *this; x.value = modulo - x.value; return x; }
	modP operator+() const { return *this; }
	explicit operator long long() const noexcept { return (this->value + modulo) % modulo; }
	explicit operator int() const noexcept { return(int)(this->value + modulo) % modulo; }
	modP& operator=(const int&a) { this->value = a%modulo; return *this; }
	modP& operator=(const long long&a) { this->value = a%modulo; return *this; }
	modP& operator=(const modP&bar) { this->value = bar.value; return *this; }
	//modP& operator=(T&&) noexcept;
	modP& operator*=(const modP&bar) { this->value = (this->value* bar.value) % modulo; return *this; }
	modP& operator*=(const int &bar) { this->value = (this->value* bar) % modulo; return *this; }
	modP& operator*=(const long long &bar) { this->value = (this->value* bar) % modulo; return *this; }
	modP& operator/=(const long long &);
	modP& operator/=(const modP&bar) { return operator/=((long long)bar); }
	modP& operator+=(const modP&bar) { this->value = (this->value+ bar.value) % modulo; return *this; }
	modP& operator+=(const int &bar) { this->value = (this->value+ bar) % modulo; return *this; }
	modP& operator+=(const long long &bar) { this->value = (this->value+ bar) % modulo; return *this; }
	modP& operator-=(const modP&bar) { this->value = (this->value- bar.value+modulo) % modulo; return *this; }
	modP& operator-=(const int &bar) { this->value = (this->value- bar+modulo) % modulo; return *this; }
	modP& operator-=(const long long &bar) { this->value = (this->value- bar+modulo) % modulo; return *this; }
	modP& operator^=(const long long &);
	friend modP operator*(const modP&, const modP&);
	friend modP operator/(const modP&, const modP&);
	friend modP operator+(const modP&, const modP&);
	friend modP operator-(const modP&, const modP&);
	friend modP operator*(const long long&, const modP&);
	friend modP operator/(const long long&, const modP&);
	friend modP operator+(const long long&, const modP&);
	friend modP operator-(const long long&, const modP&);
	friend modP operator*(const modP&, const long long&);
	friend modP operator/(const modP&, const long long&);
	friend modP operator+(const modP&, const long long&);
	friend modP operator-(const modP&, const long long&);
	friend modP operator^(const modP&, const long long&);
	friend ostream& operator<<(ostream& os, const modP& dt);
private:
	static long long modulo;
	long long extgcd(long long, long long, long long&, long long&);
};
long long modP::modulo;
modP operator*(const modP& t1, const modP& t2) { return modP(t1) *= t2; }
modP operator/(const modP& t1, const modP& t2) { return modP(t1) /= t2; }
modP operator+(const modP& t1, const modP& t2) { return modP(t1) += t2; }
modP operator-(const modP& t1, const modP& t2) { return modP(t1) -= t2; }
modP operator*(const modP& t1, const long long& t2) { return modP(t1) *= t2; }
modP operator/(const modP& t1, const long long& t2) { return modP(t1) /= t2; }
modP operator+(const modP& t1, const long long& t2) { return modP(t1) += t2; }
modP operator-(const modP& t1, const long long& t2) { return modP(t1) -= t2; }
modP operator*(const modP& t1, const int& t2) { return modP(t1) *= t2; }
modP operator/(const modP& t1, const int& t2) { return modP(t1) /= t2; }
modP operator+(const modP& t1, const int& t2) { return modP(t1) += t2; }
modP operator-(const modP& t1, const int& t2) { return modP(t1) -= t2; }
modP operator^(const modP& t1, const int& t2) { return modP(t1) ^= t2; }
modP operator^(const modP& t1, const long long& t2) { return modP(t1) ^= t2; }
modP operator*(const long long& t1, const modP& t2) { return modP(t1) *= t2; }
modP operator/(const long long& t1, const modP& t2) { return modP(t1) /= t2; }
modP operator+(const long long& t1, const modP& t2) { return modP(t1) += t2; }
modP operator-(const long long& t1, const modP& t2) { return modP(t1) -= t2; }
modP::modP()
{
	if (modulo == 0)modulo = 1000000007;
}
long long modP::extgcd(long long a, long long b, long long &x, long long &y) {
	long long d = a;
	if (b != 0) {
		d = extgcd(b, a%b, y, x);
		y -= (a / b)*x;
	}
	else x = 1, y = 0;
	return d;
}
modP& modP::operator/=(const long long &bar) {
	long long x;
	long long y;
	extgcd(modulo, bar, x, y);
	this->value = (this->value*y) % modulo;
	return *this;
}
modP& modP::operator^=(const long long &n) {
	long long a = this->value;
	long long x = n % (modulo - 1);
	long long ans = 1;
	while (x > 0) {
		if ((x & 1) == 1) {
			ans = (ans*a) % modulo;
		}
		a *= a;
		x=x >> 1;
	}
	this->value = ans;
	return *this;
}
ostream& operator<<(ostream& os, const modP& dt)
{
	os << (long long)dt;
	return os;
}
modP combination(int a, int b) {
	modP ans = 1;
	if (a - b < b)b = a - b;
	if (b < 0) {
		ans = 0;
		return ans;
	}
	for (int i = a - b + 1; i <= a; i++) {
		ans *= i;
	}
	for (int i = 1; i <= b; i++) {
		ans /= i;
	}
	return ans;
}
int main() {
	modP a ;
	a = 2;
	modP b;
	b = 5;
	cout << "a = "<<a << endl;
	cout << "b = "<<b << endl;
	cout << "a+b = "<<a + b << endl;
	cout << "a-b = "<<a - b << endl;
	cout << "a*b = "<<a*b << endl;
	cout << "a/b = "<<a / b << endl;
	cout << "a^5 = "<<(a^5) << endl;
	a.setP(7);
	cout << "P = 7" << endl;
	cout << "a+b = "<<a + b << endl;
	cout << "a-b = "<<a - b << endl;
	cout << "a*b = "<<a*b << endl;
	cout << "a/b = "<<a / b << endl;
	cout << "a^5 = "<<(a^5) << endl;
}