题目大意:
给一个含有 $n$ 个元素的集合,里面的元素为 $[1,n]$ 内的所有元素,请选出一个大小为 $m$ 的子集,使其乘积不含有完全平方因子,求方案数 $\text{mod }10^9+7$。
$\texttt{Data Range:}1\le n,m\le 500$
其实看到这种 $500$ 的数据范围很难想到状压 dp 的。
首先题目的要求等价与这个子集的乘积的所有质因子次数小于 $2$。首先排除那些一开始就含有完全平方因子的数。
考虑一个最 naive 的 dp:$f_{i,S}$ 表示当前子集已经选了 $i$ 个数,当前子集所有数乘积中含有的质因子集合为 $S$ 的方案数。
相信来做这题的都会转移我就懒得写了
但我们杯具地发现 $500$ 以内地质数高达 $95$ 个,时空爆炸好耶
考虑根号分治优化:
$\le \sqrt{n}$ 的质数个数不超过 $8$ 个,所以先不考虑那些含有 $>\sqrt{n}$ 的质因子的数直接 dp。
一个 $\le n$ 的数至多有一个 $>\sqrt{n}$ 的质因子,利用这个性质,我们可以把所有 $i$ 的倍数($i$ 是质数且 $i>\sqrt{n}$)放到第 $i$ 个vector
里面。
做完 $\le \sqrt{n}$ 的质因子的 dp 之后,我们还要再做第二次 dp:遍历每个vector
更新 $f$ 数组。
所有含有质因子 $i$ 的数中只能选出一个,所以我在做第二次 dp 时为了防止转移混乱实现时先把 $f$ 复制给了 $g$,做完 dp 后再memcpy
回来。
另外我开始 dp 时没有管 $1$,输出答案时直接统计的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| #include <cstdio> #include <vector> #include <cstring> #include <cmath>
const int mod = 1e9 + 7; int Prime[105], f[505][1 << 8], g[505][1 << 8], Set[505], key[505], cnt, S; std::vector<int> vec[505]; bool mark[505], flag[505];
int main() { int n, m; scanf("%d%d", &n, &m); if (m > n) m = n; S = floor(sqrt(n)); for (int i = 2; i <= S; ++ i) if (!mark[i]) { Prime[++ cnt] = i, key[i] = cnt; for (int j = i * i; j <= S; j += i) mark[j] = true; } for (int i = 2; i <= n; ++ i) { int k = i; for (int j = 1; Prime[j] * Prime[j] <= k && j <= cnt; ++ j) { if (k % Prime[j] == 0) Set[i] |= 1 << j - 1, k /= Prime[j]; if (k % Prime[j] == 0) {flag[i] = true; break;} } if (k != 1 && !flag[i]) { if (k <= S) Set[i] |= 1 << key[k] - 1; else flag[i] = true, vec[k].push_back(i); } } f[0][0] = 1; for (int j = 2; j <= n; ++ j) if (!flag[j]) for (int i = m - 1; i >= 0; -- i) for (int S = 0; S < 1 << cnt; ++ S) if (!(S & Set[j])) f[i + 1][S | Set[j]] = (f[i + 1][S | Set[j]] + f[i][S]) % mod; for (int j = S + 1; j <= n; ++ j) { memcpy(g, f, sizeof f); for (int k : vec[j]) for (int i = m - 1; i >= 0; -- i) for (int S = 0; S < 1 << cnt; ++ S) if (!(S & Set[k])) g[i + 1][S | Set[k]] = (g[i + 1][S | Set[k]] + f[i][S]) % mod; memcpy(f, g, sizeof f); } int ans = 0; for (int i = 1; i <= m; ++ i) for (int S = 0; S < 1 << cnt; ++ S) if (i != m) ans = (ans + f[i][S] * 2 % mod) % mod; else ans = (ans + f[i][S]) % mod; printf("%d", (ans + 1) % mod); return 0; }
|