博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【知识总结】多项式全家桶(三点五)(拆系数解决任意模数多项式卷积)
阅读量:4982 次
发布时间:2019-06-12

本文共 5158 字,大约阅读时间需要 17 分钟。

上一篇:

5c88b1c69b022.gif

(请无视此图)

我最近学了一个常数小还不用背三个模数的做法:拆系数法

(以下默认多项式项数 \(N=10^5\) ,系数不超过 \(M=10^9\) 且为非负整数)

我们放弃「数论变换」「利用原根性质」之类的想法,来点简单粗暴的:用实数 FFT 把原始结果算出来,然后直接取模。

为什么过去我们没有这样做呢?因为卷积结果的系数最大可能达到 \(NM^2=10^{24}\) ,long double 的精度也不够(通常情况下 double 的精度约为 15 位十进制,long double 的精度约为 19 位十进制)。考虑「拆系数」来牺牲时间保证精度。

设相乘的两个多项式为 \(A(x)=\sum a_ix^i\)\(B(x)=\sum b_ix^i\) ,结果为 \(C(x)=\sum c_ix^i\) 。把\(A(x)\) 拆成两个多项式 \(A_0(x)=\sum {a_0}_ix^i\)\(A_1(x)=\sum {a_1}_ix^i\) ,其中 \({a_1}_i=\lfloor\frac{a_i}{S}\rfloor\)\({a_0}_i=a_i-S\cdot {a_1}_i\) ,即 \(a_i=S\cdot {a_1}_i+{a_0}_i\) 。对 \(B(x)\) 也作同样的操作。这样,就有

\[\begin{aligned} a_ib_j&=(S\cdot {a_1}_i+{a_0}_i)(S\cdot {b_1}_j+{b_0}_j)\\ &={a_1}_i{b_1}_jS^2+({a_1}_i{b_0}_j+{a_0}_i{b_1}_j)S+{a_0}_i{b_0}_j \end{aligned}\]

于是直接计算 \(C_1=A_1*B_1\)\(C_2=A_1*B_0+A_0*B_1\)\(C_3=A_0*B_0\) ,然后 \(c_i={c_1}_iS^2+{c_2}_iS+{c_3}_i\) ,算最后一步的时候对「任意模数」取模即可。

这样,大致估计一下 \(C_1\) 的最大系数是 \(N\cdot(\frac{M}{S})^2=\frac{NM^2}{S^2}\)\(C_2\) 最大系数是 \(2N\cdot\frac{M}{S}\cdot S=2NM\)\(C_3\) 最大系数是 \(NS^2\) 。当 \(S=\sqrt{M}\) 时,以上三项均为 \(NM\) 级别,即 \(10^{14}\) 左右,足以保证精度(跟瓜学的一般偷懒直接 \(S=32768\) )。

代码:

事实上由于只跟 7 个多项式有关,所以只需要进行 7 次 FFT 。我写的常数特别大了不要跟我学 ……

题目:

#include 
#include
#include
#include
#include
using namespace std;namespace zyt{ template
inline bool read(T &x) { char c; bool f = false; x = 0; do c = getchar(); while (c != EOF && c != '-' && !isdigit(c)); if (c == EOF) return false; if (c == '-') f = true, c = getchar(); do x = x * 10 + c - '0', c = getchar(); while (isdigit(c)); if (f) x = -x; return true; } template
inline void write(T x) { static char buf[20]; char *pos = buf; if (x < 0) putchar('-'), x = -x; do *pos++ = x % 10 + '0'; while (x /= 10); while (pos > buf) putchar(*--pos); } const int N = 1e5 + 10, S = 1 << 15, p = 1e9 + 7, B = 18; typedef long double ld; typedef long long ll; int power(int a, int b) { int ans = 1; while (b) { if (b & 1) ans = (ll)ans * a % p; a = (ll)a * a % p; b >>= 1; } return ans; } int get_inv(const int a) { return power(a, p - 2); } ll dtol(const ld x) { return ll(fabs(x) + 0.5) * (x < 0 ? -1 : 1); } namespace Polynomial { const int LEN = 1 << B; const ld PI = acos(-1.0L); struct cpx { ld x, y; cpx(const ld _x = 0, const ld _y = 0) : x(_x), y(_y) {} cpx conj() { return cpx(x, -y); } }omega[LEN], winv[LEN]; ll ctol(const cpx &a) { return dtol(a.x); } int rev[LEN]; cpx operator + (const cpx &a, const cpx &b) { return cpx(a.x + b.x, a.y + b.y); } cpx operator - (const cpx &a, const cpx &b) { return cpx(a.x - b.x, a.y - b.y); } cpx operator * (const cpx &a, const cpx &b) { return cpx(a.x * b.x - a.y * b.y, a.y * b.x + a.x * b.y); } void init(const int n, const int lg2) { cpx w = cpx(cos(2.0L * PI / n), sin(2.0L * PI / n)), wi = w.conj(); omega[0] = winv[0] = 1; for (int i = 1; i < n; i++) omega[i] = omega[i - 1] * w, winv[i] = winv[i - 1] * wi; for (int i = 0; i < n; i++) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1))); } void FFT(cpx *const a, const cpx * const w, const int n) { for (int i = 0; i < n; i++) if (i < rev[i]) swap(a[i], a[rev[i]]); for (int l = 1; l < n; l <<= 1) for (int i = 0; i < n; i += (l << 1)) for (int k = 0; k < l; k++) { cpx x = a[i + k], y = a[i + l + k] * w[n / (l << 1) * k]; a[i + k] = x + y, a[i + l + k] = x - y; } } void mul(const cpx *const a, const cpx *const b, cpx *const c, const int n) { static cpx x[LEN], y[LEN]; int m = 1, lg2 = 0; while (m < (n + n - 1)) m <<= 1, ++lg2; init(m, lg2); memcpy(x, a, sizeof(cpx[n])); memcpy(y, b, sizeof(cpx[n])); for (int i = n; i < m; i++) x[i] = y[i] = 0; FFT(x, omega, m), FFT(y, omega, m); for (int i = 0; i < m; i++) x[i] = x[i] * y[i]; FFT(x, winv, m); for (int i = 0; i < n; i++) c[i] = cpx(x[i].x / m, 0.0); } void MTT(const int *const a, const int *const b, int *const ans, const int n) { const int S = 1 << 15; static cpx a0[LEN], a1[LEN], b0[LEN], b1[LEN], c1[LEN], c2[LEN], c3[LEN], c4[LEN]; for (int i = 0; i < n; i++) { a0[i] = cpx(a[i] % S, 0), a1[i] = cpx(a[i] / S, 0); b0[i] = cpx(b[i] % S, 0), b1[i] = cpx(b[i] / S, 0); } mul(a0, b0, c1, n), mul(a0, b1, c2, n), mul(a1, b0, c3, n), mul(a1, b1, c4, n); for (int i = 0; i < n; i++) { int x1 = ctol(c1[i]) % p, x2 = ctol(c2[i]) % p, x3 = ctol(c3[i]) % p, x4 = ctol(c4[i]) % p; ans[i] = (x1 + ll(x2 + x3) * S % p + ll(x4) * S % p * S % p) % p; } } void _inv(const int *const a, int *b, const int n) { if (n == 1) return void(b[0] = get_inv(a[0])); static int tmp[LEN]; _inv(a, b, (n + 1) >> 1); memset(b + ((n + 1) >> 1), 0, sizeof(int[n - ((n + 1) >> 1)])); MTT(a, b, tmp, n), MTT(tmp, b, tmp, n); for (int i = 0; i < n; i++) b[i] = (2LL * b[i] % p - tmp[i] + p) % p; } void inv(const int *const a, int *b, const int n) { static int tmp[LEN]; memcpy(tmp, a, sizeof(int[n])); _inv(tmp, b, n); } } int A[N << 1]; int work() { using namespace Polynomial; int n; read(n); for (int i = 0; i < n; i++) read(A[i]); inv(A, A, n); for (int i = 0; i < n; i++) write(A[i]), putchar(' '); return 0; }}int main(){#ifdef BlueSpirit freopen("4239.in", "r", stdin);#endif return zyt::work();}

转载于:https://www.cnblogs.com/zyt1253679098/p/11160807.html

你可能感兴趣的文章
08-语言入门-08-5个数求最值
查看>>
mysql常见知识点总结
查看>>
网站添加百度影音的方法
查看>>
Comparsion in JavaScript
查看>>
【转】ubuntu磁盘状态查看(转)--脱离鼠标操作
查看>>
hdu 1237 简单计算器 栈
查看>>
当我们在说微服务治理的时候究竟在说什么
查看>>
CAS(Compare And Swap)
查看>>
JAVA中String类以及常量池和常用方法
查看>>
java
查看>>
Oracle 数据库导入、导出
查看>>
批量修改 表结构
查看>>
MySQL的btree索引和hash索引的区别
查看>>
抽象类和接口有什么区别
查看>>
wc2018
查看>>
[转载] 杜拉拉升职记——01 忠诚源于满足
查看>>
那些mv*框架如何选择
查看>>
git工作流程
查看>>
Excel坐标自动在AutoCad绘图_3
查看>>
hacknet
查看>>