快速Fourier变换,被称为20世纪最伟大的十大算法之一。所以很多软件都有对应的FFT,例如Python的 scipy.fftpack中就有关于FFT的包。所以个人写FFT就没有那么必要了。但是NFT的包一般都没多少,而且会写NFT必然就会写FFT了。大约在5年前就写过NFT的C++代码,现在一看依然不记得蝶式映射到底是咋想的,所以想独立思考出整个过程(失败)。

此博文特别划水,建议直接阅读 miskcoo博文:从多项式乘法到快速傅里叶变换

发现NFT一个很大的限制就是你只能在 NFT-friendly 的域(例如 $\mod 998244353 = 119 \cdot2^{23}$,原根为$3$ )或者模很小的数的环中处理。即环中所有运算放在整数环中都不会超过选择的大基底。

选择$n$个 NFT-friendly 的大基底$p_1,\cdots, p_n$ 使得$p_1\cdots p_n$大于ans的上界,然后再用 中国剩余定理 就可以把ans搞出来了。

基底的选择

我们考虑$\mod p$ 构成的域。即运算默认是 $\mod p$ 的(除了指数上的幂次数),因为原根定理,此形式必有原根$g$,即存在$\mod p$ 中所有元素都可以写成 $g^n$ 的形式(所以$g^{p-1}=1,g^{n}\neq 1, 0<n<p-1$)。而我们做NFT是需要找一个元素 $w$,使得 $w^{2^k} = 1$,因此我们需要找素数$p$,使得$p-1=c \cdot 2^k$,其中$c$ 是个小奇数。

查找基底的SageMath代码:

1
2
3
4
5
6
7
ans = []
for i in range(20,32): # 调节这个数值范围来找自己想要的 p
for j in range(1,25,2):
if(is_prime(j*2^i+1)):
ans.append(j*2^i+1)
for i in ans:
print("1 + ",factor(i-1),"\t=\t",i)

我们会发现有很多可供选择的例子,其中的娇楚($c$ 较小,$k$ 较大,$p^2 < 2^{63}$)

  • x_1 = 1 + 2^27 * 15 十分推荐!$x_2$刚好不超过INT_MAX,所以在乘积取模之前还多一次加法运算,就很方便!
  • x_2 = 1 + 2^27 * 17 是平方不超过LL(long long) 中最大的一个,但是不推荐,因为$2*x_1^2$超了LL
  • x_3 = 1 + 2^21 * 479 是网上常见的一个,但并不推荐。$c$ 太大了!
  • x_4 = 1 + 2^12 * 3 是最小不超过INT16,并且$c$特别小的一个!如果不用LL就很推荐
  • x_5 = 1 + 2^57 * 29 是不超过 INT64 中最推荐的一个!然后基础运算需要用GCC内建的 __int128

总之,$x_1$ 是最为推荐的,$x_2,x_3$ 很常见主要是因为国内第一篇比较完整的介绍NFT的是 大佬miskcoo,他当时给的常数是$x_2,x_3$,然后就人云亦云了!$x_5$ 很有意思,它敲好比$2^{62}$小一点,然后它又大于$1e9+9$,而$1e9+7,1e9+9$这两个孪生素数又经常的出现在ICPC/IO中!但是,貌似也没啥用,见NFT模板代码的注释

用SageMath自带的primtive_root函数分别求对应的原根 $g_1=31,g_2=g_3=g_5=3,g_4=11$。

所以,在LL的数据范围内,我们可以使用$x_1$,可以处理最长长度为$2^{27}$ 的NFT,最长为$2^{26} \sim 6 \times 10^7$ 项的NFT多项式乘法。

我们现在存在$w$,有 $w^N = 1,\; w^n \neq 1, 0<n<N $,其中 $N = 2^k, k<27$,有时我们用$w_N$表明$w$和$N$的关系。

离散 Fourier 变换 DFT

对长度为$N$ 的数列$a_0,\cdots a_{N-1}$ 做离散Fourier 变换得到数列$\hat{a}_0 \cdots \hat{a}_{N-1}$:

写成矩阵形式:

即上述矩阵为$A$,则 $a_{ij} = w^{ij}$, 即$b_{ij}= w^{-ij}$,则$AB = NI$,即$A^{-1} = \frac{1}{N}(w^{-ij})_{N \times N}$。即我们得到了Fourier逆变换公式:

快速Fourier数论变换NFT

记 $H = \frac{N}{2}$,

即长度为$N$的Fourier变换可以其奇数项和偶数项的长度为$\frac{N}{2}$的Fourier变换表出。于是递归的我们可以在$O(n\log n)$ 时间复杂度求出。

递归太消耗计算时间了。因此我们需要给出非递归的版本

快速NFT图:

下图出自 miskcoo从多项式乘法到快速傅里叶变换

bit-reverse-miskcoo

从这个图发现,最终的计算顺序,是每个数的位倒序。处理的细节miskcoo博客写的特别清楚了!

当我想要修改miskcoo的代码形式时,发现怎么修改都没他的好!

还有NFT可以用于求多项式的逆!也可见miskcoo的博文

NFT 和卷积的关系

设$a,b$ 是长度为$N$(不必2的幂次,之前这个限制只是为了快速计算)的数列,则

Proof

最后一个式子成立是因为若$i \neq j$,则 $\sum_{n=0} ^k w^{ni} a_i w^{(k-n)j}b_j=0$

所以我们有 $a \star b= \hat{\hat{a}\hat{b}}$,而多项式乘法只是卷积的一个例子。有些时候计算式一开始不是卷积形式,但是可以转换成卷积形式,再利用FFT或者NFT加速。

NFT模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
using BI = __int128;
void bitreverse(BI *x,int len){ // note that bitreverse(i)=j
for(int i=0,j=0;i!=len;++i){
if(i>j) swap(x[i],x[j]);
for(int l=len>>1;(j^=l)<l;l>>=1);
}
}
// the mod must NFT-friendly or (len+1)*mod^2 < FM
const BI FM = BI(29)<<57|1, gg=3;
void nft(BI *x,int len,bool isInverse=false){
g = powmod(gg,(FM-1)/len,FM);
if(isInverse){
g = powmod(g,FM-2,FM);
BI invlen = powmod(BI(len),FM-2,FM);
for(int i=0;i!=len;++i){
x[i]=x[i]*invlen%FM;
}
}
bitreverse(x,len);
for(int half=1,step=2;half!=len;half<<=1,step<<=1){
BI wn = powmod(g,len/step,FM),w=1;
for(int i=0;i<len;i+=step,w=1){
for(int j = i;j<i+half;++j){
BI t=(w*x[j+half])%FM;
x[j+half]=(FM-t+x[j])%FM;
x[j]=(x[j]+t)%FM;
w = w*wn%FM;
}
}
}
}
void square(BI *a, int n){
int len = 1<<(32-__builtin_clz(2*n+1));
nft(a,len);
for(int i=0;i!=len;++i){
a[i]=a[i]*a[i]%FM;
}
nft(a,len,1);
}
void mul(BI *a,BI *b,int na,int nb){
int len = 1<<(32-__builtin_clz(na+nb+1));
nft(a,len);nft(b,len);
for(int i=0;i!=len;++i){
a[i] = a[i]*b[i]%FM;
}
nft(b,len,1);nft(a,len,1);
}

NFT 模板更新(2020/7/9)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<bits/stdc++.h>
using namespace std;
using LL = long long;
constexpr LL M = 998244353,ROOT=3;
LL powmod(LL x,LL n){
LL r(1);
while(n){
if(n&1) r=r*x%M;
n>>=1; x=x*x%M;
}
return r;
}
void bitreverse(vector<LL> &a){
for(int i=0,j=0;i!=a.size();++i){
if(i>j) swap(a[i],a[j]);
for(int l=a.size()>>1;(j^=l)<l;l>>=1);
}
}
void nft(vector<LL> &a,bool isInverse=false){
LL g = powmod(ROOT,(M-1)/a.size());
if(isInverse){
g = powmod(g,M-2);
LL invLen = powmod(LL(a.size()),M-2);
for(auto &x:a) x=x*invLen%M;
}
bitreverse(a);
vector<LL> w(a.size(),1);
for(int i=1;i!=w.size();++i) w[i] = w[i-1]*g%M;
auto addMod = [](LL x,LL y){return (x+=y)>=M?x-=M:x;};
for(int step=2,half = 1;half!=a.size();step<<=1,half<<=1){
for(int i=0,wstep=a.size()/step;i!=a.size();i+=step){
for(int j=i;j!=i+half;++j){
LL t = (a[j+half]*w[wstep*(j-i)])%M;
a[j+half]=addMod(a[j],M-t);
a[j]=addMod(a[j],t);
}
}
}
}
vector<LL> mul(vector<LL> a,vector<LL> b){
int sz=1,tot = a.size()+b.size()-1;
while(sz<tot) sz*=2;
a.resize(sz);b.resize(sz);
nft(a);nft(b);
for(int i=0;i!=sz;++i) a[i] = a[i]*b[i]%M;
nft(a,1);
a.resize(tot);
return a;
}

关于多项式乘法,求逆,带余除法的理论基础见下图,取自cp-algorithm

mulInv

divRem