其他分享
首页 > 其他分享> > FFT&NTT

FFT&NTT

作者:互联网

 已知多项式$A(x)=\sum _{i=0}^{N} a_ix^i,B(x)=\sum _{i=0}^{M} b_ix^i$求$A(x)*B(x)$.

  显然看出可以枚举两个多项式的系数,依次算出,时间$O(nm)$.

  太慢了!!怎么办?利用一个奇妙的东西:FFT

    对于一个多项式$A(x)=\sum _{i=0}^{N} a_ix^i$,可以取$N$个不同的$x$值,求得$N$个多项式值。将其作为点,即$(x_i,A(x_i))$

  FFT的大致思路就是

  1.   将多项式化为点值形式。
  2.         将点值相乘。即算出每一个$C(x)=A(x)*B(x)$
  3.        将新的点值转化回多项式形式。

 前置芝士:向量与复数

  向量,即有方向的量,在平面直角坐标系上可以用$(a,b)$表示。

  图形上即为由原点指向点$(a,b)$的有向线段。

  向量的模长为$\sqrt{a^2+b^2}$

  向量的幅角为向量逆时针旋转至与x轴正半轴重合时旋转的角度。

  向量的加减法满足平行四边形法则,即$\overrightarrow{m}(x_1,y_1)\pm \overrightarrow{n}(x_2,y_2) = \overrightarrow{p}(x_1 \pm x_2,y_1 \pm y_2)$

  

  定义虚数单位$i$ 满足$i^2=-1$,复数域$I$,形如$a+bi,(a,b\in \mathbb{R})$的数叫做复数。

  复数$a+bi$可以在坐标系中表示为$(a,b)$的向量。

  同时复数的加减法满足向量的加减法,模长与幅角的定义也与向量相同。

  若复数$z$的模长为$|z|$,幅角为$\theta$,根据坐标系则有

$z=|z|cos \theta +i|z|sin \theta$

  复数的乘法:

$(x_1+y_1 i)*(x_2+y_2 i)=x_1 x_2+x_2 y_1 i+x_1 y_2 i - y_1 y_2 $

$=(x_1 x_2 - y_1 y_2)+(x_1 y_2+x_2 y_1)i$

  并且两个复数相乘遵循一个规律:模长相乘,幅角相加。

   在坐标系中做一个单位圆,将单位圆等分成$n$份的$n$个向量所对应的复数称为$n$次单位根

  幅角最小的记为$\omega _n$,而幅角是$\omega _n$的$k$倍的单位根为$\omega _n^k$.

     

  8次单位根↑

  由于我们只需要$2^n$次单位根,所以以下单位根均为$2$的幂次单位根。

  单位根的性质:

  1.$\omega _n^{kn} =1 (k\in \mathbb{Z}) , \omega _n^k * \omega _n^j = \omega _n^{k+j}$ 

  根据复数乘法,很明显。

  2.$\omega _n^k =cos 2\pi \frac{k}{n} +isin 2\pi \frac{k}{n}$ 

  复数的三角表示法。

  3.$\omega _{tn}^{tk} =\omega _n^k$

  因为$cos 2\pi \frac{tk}{tn}+isin 2\pi \frac{tk}{tn}=cos 2\pi \frac{k}{n}+isin 2\pi \frac{k}{n}$.

  4.$\omega _{n}^{\frac {n}{2}}=-1$

 $\omega _{n}^{\frac {n}{2}}=cos 2\pi \frac{1}{2}+isin 2\pi \frac{1}{2}$

$=cos \pi +isin \pi$

$=-1$

快速傅里叶变换O(nlogn)

   设一个多项式$A(x)$的系数为$(a_0,a_1,a_2...a_{n-1})$.

   首先在$A(x)$后面补系数为$0$的项,直到$n$为$2$的幂数,方便接下来运算。

   我们可以将所有的$\omega _n^k k \in [0,n-1]$代入求得$n$个点值,并可以做出优化。

$A(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}$

$=(a_0+a_2x^2+...+a_{n-2}x^{n-2})+x(a_1+a_3x^2+...+a_{n-1}x^{n-1})$

      令$A_1(x)=a_0+a_2x^2+...+a_{n-2}x^{n-2},A_2(x)=a_1+a_3x^2+...+a_{n-1}x^{n-1}$.

$A(x)=A_1(x^2)+xA_2(x^2)$

      将$x=\omega _n^k (0 \leq k<\frac{n}{2})$代入上式

$A(\omega _n^k)=A_1(\omega _n^{2k})+\omega _n^kA_2(\omega _n^{2k})$

$=A_1(\omega _{\frac{n}{2}}^{k})+\omega _n^kA_2(\omega _{\frac{n}{2}}^{k})$

    同理将$x=\omega _n^{k+\frac{n}{2}} (0 \leq k<\frac{n}{2})$代入。

$A(\omega _n^{k+\frac{n}{2}})=A_1(\omega _n^{2k+n})-\omega _n^kA_2(\omega _n^{2k+n})$

$=A_1(\omega _n^{2k})-\omega _n^kA_2(\omega _n^{2k})$

$=A_1(\omega _{\frac{n}{2}}^{k})-\omega _n^kA_2(\omega _{\frac{n}{2}}^{k})$

  之后我们发现只要求出$A_1(\omega _{\frac{n}{2}}^{k})和A_2(\omega _{\frac{n}{2}}^{k})$就可以算出两个点值。而他们可以递归去求,并且刚好由$n$次变为了$\frac{n}{2}$次,时间复杂度类似线段树$O(nlogn)$.

  然后求出两个多项式的所有点值之后将他们分别相乘,得出新多项式的$N+M+1$个点值,这一步是$O(n)$的。

快速傅里叶逆变换O(nlogn)

  接下来我们只需要把点值形式转化为多项式形式即可。

  设多项式$A(x)(a_0,a_1,a_2...a_{n-1})$的点值表示为$(y_0,y_1,y_2...y_{n-1})$

  多项式$D(x)=\sum _{i=0}^{n-1} y_ix^i$,$D(x)$在$(\omega _n^0,\omega _n^{-1},\omega _n^{-2}...\omega _n^{-(n-1)})$的点值表示为$(c_0,c_1,c_2...c_{n-1})$

  则有

$c_k=\sum _{i=0}^{n-1} y_i(\omega _n^{-k})^{i}$

$=\sum _{i=0}^{n-1} \sum _{j=0}^{n-1} a_j\omega _n^{ij} \omega _n^{-ik}$

$=\sum _{i=0}^{n-1} \sum _{j=0}^{n-1} a_j(\omega _n^{j-k})^i$

$=\sum _{j=0}^{n-1} \sum _{i=0}^{n-1} a_j(\omega _n^{j-k})^i$

$=\sum _{j=0}^{n-1} a_j\sum _{i=0}^{n-1} (\omega _n^{j-k})^i$

  令$T(x)=\sum _{i=0}^{n-1} x^i$,则有

$T(\omega _n^{t})=1+\omega _n^t+(\omega _n^t)^2+...+(\omega _n^t)^{n-1}$  A式

$A*\omega _n^{t}$得:

$\omega _n^{t}T(\omega _n^{t})=\omega _n^t+(\omega _n^t)^2+...+(\omega _n^t)^{n}$ B式 

$B-A$得

$(\omega _n^{t}-1)T(\omega _n^{t})=(\omega _n^t)^{n}-1$

$(\omega _n^{t}-1)T(\omega _n^{t})=(\omega _n^n)^{t}-1=1-1=0$

    所以当$\omega _n^{t}-1!=0$时$T(\omega _n^{t})=0$

    当$\omega _n^{t}-1=0$时可以得到$\omega _n^{t}=1,t=0$.

    则$T(\omega _n^{t})=T(1)=\sum _{i=0}^{n-1} 1=n$.

    有了这个结论后我们来看这个式子:

$c_k=\sum _{j=0}^{n-1} a_j\sum _{i=0}^{n-1} (\omega _n^{j-k})^i$

$=\sum _{j=0}^{n-1} a_jT(\omega _n^{j-k})$

   当且仅当$j=k$时有值,即

$c_k=a_jn$

$a_j=\frac{a_k}{n}$

    所以我们只需要求出多项式$D(x)$在$(\omega _n^0,\omega _n^{-1},\omega _n^{-2}...\omega _n^{-(n-1)})$的点值表示即可算出$a_i$.

 递归版:

const db pi=acos(-1);
class cplx{
public:
    db x,y;
    cplx(){x=y=0;}
    cplx(const db a,const db b){x=a,y=b;}
    friend cplx operator +(const cplx a,const cplx b){return cplx(a.x+b.x,a.y+b.y);}
    friend cplx operator -(const cplx a,const cplx b){return cplx(a.x-b.x,a.y-b.y);}
    friend cplx operator *(const cplx a,const cplx b){return cplx(a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y);}
}a[maxn],b[maxn];
int N,M,lim=1;
void fft(int lm,cplx *a,int op){
    if(lm==1) return;
    cplx a1[lm>>1],a2[lm>>1];
    for(int i=0;i<=lm;i+=2)
        a1[i>>1]=a[i],a2[i>>1]=a[i+1];
    fft(lm>>1,a1,op);
    fft(lm>>1,a2,op);
    cplx w1=cplx(cos(2*pi/lm),op*sin(2*pi/lm)),wk=cplx(1,0);
    for(int i=0;i<(lm>>1);i++,wk=wk*w1){
        cplx b=wk*a2[i];
        a[i]=a1[i]+b;
        a[i+(lm>>1)]=a1[i]-b;
    }
}
int MAIN(){
    cin>>N>>M;
    for(int i=0;i<=N;i++) scanf("%lf",&a[i].x);
    for(int i=0;i<=M;i++) scanf("%lf",&b[i].x);
    while(lim<=N+M) lim<<=1;
    fft(lim,a,1);
    fft(lim,b,1);
    for(int i=0;i<=lim;i++) a[i]=a[i]*b[i];
    fft(lim,a,-1);
    for(int i=0;i<=N+M;i++) prt(a[i].x/lim);
    return 0;
}

 但是我们发现,这种写法需要很多次复制数组,既耗内存也耗空间。

 迭代优化:

  我们写出$n=8$时的递归详细:

  我们发现一个神奇的性质:递归到最底层时实际的值为原下标的二进制翻转!!(具体证明见文末)

    于是我们没有必要再进行递归,只需要将数组调换至最底层的状态然后一层一层往回的迭代即可!

  

const db pi=acos(-1);
class cplx{
public:
    db x,y;
    cplx(){x=y=0;}
    cplx(const db a,const db b){x=a,y=b;}
    friend cplx operator +(const cplx a,const cplx b){return cplx(a.x+b.x,a.y+b.y);}
    friend cplx operator -(const cplx a,const cplx b){return cplx(a.x-b.x,a.y-b.y);}
    friend cplx operator *(const cplx a,const cplx b){return cplx(a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y);}
}a[maxn],b[maxn];
int N,M,lim=1,tr[maxn],l=0;
void fft(cplx *a,int op){
    for(int i=0;i<lim;i++) if(i<tr[i])swap(a[i],a[tr[i]]);
    for(int m=1;m<lim;m<<=1){
        cplx w1(cos(pi/m),op*sin(pi/m));
        int len=m<<1;
        for(int i=0;i<lim;i+=len){
            cplx wk(1,0);
            for(int k=0;k<m;k++,wk=wk*w1){
                cplx a1=a[i+k],a2=wk*a[i+m+k];
                a[i+k]=a1+a2;
                a[i+m+k]=a1-a2;
            }
        }
    }
}
int MAIN(){
    cin>>N>>M;
    for(int i=0;i<=N;i++) scanf("%lf",&a[i].x);
    for(int i=0;i<=M;i++) scanf("%lf",&b[i].x);
    while(lim<=N+M) lim<<=1,++l;
    for(int i=1;i<lim;i++){
        tr[i]=(tr[i>>1]>>1)|((i&1)?(1<<(l-1)):0);
    }
    fft(a,1);
    fft(b,1);
    for(int i=0;i<=lim;i++) a[i]=a[i]*b[i];
    fft(a,-1);
    for(int i=0;i<=N+M;i++) prt(a[i].x/lim);
    return 0;
}

总时间复杂度为O(nlogn).

  我们发现,FFT中因为要用到三角函数以及浮点数的运算,精度得不到保障,并且复数的常数较大,我们可以进行优化:

  引入概念:

  设m是正整数,a是整数,若a模m的阶等于φ(m),则称a为模m的一个原根。(其中φ(m)表示m的欧拉函数)

  先不用管原根的定义,扔出一个结论(设$g$为$P$的原根):

$\omega _n \equiv g^{\frac{P-1}{n}} (mod P)$

   原根满足这样的性质:

$g^i != g^j (mod P,i!=j)$

  并且根据费马小定理:

$\omega _n^n \equiv g^{P-1} =1 (mod P)$

  所以我们知道原根的性质与单位根类似,可以用$g^{\frac {P-1}{n}}$来代替$\omega _n$.

  如何求质数$P$的原根?

  首先需要知道满足$a^n \equiv 1 (mod P)$的最小$n$值一定满足$n|P-1$.

  质因数分解$P-1=\prod p_i^{a_i}$

  那么如果有$m|P-1,n|P-1,m|n,a^m \equiv 1(mod P)$,则有$a^n \equiv 1(mod P)$

  所以要验证一个数$t$是不是原根,要枚举每一个$p_i$,均满足$t^{\frac{P-1}{p_i}}!=1(mod P)$成立,则$t$是原根。

  $P$一般取$998244353$,他的原根是$3$.

const int maxn=(1<<21)+5,mod=998244353,g=3;
int qp(int x,int y){
    long long ans=1;
    while(y){
        if(y&1) ans=ans*x%mod;
        x=((long long)x*x)%mod;
        y>>=1;
    }
    return (int)ans;
}
const int ginv=qp(g,mod-2);
int a[maxn],b[maxn];
int N,M,lim=1,tr[maxn],l=0;
void ntt(int *a,int op){
    for(int i=0;i<lim;i++) if(i<tr[i])swap(a[i],a[tr[i]]);
    for(int m=1;m<lim;m<<=1){
        int len=m<<1;
        int g1=qp(op==1?g:ginv,(mod-1)/len);
        for(int i=0;i<lim;i+=len){
            int gk=1;
            for(int k=0;k<m;k++,gk=(long long)gk*g1%mod){
                int a1=a[i+k],a2=(long long)gk*a[i+m+k]%mod;
                a[i+k]=(a1+a2)%mod;
                a[i+m+k]=(a1-a2+mod)%mod;
            }
        }
    }
}
int MAIN(){
    cin>>N>>M;
    for(int i=0;i<=N;i++) scanf("%d",&a[i]);
    for(int i=0;i<=M;i++) scanf("%d",&b[i]);
    while(lim<=N+M) lim<<=1,++l;
    for(int i=1;i<lim;i++){
        tr[i]=(tr[i>>1]>>1)|((i&1)?(1<<(l-1)):0);
    }
    ntt(a,1);
    ntt(b,1);
    for(int i=0;i<=lim;i++) a[i]=((long long)a[i]*b[i])%mod;
    ntt(a,-1);
    int ny=qp(lim,mod-2);
    for(int i=0;i<=N+M;i++) printf("%lld ",(long long)a[i]*ny%mod);
    return 0;
}

保证了无精度误差,并且跑的飞快,大概是FFT速度2倍。

关于二进制翻转的证明:

  显然,在前$i$层对应着原下标的前$i$位,向左即为$0$,向右即为$1$.

  而前$i$层对应实际系数下标的后$i$位,向左即为$0$,向右即为$1$,因为选出奇数代表选择$1$,偶数代表$0$

  所以对于任意一层,原下标的前$i$位均相等,实际系数下标的后$i$位均相等,且两者有着翻转关系。

  在最底层即为原下标是实际下标的翻转。

标签:const,int,cplx,sum,FFT,NTT,frac,omega
来源: https://www.cnblogs.com/lnxxqz/p/16311455.html