快速 Fourier 变换,被称为 20 世纪最伟大的十大算法之一。所以很多软件都有对应的 FFT,例如 Python 的 scipy.fftpack 中就有关于 FFT 的包。所以个人写 FFT 就没有那么必要了。但是 NFT 的包一般都没多少,而且会写 NFT 必然就会写 FFT 了。大约在 5 年前就写过 NFT 的 C++ 代码,现在一看依然不记得蝶式映射到底是咋想的,所以想独立思考出整个过程(失败)。

此博文特别划水,建议直接阅读 miskcoo 博文:从多项式乘法到快速傅里叶变换

发现 NFT 一个很大的限制就是你只能在 NFT-friendly 的域(例如 $\mod 998244353 = 119 \cdot2^{23}$,原根为 $3$ )或者模很小的数的环中处理。即环中所有运算放在整数环中都不会超过选择的大基底。

选择 $n$ 个 NFT-friendly 的大基底 $p_1,\cdots, p_n$ 使得 $p_1\cdots p_n$ 大于 ans 的上界,然后再用 中国剩余定理 就可以把 ans 搞出来了。

基底的选择

我们考虑 $\mod p$ 构成的域。即运算默认是 $\mod p$ 的(除了指数上的幂次数),因为原根定理,此形式必有原根 $g$,即存在 $\mod p$ 中所有元素都可以写成 $g^n$ 的形式(所以 $g^{p-1}=1,g^{n}\neq 1, 0<n<p-1$)。而我们做 NFT 是需要找一个元素 $w$,使得 $w^{2^k} = 1$,因此我们需要找素数 $p$,使得 $p-1=c \cdot 2^k$,其中 $c$ 是个小奇数。

查找基底的 SageMath 代码

1
2
3
4
5
6
7
ans = []
for i in range(20,32): # 调节这个数值范围来找自己想要的 p
for j in range(1,25,2):
if(is_prime(j*2^i+1)):
ans.append(j*2^i+1)
for i in ans:
print("1 + ",factor(i-1),"\t=\t",i)

我们会发现有很多可供选择的例子,其中的娇楚($c$ 较小,$k$ 较大,$p^2 < 2^{63}$)

  • x_1 = 1 + 2^27 * 15 十分推荐!$x_2$ 刚好不超过 INT_MAX,所以在乘积取模之前还多一次加法运算,就很方便!
  • x_2 = 1 + 2^27 * 17 是平方不超过 LL(long long) 中最大的一个,但是不推荐,因为 $2*x_1^2$ 超了 LL
  • x_3 = 1 + 2^21 * 479 是网上常见的一个,但并不推荐。$c$ 太大了!
  • x_4 = 1 + 2^12 * 3 是最小不超过INT16,并且 $c$ 特别小的一个!如果不用LL就很推荐
  • x_5 = 1 + 2^57 * 29 是不超过 INT64 中最推荐的一个!然后基础运算需要用 GCC 内建的 __int128

总之,$x_1$ 是最为推荐的,$x_2,x_3$ 很常见主要是因为国内第一篇比较完整的介绍 NFT 的是 大佬 miskcoo,他当时给的常数是 $x_2,x_3$,然后就人云亦云了!$x_5$ 很有意思,它敲好比 $2^{62}$ 小一点,然后它又大于 $1e9+9$,而 $1e9+7,1e9+9$ 这两个孪生素数又经常的出现在 ICPC/IO 中!但是,貌似也没啥用,见 NFT 模板代码的注释

用 SageMath 自带的 primtive_root 函数分别求对应的原根 $g_1=31,g_2=g_3=g_5=3,g_4=11$。

所以,在 LL 的数据范围内,我们可以使用 $x_1$,可以处理最长长度为 $2^{27}$ 的 NFT,最长为 $2^{26} \sim 6 \times 10^7$ 项的 NFT 多项式乘法。

我们现在存在 $w$,有 $w^N = 1,\; w^n \neq 1, 0< n < N$,其中 $N = 2^k, k<27$,有时我们用 $w_N$ 表明 $w$ 和 $N$ 的关系。

离散 Fourier 变换 DFT

对长度为 $N$ 的数列 $a_0,\cdots a_{N-1}$ 做离散 Fourier 变换得到数列 $\hat{a}_0 \cdots \hat{a}_{N-1}$:

写成矩阵形式:

即上述矩阵为 $A$,则 $a_{ij} = w^{ij}$, 即 $b_{ij}= w^{-ij}$,则 $AB = NI$,即 $A^{-1} = \frac{1}{N}(w^{-ij})_{N \times N}$。即我们得到了 Fourier 逆变换公式:

快速 Fourier 数论变换 NFT

记 $H = \frac{N}{2}$,

即长度为 $N$ 的 Fourier 变换可以其奇数项和偶数项的长度为 $\frac{N}{2}$ 的 Fourier 变换表出。于是递归的我们可以在 $O(n\log n)$ 时间复杂度求出。

递归太消耗计算时间了。因此我们需要给出非递归的版本

快速 NFT 图

下图出自 miskcoo 从多项式乘法到快速傅里叶变换

bit-reverse-miskcoo

从这个图发现,最终的计算顺序,是每个数的位倒序。处理的细节 miskcoo 博客写的特别清楚了!

当我想要修改 Miskcoo 的代码形式时,发现怎么修改都没他的好!

还有 NFT 可以用于求多项式的逆!也可见 Miskcoo 的博文

NFT 和卷积的关系

设 $a,b$ 是长度为 $N$(不必 2 的幂次,之前这个限制只是为了快速计算)的数列,则

Proof

最后一个式子成立是因为若 $i \neq j$,则 $\sum_{n=0} ^k w^{ni} a_i w^{(k-n)j}b_j=0$

所以我们有 $a \star b= \hat{\hat{a} \hat{b}}$,而多项式乘法只是卷积的一个例子。有些时候 计算式 一开始不是卷积形式,但是可以转换成卷积形式,再利用 FFT 或者 NFT 加速。

NFT 模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
using BI = __int128;
void bitreverse(BI *x,int len){ // note that bitreverse(i)=j
for(int i=0,j=0;i!=len;++i){
if(i>j) swap(x[i],x[j]);
for(int l=len>>1;(j^=l)<l;l>>=1);
}
}
// the mod must NFT-friendly or (len+1)*mod^2 < FM
const BI FM = BI(29)<<57|1, gg=3;
void nft(BI *x,int len,bool isInverse=false){
g = powmod(gg,(FM-1)/len,FM);
if(isInverse){
g = powmod(g,FM-2,FM);
BI invlen = powmod(BI(len),FM-2,FM);
for(int i=0;i!=len;++i){
x[i]=x[i]*invlen%FM;
}
}
bitreverse(x,len);
for(int half=1,step=2;half!=len;half<<=1,step<<=1){
BI wn = powmod(g,len/step,FM),w=1;
for(int i=0;i<len;i+=step,w=1){
for(int j = i;j<i+half;++j){
BI t=(w*x[j+half])%FM;
x[j+half]=(FM-t+x[j])%FM;
x[j]=(x[j]+t)%FM;
w = w*wn%FM;
}
}
}
}
void square(BI *a, int n){
int len = 1<<(32-__builtin_clz(2*n+1));
nft(a,len);
for(int i=0;i!=len;++i){
a[i]=a[i]*a[i]%FM;
}
nft(a,len,1);
}
void mul(BI *a,BI *b,int na,int nb){
int len = 1<<(32-__builtin_clz(na+nb+1));
nft(a,len);nft(b,len);
for(int i=0;i!=len;++i){
a[i] = a[i]*b[i]%FM;
}
nft(b,len,1);nft(a,len,1);
}

NFT 模板更新(2020/7/9)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<bits/stdc++.h>
using namespace std;
using LL = long long;
constexpr LL M = 998244353,ROOT=3;
LL powmod(LL x,LL n){
LL r(1);
while(n){
if(n&1) r=r*x%M;
n>>=1; x=x*x%M;
}
return r;
}
void bitreverse(vector<LL> &a){
for(int i=0,j=0;i!=a.size();++i){
if(i>j) swap(a[i],a[j]);
for(int l=a.size()>>1;(j^=l)<l;l>>=1);
}
}
void nft(vector<LL> &a,bool isInverse=false){
LL g = powmod(ROOT,(M-1)/a.size());
if(isInverse){
g = powmod(g,M-2);
LL invLen = powmod(LL(a.size()),M-2);
for(auto &x:a) x=x*invLen%M;
}
bitreverse(a);
vector<LL> w(a.size(),1);
for(int i=1;i!=w.size();++i) w[i] = w[i-1]*g%M;
auto addMod = [](LL x,LL y){return (x+=y)>=M?x-=M:x;};
for(int step=2,half = 1;half!=a.size();step<<=1,half<<=1){
for(int i=0,wstep=a.size()/step;i!=a.size();i+=step){
for(int j=i;j!=i+half;++j){
LL t = (a[j+half]*w[wstep*(j-i)])%M;
a[j+half]=addMod(a[j],M-t);
a[j]=addMod(a[j],t);
}
}
}
}
vector<LL> mul(vector<LL> a,vector<LL> b){
int sz=1,tot = a.size()+b.size()-1;
while(sz<tot) sz*=2;
a.resize(sz);b.resize(sz);
nft(a);nft(b);
for(int i=0;i!=sz;++i) a[i] = a[i]*b[i]%M;
nft(a,1);
a.resize(tot);
return a;
}

关于多项式乘法,求逆,带余除法的理论基础见下图,取自 cp-algorithm

mulInv

divRem