FFT 小记
阅读原文时间:2023年07月09日阅读:2

写在前面

\(Q:\) 为什么会心血来潮去学 FFT

\(A:\) 当本蒟蒻还在努力消化凸包时:。所以本蒟蒻也来看一下

等等 摸头警告 。思维已经废了

About FFT

FFT( \(Fast\ Fourier\ Transformation\) )

中文名:快速傅里叶变换

Fast Fast TLE

作用:在 \(O(n\log n)\) 内求出多项式卷积

多项式

一个形如 \(A(x)=a_0+a_1x+\cdot\cdot\cdot+a_{n-2}x^{n-2}+a_{n-1}x^{n-1}=\sum_{i=0}^{n-1}a_ix^i\) 的柿子,称其为多项式

  • 系数表示法

    将 \(n-1\) 次多项式看成一个 \(n\) 维向量 \(\vec a=(a_0,a_1,\cdot\cdot\cdot,a_{n-1})\) 即为多项式的系数表示

  • 点值表示法

    \(n-1\) 次多项式 \(A(x)\) 将 \(n\) 个不同的 x 代入可得到 n 个点,能唯一确定多项式 \(A(x)\)

  • 多项式乘法

    \(A(x)=\sum_{i=0}^{n-1}a_ix^i,B(x)=\sum_{i=0}^{n-1}b_ix^i\) 则 \(C(x)=A(x)*B(x)=\sum_{i=0}^{2n-2}\sum_{j+k=i}a_jb_kx_i\)

  • 卷积

    两个向量 \(\vec a=(a_0,a_1,\cdot\cdot\cdot,a_{n-1}),\vec b=(b_0,b_1,\cdot\cdot\cdot,b_{n-1})\)

    有卷积 \(\vec a \otimes \vec b=c(c_0,c_1,\cdot\cdot\cdot,c_{2n-2})\) ,其中 \(c_k=\sum_{i+j=k}a_ib_j\)


系数表示法时计算是 \(O(n^2)\) 的

但是对于两个点值表达式的多项式,可以 \(O(n)\) 的计算出多项式乘积

便得出了 FFT 的三个步骤

  1. 系数表示法转为点值表示法, \(DFT,O(n\log n)\)
  2. 点值表示法相乘, \(O(n)\)
  3. 点值表示法转为系数表示法, \(IDFT,O(n\log n)\)

复数 Complex

复数由实部和虚部构成,可用二元组 \((a,b)\) 表示复数 \(a+bi,i^2=-1\) 可将其理解为一个点或向量

  • 加法

    \((a_1,b_1)+(a_2,b_2)=(a_1+a_2,b_1+b_2)\)

  • 减法

    \((a_1,b_1)-(a_2,b_2)=(a_1-a_2,b_1-b_2)\)

  • 乘法

    \((a+bi)(c+di)=(ac-bd)+(bc+ad)i=(ac-bd,bc+ad)\)

  • 除法

    • 共轭复数:实部相同,虚部互为相反数的两个复数, \(z(a,b)=a+bi\) 的共轭复数为 \(\bar z(a,-b)=(a-bi)\)

      有趣的性质:对于一个 \(z=(a,b)\) , \(|z|=|\bar z|,z\cdot\bar z=a^2+b^2\)

    \(\dfrac{a+bi}{c+di}=\dfrac{(a+bi)(c-di)}{(c+di)(c-di)}=\dfrac{ac+bd}{c^2+d^2}+\dfrac{bc-ad}{c^2+d^2}i\)

  • 单位根

    将单位圆(以原点为圆心,半径为 1)n 等分,得到 n 个模长为 1 的复数

    将点从 0 开始标号,设第 0 个点为 \(\omega_n^0\) 以此类推

    以 (1,0) 为起点,则 \(\omega_n^i=(\cos(\dfrac{i}{n}2\pi),\sin(\dfrac{i}{n}2\pi))\) (用弧度制)

    把这些复数称为 n 次单位根

    性质:\(\omega_n^n=1,\omega_n^k=\omega_n^{2k},\omega_{n}^{k+\frac{n}{2}}=-\omega_{n}^k\)

DFT

将系数表达式转化为点值表达式,就 DFT 来说,它分治地来求当 \(x=\omega_n^k\) 的时候 \(F(x)\) 的值

使 \(n=2^m\) ,不够位就补,不会对答案有影响

对于多项式 \(A(x)=\sum_{i=0}^{n-1}a_ix^i\) ,将其按下标奇偶性分类

\(A(x)=(a_0+a_2x^2+\cdot\cdot\cdot+a_{n-2}x^{n-2})+(a_1x+a_3x^3+\cdot\cdot\cdot+a_{n-1}x^{n-1})\)

现在设 \(A_1(x)=(a_0+a_2x+\cdot\cdot\cdot+a_{n-2}x^{\frac{n-2}{2}}),A_2(x)=(a_1+a_3x+\cdot\cdot\cdot+a_{n-1}x{\frac{n-2}{2}})\)

则 \(A(x)=A_1(x^2)+xA_2(x^2)\)

对于 \(k<\dfrac{n}{2}\) 有 \(A(\omega_n^k)=A_1(\omega_n^{2k})+\omega_n^kA_2(\omega_n^{2k})\\ =A_1(\omega_{\frac{n}{2}}^k)+\omega_n^kA_2(\omega_{\frac{n}{2}}^k)\)

对于 \(k+\dfrac{n}{2}\) 有 \(A(\omega_n^{k+\frac{n}{2}})=A_1(\omega_n^{2k+n})+\omega_n^{k+\frac{n}{2}}A_2(\omega_n^{2k+n})\\\ =A_1(\omega_{\frac{n}{2}}^k*\omega_n^n)+\omega_n^k*\omega_n^{\frac{n}{2}}A_2(\omega_{\frac{n}{2}}^k*\omega_n^n)\)

因为 \(\omega_n^n=(1,0),\omega_n^{\frac{n}{2}}=(-1,0)\) 所以 \(A(\omega_n^{k+\frac{n}{2}})=A_1(w_{\frac{n}{2}}^k)-w_n^kA_2(w_{\frac{n}{2}}^k)\)

于是问题被分解成了更小的子问题,递归求解即可

时间复杂度 \(O(n\log n)\)

IDFT

将点值表达式转回为系数表达式

以下内容摘自 OI-Wiki

已知 \(y_i=F(\omega_n^i),i\in[0,n)\) 求 \(\{a_0,a_1,\cdot\cdot\cdot,a_{n-1}\}\) 。构造式是:\(A(x)=\sum_{i=0}^{n-1}y_ix^i\)

相当于把 \(\{y_0,y_1,\cdot\cdot\cdot,y_{n-1}\}\) 当作多项式 \(A\) 的系数表示法,设 \(b_i=\omega_n^{-i}\)

则多项式 \(A\) 在 \(x=b_0,b_1,\cdots,b_{n-1}\) 处的点值表示法为 \(\{A(b_0),A(b_1),\cdots\,A(b_{n-1})\}\)

对 \(A(x)\) 的定义式做一下变换,可以将 \(A(b_k)\) 表示为

\[\begin{aligned}
A(b_k) &= \sum_{i=0}^{n-1}F(\omega_n^i)\omega_n^{-ik}
=\sum_{i=0}^{n-1}\omega_n^{-ik}\sum_{j=0}^{n-1}a_j(\omega_n^i)^j \\
&= \sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{i(j-k)}
=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i
\end{aligned}
\]

记 \(S(\omega_n^a)=\sum_{i=0}^{n-1}(\omega_n^a)^i\)。当 \(a=0\) 时,\(S(\omega_n^a)=n\);当 \(a\ne0\) 时,错位相减

\[\begin{aligned}
S(\omega_n^a) &= \sum_{i=0}^{n-1}(\omega_n^a)^i \\
\omega_n^a S(\omega_n^a) &= \sum_{i=1}^n (\omega_n^a)^i\\
S(\omega_n^a) &= \dfrac{(\omega_n^a)^n-(\omega_n^a)^0}{\omega_n^a-1}=0
\end{aligned}
\]

也就是说 \(S(\omega_n^a)=\left\{\begin{aligned}n,a=0\\0,a\ne0\end{aligned}\right.\),那么代回原式\(A(b_k)=\sum_{j=0}^{n-1}a_jS(\omega_n^{j-k})=a_k\cdot n\)

也就是说给定 $b_i=\omega_n^{-i},则 $ \(A\) 的点值表示法为

\(\{(b_0,A(b_0)),(b_1,A(b_1)),\cdots,(b_{n-1},A(b_{n-1}))\}\)

\(=\{(b_0,a_0\cdot n),(b_1,a_1\cdot n,\cdots,(b_{n-1},a_{n-1}\cdot n\}\)

综上所述,我们取单位根为其倒数,对 \(\{y_0,y_1,\cdots,y_{n-1}\}\) 跑一遍 FFT ,然后除以 n 即可得到 \(F(x)\) 的系数表示法。

迭代实现

然而,如果直接用递归来打,本蒟蒻会因为自带的 \(O(\infin)\) 大常数被卡死。

可以研究以下序列变换过程

0 1 2 3 4 5 6 7

0 2 4 6 | 1 3 5 7

0 4 | 2 6 | 1 5 | 3 7

这时候发现:最终序列为原顺序二进制下的反转,直接模拟从下往上合并

可以用递推来预处理

\(\lfloor\dfrac{x}{2}\rfloor\) 的翻转值是已知的,而这个值右移一位就是 x 除了二进制个位的翻转值

若个位是 0 ,反转后最高位就是 0,否则最高位就是 1

Code

Luogu P3803

#include<bits/stdc++.h>
using namespace std;
const int N=2100000;
const double PI=acos(-1.0);
struct Comp {
    double x,y;
    inline Comp(double p=0.0,double q=0.0):x(p),y(q) { }
    inline Comp operator +(Comp o) { return Comp(x+o.x,y+o.y); }
    inline Comp operator -(Comp o) { return Comp(x-o.x,y-o.y); }
    inline Comp operator *(Comp o) { return Comp(x*o.x-y*o.y,x*o.y+y*o.x); }
}x[N],y[N];
int r[N];
inline void FFT(Comp *a,int n,int on) {
    register int i,j,k;
    register Comp wn,w,x,y;
    for(i=1;i<=n;i++)
        if(i<r[i])swap(a[i],a[r[i]]);
    for(i=1;i<n;i<<=1) {
        wn=Comp(cos(PI/i),on*sin(PI/i));
        for(j=0;j<n;j+=(i<<1)) {
            w=Comp(1,0);
            for(k=0;k<i;k++,w=w*wn) {
                x=a[j+k],y=w*a[i+j+k];
                a[j+k]=x+y,a[i+j+k]=x-y;
            }
        }
    }
    if(on==-1)
    for(i=0;i<=n;i++)
    a[i].x=floor(a[i].x/n+0.5);
}
int n,m,lim,L;
int main() {
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%lf",&x[i].x);
    for(int i=0;i<=m;i++)scanf("%lf",&y[i].x);
    for(lim=1,L=-1;lim<=n+m;lim<<=1,++L);
    for(int i=0;i<lim;i++)r[i]=(r[i>>1]>>1)|((i&1)<<L);
    FFT(x,lim,1),FFT(y,lim,1);
    for(int i=0;i<lim;i++)x[i]=x[i]*y[i];
    FFT(x,lim,-1);
    for(int i=0;i<=n+m;i++)printf("%d ",(int)x[i].x);
}

三步变两步优化

原方法中求了三次 FFT ,由于本蒟蒻常数较大,三次不是特别快

我们可以把 \(B(x)\) 放到 \(A(x)\) 的虚部上去,求出 \(A(x)^2\),然后把 \(A(x)\) 的虚部取出来除 2 就是答案了

正确性的证明:\((a+bi)^2=(a^2-b^2)+(2abi)\)

这样的话效率是原来的 \(\dfrac{2}{3}\)

#include<bits/stdc++.h>
using namespace std;
const int N=2100000;
const double PI=acos(-1.0);
struct Comp {
    double x,y;
    inline Comp(double p=0.0,double q=0.0):x(p),y(q) { }
    inline Comp operator +(Comp o) { return Comp(x+o.x,y+o.y); }
    inline Comp operator -(Comp o) { return Comp(x-o.x,y-o.y); }
    inline Comp operator *(Comp o) { return Comp(x*o.x-y*o.y,x*o.y+y*o.x); }
}x[N];
int r[N];
inline void FFT(Comp *a,int n,int on) {
    register int i,j,k;
    register Comp wn,w,x,y;
    for(i=1;i<=n;i++)
        if(i<r[i])swap(a[i],a[r[i]]);
    for(i=1;i<n;i<<=1) {
        wn=Comp(cos(PI/i),on*sin(PI/i));
        for(j=0;j<n;j+=(i<<1)) {
            w=Comp(1,0);
            for(k=0;k<i;k++,w=w*wn) {
                x=a[j+k],y=w*a[i+j+k];
                a[j+k]=x+y,a[i+j+k]=x-y;
            }
        }
    }
    if(on==-1)
    for(i=0;i<=n;i++)
    a[i].x=floor(a[i].x/n+0.5),
    a[i].y=floor(a[i].y/n/2+0.5);
}
int n,m,lim,L;
int main() {
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%lf",&x[i].x);
    for(int i=0;i<=m;i++)scanf("%lf",&x[i].y);
    for(lim=1,L=-1;lim<=n+m;lim<<=1,++L);
    for(int i=0;i<lim;i++)r[i]=(r[i>>1]>>1)|((i&1)<<L);
    FFT(x,lim,1);
    for(int i=0;i<lim;i++)x[i]=x[i]*x[i];
    FFT(x,lim,-1);
    for(int i=0;i<=n+m;i++)printf("%d ",(int)x[i].y);
}