線形漸化式のn項目を$O(n (\log n)^2 )$で計算

目的

線形漸化式$\sum_{i=0}^{k+1}c_ia_{i+j}=0\mod m$(jは任意の整数)で得られる数列{a}のn項目$a_n$を$O(n (\log n)^2 )$で求めます。

計算量

$O(n (\log n)^2 )$

使い方

long[] init
最初のk+1項です。$init[i]=a_i$としてください。
long[] coe
線形漸化式の係数です。$coe[i]=c_i$としてください。
long MODULO
目的の項の$m$に相当します。
long nthMod(long[] coe, long[] init, long n, long MODULO)
init,coe,MODULOで与えられる数列のn項目を計算します。$MODULO≦10^9$としてください。

ソースコード

import java.io.FileNotFoundException;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Scanner;

long nthMod(long[] coe, long[] init, long n, long MODULO) {
	Poly ret = new Poly(MODULO, new long[] { 1 });
	Poly x = new Poly(MODULO, new long[] { 0, 1 });// x^n
	Poly mod = new Poly(MODULO, coe);
	long n_ = n;
	for (; n_ > 0; n_ >>= 1) {
		if (n_ % 2 == 1) {
			ret.mul(x);
			ret.mod(mod);
		}
		x.mul(x);
		x.mod(mod);
	}
	long ans = 0;
	for (int j = 0; j < ret.intLen; ++j) {
		ans += ret.val[j] * init[j] % MODULO;
		ans %= MODULO;
	}
	return ans;
}

long garner(long[] x, long[] mod, long MOD) {
	int n = x.length;
	long[] gamma = new long[n];
	for (int i = 0; i < n; ++i) {
		long prd = 1;
		for (int j = 0; j < i; ++j) {
			prd = prd * mod[j] % mod[i];
		}
		gamma[i] = inv(prd, mod[i]);
	}

	long[] v = new long[n];
	v[0] = x[0];
	for (int i = 1; i < n; ++i) {
		long tmp = v[i - 1];
		for (int j = i - 2; j >= 0; --j) {
			tmp = (tmp * mod[j] % mod[i] + v[j]) % mod[i];
		}
		v[i] = (x[i] - tmp) * gamma[i] % mod[i];
		while (v[i] < 0)
			v[i] += mod[i];
	}
	long ret = 0;
	for (int i = v.length - 1; i >= 0; --i) {
		ret = (ret * mod[i] % MOD + v[i]) % MOD;
	}
	return ret;
}

public static long inv(long a, long mod) {
	long b = mod;
	long p = 1, q = 0;
	while (b > 0) {
		long c = a / b;
		long d;
		d = a;
		a = b;
		b = d % b;
		d = p;
		p = q;
		q = d - c * q;
	}
	return p < 0 ? p + mod : p;
}

long pow(long a, long n, long MODULO) {
	long ret = 1;
	for (; n > 0; n >>= 1, a = a * a % MODULO) {
		if (n % 2 == 1)
			ret = ret * a % MODULO;
	}
	return ret;
}

class Poly {
	long[] NTTMODULO = new long[] { 924844033, 962592769, 975175681, 950009857 };
	long[] NTTROOT = new long[] { 5, 7, 17, 7 };

	long[] val;
	long MODULO = -1;
	long root = -1;
	int intLen = 0;

	// 任意剰余で計算する
	public Poly(long MODULO_, long[] vs) {
		val = Arrays.copyOf(vs, vs.length);
		MODULO = MODULO_;
		intLen = val.length;
		for (int i = 0; i < val.length; ++i) {
			val[i] %= MODULO_;
			if (val[i] < 0)
				val[i] += MODULO_;
		}
	}

	public Poly(long MODULO_, long root_, long[] vs) {
		val = Arrays.copyOf(vs, vs.length);
		intLen = val.length;
		MODULO = MODULO_;
		root = root_;
		for (int i = 0; i < val.length; ++i) {
			val[i] %= MODULO_;
			if (val[i] < 0)
				val[i] += MODULO_;
		}
	}

	void add(Poly a) {
		if (val.length < a.intLen) {
			val = Arrays.copyOf(val, a.val.length);
		}
		for (int i = 0; i < a.intLen; ++i) {
			val[i] += a.val[i];
			if (val[i] >= MODULO)
				val[i] -= MODULO;
			if (val[i] < 0)
				val[i] += MODULO;
		}
		intLen = 0;
		for (int i = 0; i < val.length; ++i) {
			if (i + 1 > intLen && val[i] != 0)
				intLen = i + 1;
		}
	}

	void sub(Poly a) {
		if (a.intLen > val.length) {
			val = Arrays.copyOf(val, a.intLen);
		}
		for (int i = 0; i < a.intLen; ++i) {
			val[i] -= a.val[i];
			if (val[i] < 0)
				val[i] += MODULO;
			if (val[i] >= MODULO)
				val[i] -= MODULO;
		}
		intLen = 0;
		for (int i = 0; i < val.length; ++i) {
			if (val[i] != 0 && intLen < i + 1)
				intLen = i + 1;
		}
	}

	void mulFFTwithAnyMOD(Poly a) {
		Poly[] ps1 = new Poly[3];
		Poly[] ps2 = new Poly[3];
		int maxLen = 0;
		for (int i = 0; i < ps1.length; ++i) {
			ps1[i] = new Poly(NTTMODULO[i], NTTROOT[i], val);
			ps2[i] = new Poly(NTTMODULO[i], NTTROOT[i], a.val);
			ps1[i].mul(ps2[i]);
			maxLen = Math.max(maxLen, ps1[i].intLen);
		}

		val = new long[maxLen];
		for (int i = 0; i < maxLen; ++i) {
			val[i] = garner(new long[] { ps1[0].val[i], ps1[1].val[i], ps1[2].val[i] }, NTTMODULO, MODULO);
		}
		update();
	}

	void mulNaive(Poly a) {
		if (intLen == 0 || a.intLen == 0) {
			val = new long[] { 0 };
			update();
			return;
		}
		long[] nv = new long[intLen + a.intLen - 1];
		for (int i = 0; i < intLen; ++i) {
			for (int j = 0; j < a.intLen; ++j) {
				nv[i + j] += val[i] % MODULO * a.val[j] % MODULO;
				nv[i + j] %= MODULO;
				if (nv[i + j] < 0)
					nv[i + j] += MODULO;
			}
		}
		val = nv;
		update();
	}

	void mul(Poly a) {
		if (intLen + a.intLen > 80) {
			mulFFT(a);
		} else  {
			mulNaive(a);
		}
	}

	void mulFFT(Poly a) {
		if (root == -1) {
			mulFFTwithAnyMOD(a);
			return;
		}

		if (a.intLen == 0 || intLen == 0) {
			intLen = 0;
			val = new long[] { 0 };
			return;
		}
		val = mul(val, a.val);
		update();
		return;
	}

	void update() {
		intLen = 0;
		for (int i = 0; i < val.length; ++i) {
			if (val[i] != 0 && intLen < i + 1)
				intLen = i + 1;
		}
	}

	// Newton method
	// Karatsuba:n^1.58
	// FFT:nlgn
	void div(Poly b) {
		b.monic();
		if (b.intLen == 0)
			throw new ArithmeticException();
		if (b.intLen > intLen) {
			val = new long[] { 0 };
			intLen = 0;
			return;
		}
		int n = intLen - 1;
		int m = b.intLen - 1;
		b.rev(m + 1);
		Poly a = new Poly(MODULO, root, new long[] { 1 });
		for (int t = 1; t < n - m + 1; t *= 2) {
			Poly tmp = a.copy();
			tmp.mul(a);
			tmp.mul(b);
			a.mul(new Poly(MODULO, root, new long[] { 2 }));
			a.sub(tmp);
			if (a.intLen > n - m + 1) {
				a.intLen = n - m + 1;
				a.val = Arrays.copyOf(a.val, a.intLen);
			}
		}
		rev(n + 1);
		mul(a);
		rev(n - m + 1);
		if (intLen - 1 > n - m) {
			intLen = n - m + 1;
			val = Arrays.copyOf(val, intLen);
		}
		b.rev(m + 1);
		return;
	}

	void mod(Poly mod) {
		Poly tmp = copy();
		tmp.div(mod);
		tmp.mul(mod);
		sub(tmp);
		val = Arrays.copyOf(val, intLen);
	}

	int compare(Poly a) {
		if (intLen != a.intLen) {
			return Integer.compare(intLen, a.intLen);
		}
		for (int i = intLen - 1; i >= 0; --i) {
			if (val[i] == a.val[i])
				continue;
			return Double.compare(val[i], a.val[i]);
		}
		return 0;
	}

	void shift(int k) {
		if (intLen == 0)
			return;
		if (intLen + k <= 0) {
			val = new long[] { 0 };
			intLen = 0;
			return;
		}
		Poly u = copy();
		val = new long[intLen + k];
		if (k >= 0)
			System.arraycopy(u.val, 0, val, k, intLen);
		else 
			System.arraycopy(u.val, -k, val, 0, intLen + k);
		intLen += k;
	}

	void monic() {
		if (intLen == 0)
			return;
		long v = inv(val[intLen - 1], MODULO);
		for (int i = 0; i < intLen; ++i) {
			val[i] *= v;
			val[i] %= MODULO;
		}
	}

	void rev(int Len) {
		if (intLen < Len)
			val = Arrays.copyOf(val, Len);
		int s = 0;
		int t = Len;
		while (t - s > 0) {
			long d = val[s];
			val[s] = val[t - 1];
			val[t - 1] = d;
			++s;
			--t;
		}
		intLen = 0;
		for (int i = val.length - 1; i >= 0; --i) {
			if (val[i] != 0) {
				intLen = i + 1;
				break;
			}
		}
	}

	Poly copy() {
		Poly ret = new Poly(MODULO, root, new long[] { 0 });
		ret.val = Arrays.copyOf(val, val.length);
		ret.intLen = intLen;
		return ret;
	}

	long[] mul(long[] a, long[] b) {
		int n = Integer.highestOneBit(a.length + b.length) << 1;
		a = Arrays.copyOf(a, n);
		b = Arrays.copyOf(b, n);
		a = fft(a, false);
		b = fft(b, false);
		long ninv = pow(n, MODULO - 2);
		for (int i = 0; i < n; ++i) {
			a[i] = a[i] * b[i] % MODULO;
		}
		a = fft(a, true);
		for (int i = 0; i < n; ++i)
			a[i] = a[i] * ninv % MODULO;
		return a;
	}

	long[] fft(long[] a, boolean inv) {
		int n = a.length;
		int c = 0;
		for (int i = 1; i < n; ++i) {
			for (int j = n >> 1; j > (c ^= j); j >>= 1)
				;
			if (c > i) {
				long d = a[c];
				a[c] = a[i];
				a[i] = d;
			}
		}

		for (int i = 1; i < n; i <<= 1) {
			long w = pow(root, (MODULO - 1) / (2 * i));
			if (inv)
				w = pow(w, MODULO - 2);
			for (int j = 0; j < n; j += 2 * i) {
				long wn = 1;
				for (int k = 0; k < i; ++k) {
					long u = a[k + j];
					long v = a[k + j + i] * wn % MODULO;
					a[k + j] = (u + v) % MODULO;
					a[k + j + i] = (u - v + MODULO) % MODULO;
					wn = wn * w % MODULO;
				}
			}
		}
		return a;
	}

	long pow(long a, long n) {
		long ret = 1;
		for (; n > 0; n >>= 1, a = a * a % MODULO) {
			if (n % 2 == 1)
				ret = ret * a % MODULO;
		}
		return ret;
	}

	void check() {
		for (int i = val.length - 1; i >= 0; --i) {
			if (val[i] != 0 && i + 1 > intLen) {
				throw new AssertionError();
			}
		}
	}
}

long gcd(long a, long b) {
	if (a > b) {
		return gcd(b, a);
	}
	if (a == 0)
		return b;
	return gcd(b % a, a);
}