仅以此博文,感谢知乎好友Vivr0

中国剩余定理也称孙子定理,是中国古代求解一次同余方程组的方法。用现代的语言来说就是:

且正整数组$m_i$ 两两互素,则对任意整数组$a_i$,上述方程有解,解可以写成 $x \equiv a \mod m$

我们不要求 $m_i$ 两两互素也能求解,只是不一定有解,下面详细给出做法。

我们先考虑$n=2$的情形。即

我们可以把方程写成

我们设 $d = \gcd(m_1,m_2)$,则 $d| x-a_1$ 又 $d|m_2$,所以 $d|a_2-a_1$。

我们知道对任意正整数$a,b$, 存在整数$x,y$ 使得 $xa + yb = \gcd(a,b)$。

(最后Python 代码注释中有给出$x,y$ 的详细操作)

存在 $t_1,t_2$ 使得 $m_1t_1 + m_2t_2 = gcd(m_1,m_2) = d$,所以

即 $x \equiv a \mod m$,其中 $a= a_1 + \frac{a_2-a_1}{d} t_1m_1 = \frac{t_2m_2a_1+t_1m_1a_2}{d}$,$m = lcm(m_1,m_2) = \frac{m_1m_2}{d}$

$n-1$ 次上述操作,就处理了一般情况

C++ 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;

LL exgcd(LL a,LL b,LL& x,LL& y){ // ax + by = gcd(a,b)
if(b==0){
x=1;y=0;return a;
}
LL d=exgcd(b,a%b,y,x);
y-=a/b*x;
return d;
}

pair<LL,LL> crt2(LL a1,LL m1,LL a2,LL m2){ // x = ai mod mi, m_i >0
LL t1,t2,ans = a2-a1;
LL d = exgcd(m1,m2,t1,t2);
assert(ans%d == 0);
LL m = m1/d*m2;
ans = (a1+ans/d*t1%m2*m1)%m; // %m2 是避免溢出
return make_pair(ans>0?ans:ans+m,m);
}
const int N = 22;
LL a[N],m[N];
pair<LL,LL> crt(int n){ // x = a[i] mod m[i], m[i] >0
pair<LL,LL> ans = make_pair(a[0]%m[0],m[0]);
for(int i=1;i<n;++i){
ans = crt2(ans.first,ans.second,a[i],m[i]);
}
return ans;
}

int main(){
LL a1,m1,a2,m2;
while(cin>>a1>>m1>>a2>>m2){
LL ans = crt2(a1,m1,a2,m2).first;
cout<<ans<<endl;
if((ans-a1)%m1 || (ans-a2)%m2){
cout<<"something wrong"<<endl;
}
}
return 0;
}

Python代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# input : a,b natural number
# output: [gcd(a,b), x, y]
# ax + by = gcd(a,b)
# Algorithm: b(a//bx+y) + a%bx = gcd(b,a%b)
def exgcd(a,b):
if(b == 0): return [a,1,0]
[d,y,x] = exgcd(b,a%b)
return [d,x,y-a//b*x]

# input: x = ai mod m_i, mi>0, i=1,2
# output: x = a mod m
def crt2(a1,m1,a2,m2):
[d,t1,t2] = exgcd(m1,m2)
a,m = a2-a1,m1//d*m2
if(a%d): raise ValueError('No solution to crt problem')
return [(a1+a//d*t1*m1)%m,m]

# input: x = ai mod m_i, mi>0
# output: x = a mod m
def crt(a,m):
n = len(a)
if(len(m)!=n): raise ValueError('a and m must have equal length')
aa,mm = a[0],m[0]
for i in range(1,n):
[aa,mm] = crt2(aa,mm,a[i],m[i])
return [aa,mm]

if __name__ == "__main__":
[a,m]=crt([2,-4,5],[3,5,12])
print(a,m)
[a,m]=crt([2,-4,4],[3,5,12])
print(a,m)