再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT)
阅读原文时间:2021年07月16日阅读:3

再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT)

目录

为了不使篇幅过长,预计将把基于论文的学习笔记分为三部分:

  1. DFT,IDFT,FFT的定义,实现与证明:快速傅里叶变换(FFT)学习笔记(其一)
  2. NTT的实现与证明:快速傅里叶变换(FFT)学习笔记(其二)
  3. 任意模数NTT与FFT的优化技巧

一些约定

  1. \([p(x)]=\begin{cases}1,p(x)为真 \\ 0,p(x)为假 \end{cases}\)
  2. 本文中序列的下标从0开始
  3. 若\(s\)是一个序列,\(|s|\)表示\(s\)的长度
  4. 若大写字母如\(F(x)\)表示一个多项式,那么对应的小写字母如\(f\)表示多项式的每一项系数,即\(F(x)=\sum_{i=0}^{n-1} f_ix^i\)

DFT卷积的本质

考虑在(其一)中提到的卷积的定义式。

\[c_{r}=\sum_{p, q}[(p+q) \bmod n=r] a_{p} b_{q} \tag{1.1}
\]

我们一般做FFT时忽略了式子中的\(\bmod\),其实它是在\(\bmod 2^q\)的意义下的循环卷积,只是因为\(|a|,|b|,|c|<2^q\),所以取不取模都没什么影响。

如果序列长度\(n\)是2的整数次幂,那么直接做就可以了。

如果序列长度\(n\)不是2的整数次幂考虑暴力的做法:先做一次普通FFT,再把\(c_{k+n}\)加到\(c_k\)上。但是这样在做多次FFT时就必须一次一次做,比如多项式快速幂。下面给出了一种在\(O(n \log n)\)的时间内实现任意长度循环卷积的算法:Bluestein’s Algorithm

Bluestein’s Algorithm

注:原论文的推导可能有误

考虑DFT的式子

\[\begin{aligned} a'_i&=\sum_{j=0}^{n-1} a_j \omega_n^{ij} \\&=\sum_{j=0}^{n-1} a_j \omega_n^{\frac{-(i-j)^2+i^2+j^2}{2}} \\&= \omega_n^{\frac{i^2}{2}} \sum_{j=0}^{n-1}a_j \omega_n^{\frac{j^2}{2}} \omega_n^{-\frac{(i-j)^2}{2}}\end{aligned}
\]

不妨设

\(x_j=a_j \omega_n^{\frac{j^2}{2}}=a_j(\cos\frac{j^2\pi}{n}+ \text{i}\sin{\frac{j^2\pi}{n}})\)

\(y_j=\omega_n^{-\frac{j^2}{2}}= \cos \frac{\pi j^2}{n}-\text{i}\sin \frac{\pi j^2}{n}\)

那么\(a_i'=\omega_n^{\frac{j^2}{2}}\sum_{j=0}^{n-1} x_j y_{i-j}\)

这已经很类似卷积的形式了,但是注意到\(j\)的上界是\(n-1\)而不是\(i\),\(j-i\)可能为负数。那么我们把\(y\)数组的长度扩大到\(2n\),定义:

\(y_j=\omega_n^{-\frac{(j-n)^2}{2}}= \cos \frac{\pi (j-n)^2}{n}-\text{i}\sin \frac{\pi (j-n)^2}{n}\).

这样\(j<n\)的时候就对应了\(j-i\)为负数的情形,\(j\geq n\)就对应了\(j-i\)为正的情形。然后对\(x\)和\(y\)用一般的FFT,最后的答案存储在\(i+n\)的位置上,也就是说真正的\(a'_i\)实际上对应了乘积结果的\((x \cdot y)_{i+n}\)

这样,我们就只做了3次FFT就求出了任意长度循环DFT。逆变换同理,只是换成共轭复数。注意到在上述的推导中我们没有用到单位根\(\omega\)的任何性质,因此这里的\(\omega\)可以换成任意复数\(z\),这样的变换称为Chirp Z-Transform,CZT.可见,CZT实际上是DFT的广义形式。

代码实现:

//com是手写复数类,省略
void fft(com *x,int *rev,int n,int type){
    //为节约篇幅,fft部分省略,x为系数序列,rev为反转数组,n为长度,type=1表示DFT,type=-1表示IDFT
}
void bluestein(com *a,int n,int type){
    //a为系数序列,n为长度,type=1表示DFT,type=-1表示IDFT
    static com x[maxn*4+5],y[maxn*4+5];
    static int rev[maxn*4+5];
    memset(x,0,sizeof(x));
    memset(y,0,sizeof(y));
    //FFT前的预处理
    int N=1,L=0;
    while(N<n*4){
        L++;
        N*=2;
    }
    for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    //x[i],y[i]的定义见上式
    for(int i=0;i<n;i++) x[i]=com(cos(pi*i*i/n),type*sin(pi*i*i/n))*a[i];
    for(int i=0;i<n*2;i++) y[i]=com(cos(pi*(i-n)*(i-n)/n),-type*sin(pi*(i-n)*(i-n)/n));
    fft(x,rev,N,1);
    fft(y,rev,N,1);
    for(int i=0;i<N;i++) x[i]*=y[i];
    fft(x,rev,N,-1);
    for(int i=0;i<n;i++){
        a[i]=x[i+n]*com(cos(pi*i*i/n),type*sin(pi*i*i/n));//记得乘上常数
        if(type==-1) a[i]/=n;//一定记得除以n,因为做一次Bluestein相当于一次FFT,IFFT最后要除n,这里也要除n
    }
}

例题

[POJ 2821]TN's Kindom III(任意长度循环卷积的Bluestein算法)

一般我们用FFT的时候,序列的所有元素都已知。但是,如果序列本身是根据卷积定义的,就无法直接套FFT

举一个最简单的例子\(f_i =\sum_{j=1}^i f_{i-j}g_j\).其中\(g\)给定,求\(f\). 由于我们卷积的时后后面的数基于前面的数,无法快速计算,时间复杂度退化到\(O(n^2)\). (虽然这个式子可以用(其四)中将会提到的多项式求逆解决,但是分治FFT更通用,可以处理很复杂的式子)

考虑分治: 设当前分治区间为\([l,r]\),假设我们求出了\([l,mid]\)的答案,那么可以求出这些点对\([mid+1,r]\)的影响。那么右半边的点\(x \in [mid+1,r]\)得到的贡献是\(\Delta_x=\sum_{i=l}^{mid} f_i g_{x-i}\).只需要把下标偏移一下(如\([l,mid]\)偏移成\([0,mid-l]\),就是一个卷积的形式,可以运用FFT或NTT计算,计算完之后,把答案累加到数组上.

伪代码如下:

poly f,g;//上述的f,g
procedure calc(L,mid,R){
    for i in [L,mid] : a[i-L] <- f[i]//下标偏移
    for i in [1,R-L] : b[i-1] <- g[i]
    a <- mul(a,b);//fft或ntt做多项式乘法
    for i in [mid+1,R] f[i] <- f[i]+a[i-l-1]//累加贡献
}
procedure solve(l,mid){
    if(l==r) return;
    mid <- (l+r)/2
    solve(l,mid);
    calc(l,mid,r);
    solve(mid+1,r)
}

时间复杂度分析:

\(T(n)=2T(\frac{n}{2})+n \log_2n\), 总复杂度\(\Theta(n \log^2n)\)

下面是基于NTT的模板代码(Luogu 4721)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 300000
#define G 3
#define invG 332748118
#define inv2 499122177
#define mod 998244353
using namespace std;
typedef long long ll;
inline ll fast_pow(ll x,ll k){
    ll ans=1;
    while(k){
        if(k&1) ans=ans*x%mod;
        x=x*x%mod;
        k>>=1;
    }
    return ans;
}
inline ll inv(ll x){
    return fast_pow(x,mod-2);
}

void NTT(ll *x,int n,int type){
    static int rev[maxn+5];
    int tn=1;
    int k=0;
    while(tn<n){
        tn*=2;
        k++;
    }
    for(int i=0;i<tn;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
    for(int i=0;i<n;i++){
        if(i<rev[i]) swap(x[i],x[rev[i]]);
    }
    for(int len=1;len<n;len*=2){
        int sz=len*2;
        ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz);
        for(int l=0;l<n;l+=sz){
            int r=l+len-1;
            ll gnk=1;
            for(int i=l;i<=r;i++){
                ll tmp=x[i+len];
                x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
                x[i]=(x[i]+gnk*tmp%mod)%mod;
                gnk=gnk*gn1%mod;
            }
        }
    }
    if(type==-1){
        int invsz=inv(n);
        for(int i=0;i<n;i++) x[i]=x[i]*invsz%mod;
    }
}
void mul(ll *a,ll *b,ll *ans,int sz){
    NTT(a,sz,1);
    NTT(b,sz,1);
    for(int i=0;i<sz;i++) ans[i]=a[i]*b[i]%mod;
    NTT(ans,sz,-1);
} 

void cdq_divide(ll *f,ll *g,int l,int r){
    static ll tmpa[maxn+5],tmpb[maxn+5];
    if(l==r) return;
    int mid=(l+r)>>1;
    cdq_divide(f,g,l,mid);
    int tn=1,k=0;
    while(tn<r-l){
        k++;
        tn*=2;
    }
    for(int i=0;i<tn;i++) tmpa[i]=tmpb[i]=0;
    for(int i=l;i<=mid;i++) tmpa[i-l]=f[i];
    for(int i=1;i<=r-l;i++) tmpb[i-1]=g[i];
    mul(tmpa,tmpb,tmpa,tn);
    for(int i=mid+1;i<=r;i++) f[i]=(f[i]+tmpa[i-l-1])%mod;
    cdq_divide(f,g,mid+1,r);
}

int n;
ll f[maxn+5],g[maxn+5];
int main(){
    scanf("%d",&n);
    for(int i=1;i<n;i++) scanf("%lld",&g[i]);
    f[0]=1;
    cdq_divide(f,g,0,n-1);
    for(int i=0;i<n;i++) printf("%lld ",f[i]);
}

容易发现,许多dp方程都有分治FFT的形式。对于此类dp方程,我们可以用分治FFT将转移复杂度由\(O(n^2)\)降到\(O(n \log^2 n)\)

例题

[Codeforces 553E]Kyoya and Train(期望DP+Floyd+分治FFT)

下面介绍一些优化FFT的常数的技巧。虽然这些技巧都只是对FFT的一些小优化,但是在某些题目中优化效果极其明显。

复杂算式中减少FFT次数

如果我们要计算一个复杂的多项式,如\(A(x)=B(x)C(x)+D(x)E(x)\)

最简单的方法是分别计算\(B(x)C(x)\)和\(D(x)E(x)\),这样需要做6次FFT. 但是如果先对\(B,C,D,E\)做DFT,然后直接用点值表达式计算\(a_i=b_ic_i+d_ie_i\),再把\(a\)IDFT回去。这样只需要做5次FFT,且多项式越复杂,这样的常数就越优秀。

例题

[BZOJ 3771] Triple(FFT+容斥原理+生成函数)

利用循环卷积

考虑对于两个长度为\(n\)的序列\(a,b\),计算它们的卷积\(c\)的第\(0.5n\)项到第\(1.5n\)项。传统的方法是补0扩充到\(2n\)的序列。但是因为FFT求得实际上是我们已经提到过的循环卷积,所以如果只补0到\(1.5n\)(上取整),对第\(0.5n\)项到第\(1.5n\)项无影响

在基于牛顿迭代的算法中,能起到较明显的优化作用。会在(其四)中详细介绍这些算法。

小范围暴力

由于FFT的常数较大。在数据范围较小的时候甚至不如\(O(n^2)\)的暴力卷积的优秀。因此在做多次FFT和分治FFT的时候,如果当前的序列长度较小,可以采用暴力算法。

例题

[BZOJ 3509] [CodeChef] COUNTARI (FFT+分块)

快速幂乘法次数的优化

这个东西实际上比较鸡肋。因为多项式快速幂可以通过多项式\(\ln\)和\(\exp\)优化到\(O(n \log n)\).但是为了应对考场上时间不够的情况,我们来考虑如何通过简单的实现来减少\(O(n \log^2n)\)的倍增快速幂的复杂度。

倍增法的思路是根据前面算过的乘积快速算出当前的乘积,如\(1 \to 2 \to 4 \to 8\).最坏情况下需要\(2 \log_2n+C\)次乘法。但这并不是下界。我们定义additional chain为一条链,最开始是1,后一个数减前一个数的差是链上这个是前面的某一个数。例如\(1 \to 2 \to 4 \to 6\).\(6-4=2\)在前面出现过,\(4-2=2\)在前面出现过。那么根据这条additional chain计算6次幂的时候,可以从1次幂出发,用1次幂乘1次幂得到2次幂,再乘2次幂得到4次幂,再乘2次幂得到6次幂。

很可惜,对于数\(k\)求出得到\(k\)的最短additional chain是NP-hard的。但是有很好的近似算法。近似算法基于BFS。每次我们对于队头的数\(x\),枚举它对应的additional chain中的数\(y\),如果\(x+y\)还没有访问过那么将其入队,并将\(x\)对应的链后面接上\(x+y\). 这个预处理是\(O(k)\)的,且对快速幂的常数优化很显著。

如果\(k\)很大,比如\(10^{10000}\),可以采用十进制快速幂。但是用Method of Four Russians(俗称四毛子算法),可以将乘法次数减少到\(\log_2n+O(\frac{\log n}{\log \log n})\).具体方法见2017年国家集训队论文《非常规大小分块算法初探》

FFT的强常数优化一般是通过减少FFT次数来实现的

在这一节中,我们记\(DFT(A(x))\)表示多项式\(A(x)\)(或序列)做DFT之后的结果,\(IDFT(A(x))\)同理

我们现在考虑最常见的一个模型:给出两个长度为\(n+1\)和\(m+1\)的多项式\(A(x),B(x)\),我们要计算他们的线性卷积。假设长度已经补齐为第一个大于\(n+m+1\)的2的整数幂\(L\)。

显然直接搞需要3次长度为\(L\)的FFT。毒瘤的Vladimir Smykalov在cf上最先给出了这个问题的优化算法。

DFT的合并

DFT的合并是指,对于两个序列\(a\),\(b\),我们只通过一次FFT就求出\(DFT(a),DFT(b)\)

不妨设:

\[P(x)=A(x)+\text{i}B(x) \tag{4.1}
\]

\[Q(x)=A(x)-\text{i}B(x) \tag{4.2}
\]

接下来我们开始推导公式。注意为了简洁,我们记\(X=\frac{2 \pi jk}{2L}\),\(\text{conj}(z)\)表示\(z\)的共轭复数

\[\begin{aligned}
DFT(p_k) &=A\left(\omega_{2 L}^{k}\right)+i B\left(\omega_{2 L}^{k}\right) \\
&=\sum_{j=0}^{2 L-1} a_{j} \omega_{2 L}^{j k}+i b_{j} \omega_{2 L}^{j k} \\
&=\sum_{j=0}^{2 L-1}\left(a_{j}+i b_{j}\right)(\cos X+i \sin X)
\end{aligned}\]

\[\begin{aligned}
DFT(q_k) &=A\left(\omega_{2 L}^{k}\right)-i B\left(\omega_{2 L}^{k}\right) \\
&=\sum_{j=0}^{2 L-1} a_{j} \omega_{2 L}^{j k}-i b_{j} \omega_{2 L}^{j k} \\
&=\sum_{j=0}^{2 L-1}\left(a_{j}-i b_{j}\right)(\cos X+i \sin X) \\
&=\sum_{j=0}^{2 L-1}\left(a_{j} \cos X+b_{j} \sin X+i \sin X-b_{j} \cos X\right) \\&=\operatorname{conj}\left(\sum_{j=0}^{2 L-1}\left(a_{j} \cos X+b_{j} \sin X\right)-i\left(a_{j} \sin X-b_{j} \cos X\right)\right)\\
&=\operatorname{conj}\left(\sum_{j=0}^{2 L-1}\left(a_{j} \cos (-X)-b_{j} \sin (-X)\right)+i\left(a_{j} \sin (-X)+b_{j} \cos (-X)\right)\right)\\
&=\operatorname{conj}\left(\sum_{j=0}^{2 L-1}\left(a_{j}+i b_{j}\right)(\cos (-X)+i \sin (-X))\right)\\
&=\operatorname{conj}\left(\sum_{j=0}^{2 L-1}\left(a_{j}+i b_{j}\right) \omega_{2 i}^{-j k}\right)\\
&=\operatorname{conj}\left(\sum_{j=0}^{2 L-1}\left(a_{j}+i b_{j}\right) \omega_{2 L}^{(2 L-k) j}\right)\\
&=\operatorname{conj}\left(p'[2 L-k]\right)
\end{aligned}\]

也就是说,只要一次DFT算出\(DFT(p)\),就可以把序列反转再取共轭复数得到\(DFT(q)\).

由于DFT是线性变换,

\[DFT(a_k)=\frac{DFT(p_k)+DFT(q_k)}{2}=\frac{DFT(p_k)+\text{conj}(DFT(p_j))}{2}
\]

其中\(j\)为\(k\)翻转后的数,即\(j=\begin{cases}0,k=0 \\ L-k ,k>0 \end{cases}\)

又由\((4.1),(4.2)\)式

\[DFT(a_k)=\frac{DFT(p_k)+DFT(q_k)}{2} \tag{4.3}
\]

\[DFT(b_k)=-\text{i}\frac{DFT(p_k)-DFT(q_k)}{2} \tag{4.4}
\]

\[DFT(a_k)DFT(b_k)=\text{i}\frac{{DFT(p_k)}^2-{DFT(q_k)}^2}{4} \tag{4.5}
\]

这样我们就可以从\(q'\)推出\(a',b'\),也就是说一次DFT就能得到\(a'\)和\(b'\)了.

我们一共做了2次长度为\(L\)的FFT.

代码(UOJ#34):

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
    double real;
    double imag;
    com(){

    }
    com(double _real,double _imag){
        real=_real;
        imag=_imag;
    }
    com(double x){
        real=x;
        imag=0;
    }
    void operator = (const com x){
        this->real=x.real;
        this->imag=x.imag;
    }
    void operator = (const double x){
        this->real=x;
        this->imag=0;
    }
    friend com operator + (com p,com q){
        return com(p.real+q.real,p.imag+q.imag);
    }
    friend com operator + (com p,double q){
        return com(p.real+q,p.imag);
    }
    void operator += (com q){
        *this=*this+q;
    }
    void operator += (double q){
        *this=*this+q;
    }
    friend com operator - (com p,com q){
        return com(p.real-q.real,p.imag-q.imag);
    }
    friend com operator - (com p,double q){
        return com(p.real-q,p.imag);
    }
    void operator -= (com q){
        *this=*this-q;
    }
    void operator -= (double q){
        *this=*this-q;
    }
    friend com operator * (com p,com q){
        return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
    }
    friend com operator * (com p,double q){
        return com(p.real*q,p.imag*q);
    }
    void operator *= (com q){
        *this=(*this)*q;
    }
    void operator *= (double q){
        *this=(*this)*q;
    }
    friend com operator / (com p,double q){
        return com(p.real/q,p.imag/q);
    }
    void operator /= (double q){
        *this=(*this)/q;
    }
    com conj(){
        return com(real,-imag);
    }
    void print(){
        printf("%lf + %lf i ",real,imag);
    }
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
    for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
    for(int len=1;len<n;len*=2){
        int sz=len*2;
        for(int l=0;l<n;l+=sz){
            int r=l+len-1;
            for(int i=l;i<=r;i++){
                com tmp=x[i+len];
                x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
                x[i]=x[i]+tmp*w[n/sz*(i-l)];
            }
        }
    }
}
void mul(ll *a,ll *b,ll *c,int n){
    static com p[maxn+5],r[maxn+5];
    for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));//预处理单位根
    for(int i=0;i<n;i++) p[i]=com(a[i],b[i]);//p[i]=a[i]+ib[i]
    fft(p,n);
    for(int i=0;i<n;i++){
        int j=(i>0?(n-i):0);//0的位置需要特判一下
        com q=p[j];
        r[j]=(p[i]*p[i]-q.conj()*q.conj())*com(0,-0.25);//按照上面的式子
    }
    fft(r,n);//这里是用了第一篇中提到的反转技巧
    for(int i=0;i<n;i++) c[i]=r[i].real/n+0.5;
}

int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
    scanf("%d %d",&n,&m);
    for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
    for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
    int N=1,L=0;
    while(N<n+m+1){
        L++;
        N*=2;
    }
    for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    mul(a,b,c,N);
    for(int i=0;i<n+m+1;i++) printf("%lld\n",c[i]);
}

IDFT的合并

IDFT的合并是指,对于两个序列\(a\),\(b\),我们只通过一次FFT就求出\(IDFT(a),IDFT(b)\)

IDFT的合并非常简单。

设\(r(x)=a(x)+\text{i}b(x)\)

由于IDFT是线性变换

\(IDFT(r(x))=IDFT(a(x))+\text{i}IDFT(b(x))\)

又因为\(a(x)\)和\(b(x)\)都是实数序列,那么\(IDFT(r(x))\)的实部就是\(IDFT(a(x))\),虚部就是\(IDFT(b(x))\)

形如\((A+B)(C+D)\)的卷积的优化

在这一节中我们讨论\((A(x)+B(x))(C(x)+D(x))\)形式的卷积的优化.

一般的做法是对\(A,B,C,D\)都做一次DFT,然后按照这个式子直接计算,最后再IDFT回来。需要5次FFT.

而根据上面的合并技巧,先把\(A(x),B(x)\)合并DFT,\(C(x),D(x)\)合并DFT得到点值表达式.

由于\((A(x)+B(x))(C(x)+D(x))=A(x)C(x)+A(x)D(x)+B(x)C(x)+B(x)D(x)\)

我们可以直接把点值表达式相乘得到这4个多项式。对于这4个多项式,分成2组合并做IDFT即可。

总共需要4次FFT.

大致代码如下:

void mul(ll *a,ll *b,ll *c,ll *d,ll *ans,int n){
    static com p[maxn+5],q[maxn+5];
    static com r[maxn+5],s[maxn+5];
    for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
    for(int i=0;i<n;i++){
        p[i]=com(a[i],b[i]);//打包A,B
        q[i]=com(c[i],d[i]);//打包C,D
    }
    fft(p,n);
    fft(q,n);
    for(int i=0;i<n;i++){
        int j=(i==0?0:n-i);
        //得到DFT(A),DFT(B),DFT(C),DFT(D)
        com da=(p[i]+p[j].conj())*0.5;
        com db=(p[i]-p[j].conj())*com(0,-0.5);
        com dc=(q[i]+q[j].conj())*0.5;
        com dd=(q[i]-q[j].conj())*com(0,-0.5);
        r[j]=da*dc+da*dd*com(0,1);//打包AC,AD
        s[j]=db*dc+db*dd*com(0,1); //打包BC,BD
    }
    fft(r,n);
    fft(s,n);
    for(int i=0;i<n;i++){
        ll ac,ad,bc,bd;
        ac=(ll)(r[i].real/n+0.5);
        ad=(ll)(r[i].imag/n+0.5);
        bc=(ll)(s[i].real/n+0.5);
        bd=(ll)(s[i].imag/n+0.5);
        ans[i]=ac+ad+bc+bd;
    }
}

卷积的终极优化

上述优化中我们只用到了DFT的思想。现在我们利用FFT的思想继续优化

同样拆分奇偶项,\(A(x)=A_0(x^2)+xA_1(x^2)\)

\[\begin{aligned}
A(x)B(x)&=(A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))\\
&=A_0(x^2)B_0(x^2)+x(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2))+x^2A_1(x^2)B_1(x^2)
\end{aligned} \tag{4.6}\]

我们只需要知道上式中\(x^0,x^1,x^2\)的系数

发现\(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2)\)是奇数项的系数,\(A_0(x^2)B_0(x^2)\)和\(A_1(x^2)B_1(x^2)\)是偶数项的系数,而偶数项的两个东西都可以看成一个关于\(x^2\)的多项式。

我们先优化DFT的过程,观察\((4.6)\)式的乘积形式\((A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))\).

我们发现,这个形式和上一节的\((A+B)(C+D)\)很像,可以类似地优化。

令\(p_k={a_0}_k+\text{i}{a_1}_k,q_k={b_0}_k+\text{i}{b_1}_k\)

然后合并IDFT,再设两个辅助多项式

\[G(x)=DFT(A_0(x))\cdot DFT(B_0(x))+\omega_L^k DFT(A_1(x)) DFT(B_1(x))
\]

(注意我们把\(x^2\)换元成\(x\),做DFT的时候要乘上单位根)

\[F(x)=DFT(A_0(x))\cdot DFT(B_1(x))+ DFT(A_1(x)) DFT(B_0(x))
\]

那么我们只需要计算出\(IDFT(G(x))\)和\(IDFT(F(x))\)

设\(R(x)=G(x)+\mathrm{i} F(x)\)

那么因为IDFT是线性变换,\(IDFT(R(x))=IDFT(G(x))+\mathrm{i} IDFT(F(x))\)

(IDFT的线性性这里不做证明,容易发现两个点值表达式相加再IDFT回来,显然系数也会相加)

显然这两个多项式IDFT的结果是实数。故我们只要求出\(IDFT(R(x))\),每一项系数的实部就是偶数项系数\(G(x)\),虚部就是奇数项系数\(F(x)\)

我们再考虑把合并DFT弄进去,即式\((4.3)(4.4)(4.5)\)

接下来我们尝试用\(DFT(p_k),DFT(q_k)\)来表示\(R(x)=G(x)+\text{i}F(x)\),为了推导简洁,我们省略\(DFT\)不写

\[\begin{aligned}
g_k&=\frac {p_k+\text{conj}(p_j)}{2}\cdot \frac {q_k+\text{conj}(q_j)}{2}+\omega_L^k\cdot \frac {p_k-\text{conj}(p_j)}{-2i}\cdot \frac {q_k-\text{conj}(q_j)}{-2i}\\
&=\frac 1 4 [(p_k+\text{conj}(p_j))\cdot(q_k+\text{conj}(q_j))-\omega_L^k\cdot(p_k-\text{conj}(p_j))\cdot(q_k-\text{conj}(q_j))]\\
\\
f_k&=\frac {p_k+\text{conj}(p_j)} 2 \cdot \frac{q_k-\text{conj}(q_j)}{-2}i+\frac {q_k+\text{conj}(q_j)} 2 \cdot \frac{p_k-\text{conj}(p_j)}{-2}i\\
&=\frac i{-4}[2\cdot p_k\cdot q_k-2\cdot \text{conj}(p_j)\cdot \text{conj}(q_j)]
\end{aligned}\]

那么

\[\begin{aligned}
g_k+\text{i} f_k&=\frac 1 4 [(p_k+\text{conj}(p_j))\cdot(q_k+\text{conj}(q_j))-w_L^k\cdot(p_k-\text{conj}(p_j))\cdot(q_k-\text{conj}(q_j))-2\cdot p_k\cdot q_k+2 \text{conj}(p_j\cdot q_j)]\\
&=\frac 1 4 [-(p_k-\text{conj}(p_j))\cdot(q_k-\text{conj}(q_j))+2\cdot (p_k\cdot q_k+\text{conj}(p_j\cdot q_j))\\
&-w_L^k\cdot(p_k-\text{conj}(p_j))\cdot(q_k-\text{conj}(q_j))+2\cdot p_k\cdot q_k-2\cdot \text{conj}(p_j\cdot q_j)]\\
&=q_k\cdot p_k-\frac 1 4[(1+w_L^k)\cdot (p_k-\text{conj}(p_j))\cdot(q_k-\text{conj}(q_j))]\\
\end{aligned}\]

和上一节的\((A+B)(C+D)\)不同,我们只用了3次长度为\(L/2\)的FFT,就求出了答案,这是由于FFT本身的性质。因为长度缩减了一半,我们不妨称它为\(1.5\)次FFT.

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
    double real;
    double imag;
    com(){

    }
    com(double _real,double _imag){
        real=_real;
        imag=_imag;
    }
    com(double x){
        real=x;
        imag=0;
    }
    void operator = (const com x){
        this->real=x.real;
        this->imag=x.imag;
    }
    void operator = (const double x){
        this->real=x;
        this->imag=0;
    }
    friend com operator + (com p,com q){
        return com(p.real+q.real,p.imag+q.imag);
    }
    friend com operator + (com p,double q){
        return com(p.real+q,p.imag);
    }
    void operator += (com q){
        *this=*this+q;
    }
    void operator += (double q){
        *this=*this+q;
    }
    friend com operator - (com p,com q){
        return com(p.real-q.real,p.imag-q.imag);
    }
    friend com operator - (com p,double q){
        return com(p.real-q,p.imag);
    }
    void operator -= (com q){
        *this=*this-q;
    }
    void operator -= (double q){
        *this=*this-q;
    }
    friend com operator * (com p,com q){
        return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
    }
    friend com operator * (com p,double q){
        return com(p.real*q,p.imag*q);
    }
    void operator *= (com q){
        *this=(*this)*q;
    }
    void operator *= (double q){
        *this=(*this)*q;
    }
    friend com operator / (com p,double q){
        return com(p.real/q,p.imag/q);
    }
    void operator /= (double q){
        *this=(*this)/q;
    }
    com conj(){
        return com(real,-imag);
    }
    void print(){
        printf("%lf + %lf i ",real,imag);
    }
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){

    for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
    for(int len=1;len<n;len*=2){
        int sz=len*2;
        for(int l=0;l<n;l+=sz){
            int r=l+len-1;
            for(int i=l;i<=r;i++){
                com tmp=x[i+len];
                x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
                x[i]=x[i]+tmp*w[n/sz*(i-l)];
            }
        }
    }
}
void mul(ll *a,ll *b,ll *c,int n){
    static com p[maxn+5],q[maxn+5],r[maxn+5];
    for(int i=0;i<n;i++){//合并做DFT
        if(i%2==1){
            p[i/2].imag=a[i];
            q[i/2].imag=b[i];
        }else{
            p[i/2].real=a[i];
            q[i/2].real=b[i];
        }
    }
    n/=2;
    for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
    fft(q,n);
    fft(p,n);
    for(int i=0;i<n;i++){
        int j=(i>0?(n-i):0);
        r[j]=p[i]*q[i]-(w[i]+1)*(p[i]-p[j].conj())*(q[i]-q[j].conj())*0.25;
    }
    fft(r,n);
    for(int i=0;i<n;i++){
        c[i*2]=r[i].real/n+0.5;
        c[i*2+1]=r[i].imag/n+0.5;
    }
}

int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
    scanf("%d %d",&n,&m);
    for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
    for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
    int N=1,L=0;
    while(N<=n+m+1){
        L++;
        N*=2;
    }
    for(int i=0;i<N/2;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-2));//注意这里的rev数组是对N/2做的,L要-1
    mul(a,b,c,N);
    for(int i=0;i<n+m+1;i++) printf("%lld\n",c[i]);
}

三模数NTT

这是任意模数NTT的算法中最好理解的一种,它基于中国剩余定理。

定理5.1 若\(m_1,m_2 ,\dots m_n\)两两互质,则对于\(\forall a_1,a_2 \dots a_n\)同余方程组

\[\begin{cases} x \equiv a_1 (\bmod m_1) \\ x \equiv a_2 (\bmod m_2) \\ \dots \\ x \equiv a_n (\bmod m_n)\end{cases}
\]

有整数解解,且可以用如下方式构造解

这就是著名的中国剩余定理(Chinese Reminder Theorem,CRT)

证明:

对于\(k \neq i\),\(a_iM_iM_i^{-1} \bmod m_k=0\), 而根据逆元的定义,\(a_iM_iM_i^{-1} \bmod m_i =a_i\). 再代入到\(\sum_{i=1}^n a_iM_iM_i^{-1}\),原方程组成立。

回到任意模数NTT问题

模\(M\)意义下长度为\(n\)的序列做卷积,最大值可以到\(n^2M\).一般的题目中\(n \leq 10^5,M\leq 10^{9}\),那么结果会到\(10^{23}\)级别。用long double等存储会丢失精度。那么我们可以选三个乘起来大于\(10^{23}\)的NTT模数998244353,1004535809,469762049(选这三个模数的好处是他们的原根都是3,所以NTT部分写起来比较简洁)。然后分别在这三个模数的意义下做卷积。最后考虑把答案合并,我们只考虑某一位上的值\(ans\),容易写出:

\[\begin{cases} ans=a_1( \bmod m_1) (5.2)\\ans=a_2( \bmod m_2)(5.3)\\ans=a_3( \bmod m_3) (5.4)\end{cases}
\]

显然\(m_1,m_2,m_3\)互质,那么我们可以利用中国剩余定理直接合并。但是,直接合并把三个模数乘起来的时候会超出long long的范围。注意到两个模数相乘还是在long long范围内的,可以两两合并,具体方法如下,

记\(inv(a,m)\)表示\(a\)在模\(m\)下的逆元.根据CRT合并\((5.2)(5.3)\)有:

\[ans \equiv a_1m_2inv(m_1,m_1m_2)+a_2m_1inv(m_2,m_1m_2)(\bmod m_1m_2) \tag{5.5}
\]

不妨设\(ans=km_1m_2+r\),根据\(5.4\)有

\(ans=km_1 m_2+r=q m_3+a_3 \tag{5.6}\),

在模 \(m_3\) 意义下有

\(km_1 m_2+r \equiv a_3 (\bmod m_3) \tag{5.7}\)

因此\(k=(a_3-r_2)inv(m_1m_2,m_3) (\bmod m_3)\),不妨设\(k=dm_3+e\),代入\(5.6\)得

\[ans=dm_1m_2m_3+em_1m_2+r
\]

由于\(m_1m_2m_3>ans\),所以\(d=0\),也就是说,\(ans=em_1m_2+r\),其中\(r=a_1m_2inv(m_1,m_1m_2)+a_2m_1inv(m_2,m_1m_2),e=(a_3-r_2)inv(m_1m_2,m_3)\)

const ll mm=m1*m2;
inline ll inv(ll a,ll m);
ll mul(ll a,ll b,ll m);//要用按位乘防止溢出
ll CRT(ll a1,ll a2,ll a3){
    ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
    ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
    return ((e%C)*(mm%C)%C+r%C)%C;
}

完整代码(LuoguP4245 【模板】任意模数NTT)

#include<iostream>
#include<cstdio>
#include<cstring>
#define m1 998244353ll
#define m2 1004535809ll
#define m3 469762049ll
#define G 3
#define maxn 1048576
using namespace std;
typedef long long ll;
const ll mm=m1*m2;
ll C;
ll fast_pow(ll x,ll k,ll m){
    ll ans=1;
    while(k){
        if(k&1) ans=ans*x%m;
        x=x*x%m;
        k>>=1;
    }
    return ans;
}
inline ll inv(ll a,ll m){
    return fast_pow(a%m,m-2,m); //一定要取模m
} 

ll mul(ll a,ll b,ll m){
    ll ans=0;
    while(b){
        if(b&1) ans=(ans+a)%m;
        a=(a+a)%m;
        b>>=1;
    }
    return ans;
}
ll CRT(ll a1,ll a2,ll a3){
    //[Warning]You are not expected to understand this.
    ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
    ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
    return ((e%C)*(mm%C)%C+r%C)%C;
}

int n,m,N,L;
int rev[maxn+5];
void NTT(ll *x,int n,int type,ll mod){
    ll invG=inv(G,mod);
    for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
    for(int len=1;len<n;len*=2){
        int sz=len*2;
        ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz,mod);
        for(int l=0;l<n;l+=sz){
            int r=l+len-1;
            ll gnk=1;
            for(int i=l;i<=r;i++){
                ll tmp=x[i+len];
                x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
                x[i]=(x[i]+gnk*tmp%mod)%mod;
                gnk=gnk*gn1%mod;
            }
        }
    }
    if(type==-1){
        ll invn=inv(n,mod);
        for(int i=0;i<n;i++) x[i]=x[i]*invn%mod;
    }
}
void fmul(ll *a,ll *b,ll *ans,int n,ll mod){
    static ll ta[maxn+5],tb[maxn+5];
    for(int i=0;i<n;i++) ta[i]=a[i];
    for(int i=0;i<n;i++) tb[i]=b[i];
    NTT(ta,n,1,mod);
    if(a!=b) NTT(tb,n,1,mod);
    for(int i=0;i<n;i++) ans[i]=ta[i]*tb[i]%mod;
    NTT(ans,n,-1,mod);
}

ll a[maxn+5],b[maxn+5],c[3][maxn+5];
int main(){
    scanf("%d %d %lld",&n,&m,&C);
    for(int i=0;i<=n;i++){
        scanf("%lld",&a[i]);
        a[i]%=C;
    }
    for(int i=0;i<=m;i++){
        scanf("%lld",&b[i]);
        b[i]%=C;
    }
    N=1,L=0;
    while(N<n+m+1){
        N*=2;
        L++;
    }
    for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    fmul(a,b,c[0],N,m1);
    fmul(a,b,c[1],N,m2);
    fmul(a,b,c[2],N,m3);
    for(int i=0;i<n+m+1;i++){
        printf("%lld ",CRT(c[0][i],c[1][i],c[2][i]));
    }
}

容易发现,三模数NTT需要9次FFT,不是很优秀

拆系数FFT

我们之前讨论的优化都是针对FFT的,那不妨尝试用FFT解决任意模数NTT

最简单的想法是不取模,FFT完再取模。但是上文提到数值过大,long double会丢失精度。

int128是一个方法,但在OI比赛中不一定能使用。所以需要拆系数。

设\(M_0=[\sqrt{M}]\)

\[\begin{aligned} a_i=k[a_i]M_0+b[a_i]\\
b_i=k[b_i]M_0+b[b_i]\end{aligned}\]

相当于把模数换成\(M_0\),降低大小。

代入对应的多项式

\[\begin{aligned}A(x)=K_a(x)M_0+B_a(x)\\
B(x)=K_b(x)M_0+B_b(x)\\
A(x)B(x)=K_a(x)K_b(x)M_0^2+(K_a(x)B_b(x)+K_b(x)B_a(x))M_0+B_a(x)B_b(x) \end{aligned}\]

这不就是我们提到的\((A+B)(C+D)\)形的卷积吗?

由于\(k,b\)都不超过\(2^{15}\),于是就不容易被卡精度了。实际操作中我们不必取\(M_0=\sqrt{M}\),直接取\(M_0=2^{15}\)即可。这样取模运算可以换成位运算,进一步减小常数。

代码(LuoguP4245 【模板】任意模数NTT)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
    double real;
    double imag;
    com(){

    }
    com(double _real,double _imag){
        real=_real;
        imag=_imag;
    }
    com(double x){
        real=x;
        imag=0;
    }
    void operator = (const com x){
        this->real=x.real;
        this->imag=x.imag;
    }
    void operator = (const double x){
        this->real=x;
        this->imag=0;
    }
    friend com operator + (com p,com q){
        return com(p.real+q.real,p.imag+q.imag);
    }
    friend com operator + (com p,double q){
        return com(p.real+q,p.imag);
    }
    void operator += (com q){
        *this=*this+q;
    }
    void operator += (double q){
        *this=*this+q;
    }
    friend com operator - (com p,com q){
        return com(p.real-q.real,p.imag-q.imag);
    }
    friend com operator - (com p,double q){
        return com(p.real-q,p.imag);
    }
    void operator -= (com q){
        *this=*this-q;
    }
    void operator -= (double q){
        *this=*this-q;
    }
    friend com operator * (com p,com q){
        return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
    }
    friend com operator * (com p,double q){
        return com(p.real*q,p.imag*q);
    }
    void operator *= (com q){
        *this=(*this)*q;
    }
    void operator *= (double q){
        *this=(*this)*q;
    }
    friend com operator / (com p,double q){
        return com(p.real/q,p.imag/q);
    }
    void operator /= (double q){
        *this=(*this)/q;
    }
    com conj(){
        return com(real,-imag);
    }
    void print(){
        printf("(%lf,%lf)\n",real,imag);
    }
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
    for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
    for(int len=1;len<n;len*=2){
        int sz=len*2;
        for(int l=0;l<n;l+=sz){
            int r=l+len-1;
            for(int i=l;i<=r;i++){
                com tmp=x[i+len];
                x[i+len]=x[i]-tmp*w[n/sz*(i-l)];
                x[i]=x[i]+tmp*w[n/sz*(i-l)];
            }
        }
    }
}
ll mod;
void mul(ll *ina,ll *inb,ll *inc,int n){
    static ll a[maxn+5],b[maxn+5],c[maxn+5],d[maxn+5];
    static com p[maxn+5],q[maxn+5];
    static com r[maxn+5],s[maxn+5];
    for(int i=0;i<n;i++){
        ina[i]=(ina[i]+mod)%mod;
        inb[i]=(inb[i]+mod)%mod;
        a[i]=ina[i]>>15;
        b[i]=ina[i]&((1<<15)-1);
        c[i]=inb[i]>>15;
        d[i]=inb[i]&((1<<15)-1);
    }
    for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
    for(int i=0;i<n;i++){
        p[i]=com(a[i],b[i]);//打包A,B
        q[i]=com(c[i],d[i]);//打包C,D
    }
    fft(p,n);
    fft(q,n);
    for(int i=0;i<n;i++){
//        p[i].print();
        int j=(i==0?0:n-i);
        //得到DFT(A),DFT(B),DFT(C),DFT(D)
        com da=(p[i]+p[j].conj())*0.5;
        com db=(p[i]-p[j].conj())*com(0,-0.5);
        com dc=(q[i]+q[j].conj())*0.5;
        com dd=(q[i]-q[j].conj())*com(0,-0.5);
        r[j]=da*dc+da*dd*com(0,1);//打包AC,AD
        s[j]=db*dc+db*dd*com(0,1); //打包BC,BD
    }
    fft(r,n);
    fft(s,n);
    for(int i=0;i<n;i++){
        ll ac,ad,bc,bd;
        ac=(ll)(r[i].real/n+0.5)%mod;
        ad=(ll)(r[i].imag/n+0.5)%mod;
        bc=(ll)(s[i].real/n+0.5)%mod;
        bd=(ll)(s[i].imag/n+0.5)%mod;
        inc[i]=((ac<<30)+((ad+bc)<<15)+bd)%mod;
    }
}

int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
    scanf("%d %d %lld",&n,&m,&mod);
    for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
    for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
    int N=1,L=0;
    while(N<=n+m+1){
        L++;
        N*=2;
    }
    for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    mul(a,b,c,N);
    for(int i=0;i<n+m+1;i++) printf("%lld ",c[i]);
}