给定 $n$ 个数 $a_i$,求把 $n$ 个数分为三个集合,且前两个集合异或值相同的方案数。集合可区分而集合内的元素不可区分。
$1\le a_i,n\le 10^6$
题目链接
思路
首先可以 dp,$f_{i,j}$ 为考虑了前 $i$ 个数,且前两个集合异或值为 $j$ 的方案数:
$$
f_{i,j}=f_{i-1,j}+2f_{i-1,j\otimes a_i}
$$
貌似没法优化了?我们把它写成 FWT 的形式:
$$
\begin{aligned}
c_i&=1+2x^{a_i}\\
F_i&=F_{i-1}\otimes c_i
\end{aligned}
$$
中间那个 $\otimes$ 是异或 FWT。注意到我们 $c_i$ 的项数很少,我们把 $c_i$ 的 FWT 拆开。首先异或 FWT 的意义是:
$$
\operatorname{FWT}(a)_i=\sum_j a_j\cdot (-1)^{|i\&j|}
$$
我们发现由于 0 与任何数位与都是 0,所以 $\operatorname{FWT}(c_i)$ 首先每一项都有个 1,然后 2 的符号就是任意了。所以我们知道 $\operatorname{FWT}(c_i)$ 仅由 -1 和 3 组成。那么也就是说我们现在在算很多个这样数组的点积,于是考虑统计 -1 和 3 的个数。我们现在考虑 $i$ 这一位的 -1 和 3 的个数($\operatorname{FWT}(a_{?,i})$)。记这一位所有数和为 $sum_i$,我们注意到,$3cnt_{i,3}-cnt_{i,-1}=sum$,$cnt_{i,-1}=n-cnt_{i,3}$,所以只要知道 $sum_i$ 就可以算 $cnt_{i,3}$ 和 $cnt_{i,-1}$。于是有:
$$
\begin{aligned}
sum_i&=\sum_j\operatorname{FWT}(c_j)_i\\
&=\sum_j1+2\cdot (-1)^{|i\&a_j|}
\end{aligned}
$$
欸我们一看!右边那个东西可以统一用一遍 FWT 弄出来。没了。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
| #include <mivik.h>
MI cin;
const int nmax = 1048576; const int mod = 998244353;
int n, p3[nmax], c[nmax]; inline int pro(int x) { return x >= mod? x - mod: x; } inline int per(int x) { return x < 0? x + mod: x; } inline int n1(int v, int p) { return (p & 1)? per(-v): v; } inline int div2(int x) { return ((x & 1)? x + mod: x) >> 1; } inline int round_up(int x) { return 1 << (32 - __builtin_clz(x)); } template<bool rev> inline void fwt(int *v, int len) { for (int i = 1, q = 2; i < len; q = (i = q) << 1) for (int j = 0; j < len; j += q) for (int k = 0; k < i; ++k) { const int x(v[j | k]), y(v[i | j | k]); v[j | k] = pro(x + y); v[i | j | k] = per(x - y); if (rev) { v[j | k] = div2(v[j | k]); v[i | j | k] = div2(v[i | j | k]); } } } int main() { cin > n; int lim = 0; for (int i = p3[0] = 1; i <= n; ++i) { p3[i] = pro(p3[i - 1] + pro(p3[i - 1] << 1)); const int x(R); c[x] += 2; if (x > lim) lim = x; } lim = round_up(lim); fwt<0>(c, lim); for (int i = 0; i < lim; ++i) { const int sum(pro(n + c[i])); const int c3(pro(sum + n) >> 2), cn1(n - c3); c[i] = n1(p3[c3], cn1); } fwt<1>(c, lim); cout < per(c[0] - 1) < endl; }
|