算法总览

第一次整体的总结一下我的模板。

目录

数学

高斯消元

拉格朗日插值

原根

bsgs(ex)

lucas(ex)

crt(excrt)

二次剩余

杜教筛,min25筛,洲阁筛

Miller-Rabin,Pollard-rho

类欧几里得

线性基

矩阵求逆

行列式

矩阵树定理

组合计数

莫比乌斯反演

二项式反演

Min-Max容斥

子集反演

伯努利数

斯特林数

欧拉数

数据结构

lct维护子树信息

莫队

带修莫队

回滚莫队

树上莫队

二次离线莫队

Splay

无旋Treap

可持久化平衡树

线段树合并

线段树分裂

势能分析线段树

点分治/树

边分治/树

cdq分治

树状数组套主席树

线段树套平衡树

虚树

圆方树

长链剖分

KD-Tree

LCT

左偏树(可并堆)

猫树

支配树

李超树

划分树

笛卡尔树

分块/树分块

克鲁斯卡尔重构树

扫描线

欧拉游览树 ETT

Top Tree

Sqrt Tree

O(n)-O(1) ST表

析合树

网络流

最大流

最小割树

spfa 费用流

dinic 费用流

有源汇上下界最大流

无源汇上下界可行流

有源汇最小费用可行流

有源汇上下界可行流

有源汇最小流

最大独立集

最小割相关

多项式

多项式全家桶

多项式求逆

多项式开根

多项式除法

多项式求导和积分

多项式ln

多项式exp

泰勒展开

牛顿迭代

任意模数多项式乘法

多项式快速幂

多项式复合

多项式复合逆

拉格朗日反演

有标号连通图计数

边双连通图计数

多项式多点求值

多项式快速插值

下降幂多项式乘法

下降幂多项式转普通多项式

普通多项式转下降幂多项式

Chirp Z-Transform

常系数齐次线性递推

常系数非齐次线性递推

字符串

最小表示法

kmp

exkmp

manacher

ac自动机

回文自动机

后缀数组

后缀自动机

后缀平衡树

杂项

LGV引理

树hash

spfa优化

2-sat

差分约束

prufer序列

竞赛图

一般图最大匹配

一般图最大权匹配

最小树形图

fwt

exgcd

数学

高斯消元

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <cstdio>
#include <iostream>
#include <cmath>
using namespace std;
const double MIN=0.0000001;
int n;
bool zero,all;
double a[114][114],x[114];
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
for(int j=1;j<=n+1;j++)scanf("%lf",&a[i][j]);
}
int cnt=1;
for(int i=1;i<=n;i++){
if(cnt>n)break;
for(int j=i;j<=n;j++){
if(abs(a[j][cnt])>abs(a[i][cnt])){
for(int k=1;k<=n+1;k++)swap(a[i][k],a[j][k]);
}
}
if(abs(a[i][cnt])<MIN){
all=true;cnt++;i--;continue;
}
for(int j=i+1;j<=n;j++){
double t=a[j][cnt]/a[i][cnt];
for(int k=cnt;k<=n+1;k++){
a[j][k]-=a[i][k]*t;
}
}
cnt++;
}cnt=n;
for(int i=n;i>=1;i--){
if(cnt<1)break;
if(abs(a[i][cnt])<MIN)continue;
for(int j=1;j<=i-1;j++){
double t=a[j][cnt]/a[i][cnt];
for(int k=cnt;k<=n+1;k++){
a[j][k]-=a[i][k]*t;
}
}
cnt--;
}
for(int i=n;i>=1;i--){
zero=false;
for(int j=1;j<=n;j++){
if(abs(a[i][j])>MIN)zero=true;
}
if(!zero&&abs(a[i][n+1])>MIN){
printf("-1");return 0;
}
if(abs(a[i][i])>MIN){
if(!all)x[i]=a[i][n+1]/a[i][i];
}
}
if(all){
printf("0");return 0;
}
for(int i=1;i<=n;i++){
if(abs(x[i])<MIN)printf("x%d=0\n",i);
else printf("x%d=%.2lf\n",i,x[i]);
}
return 0;
}

拉格朗日插值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 5, mod = 1e9 + 7;

int n, k, tab[N], p[N], pcnt, f[N], pre[N], suf[N], fac[N], inv[N], ans;

int qpow(int x, int y) {
int ans = 1;
for (; y; y >>= 1, x = 1LL * x * x % mod)
if (y & 1) ans = 1LL * ans * x % mod;
return ans;
}

void sieve(int lim) {
f[1] = 1;
for (int i = 2; i <= lim; i++) {
if (!tab[i]) {
p[++pcnt] = i;
f[i] = qpow(i, k);
}
for (int j = 1; j <= pcnt && 1LL * i * p[j] <= lim; j++) {
tab[i * p[j]] = 1;
f[i * p[j]] = 1LL * f[i] * f[p[j]] % mod;
if (!(i % p[j])) break;
}
}
for (int i = 2; i <= lim; i++) f[i] = (f[i - 1] + f[i]) % mod;
}

int main() {
scanf("%d%d", &n, &k);
sieve(k + 2);
if (n <= k + 2) return printf("%d\n", f[n]) & 0;
pre[0] = suf[k + 3] = 1;
for (int i = 1; i <= k + 2; i++) pre[i] = 1LL * pre[i - 1] * (n - i) % mod;
for (int i = k + 2; i >= 1; i--) suf[i] = 1LL * suf[i + 1] * (n - i) % mod;
fac[0] = inv[0] = fac[1] = inv[1] = 1;
for (int i = 2; i <= k + 2; i++) {
fac[i] = 1LL * fac[i - 1] * i % mod;
inv[i] = 1LL * (mod - mod / i) * inv[mod % i] % mod;
}
for (int i = 2; i <= k + 2; i++) inv[i] = 1LL * inv[i - 1] * inv[i] % mod;
for (int i = 1; i <= k + 2; i++) {
int P = 1LL * pre[i - 1] * suf[i + 1] % mod;
int Q = 1LL * inv[i - 1] * inv[k + 2 - i] % mod;
int mul = ((k + 2 - i) & 1) ? -1 : 1;
ans = (ans + 1LL * (Q * mul + mod) % mod * P % mod * f[i] % mod) % mod;
}
printf("%d\n", ans);
return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int p=998244353;
typedef long long ll;
ll pow(ll a,ll b){
ll res=1;
while(b){
if(b&1)res=res*a%p;
a=a*a%p;
b>>=1;
}
return res;
}
int x[2010],y[2010];
int main(){
int n,k;
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)scanf("%d%d",&x[i],&y[i]);
ll res=0;
for(int i=1;i<=n;i++){
ll up=1,down=1;
for(int j=1;j<=n;j++){
if(i^j){
up=up*(k-x[j])%p;
down=down*(x[i]-x[j])%p;
}
}
res=(res+y[i]*up%p*pow(down,p-2)%p)%p;
}
printf("%lld\n",(res+p)%p);
return 0;
}

原根

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e6+10;
#define rint register int
#define rll register long long
typedef long long ll;
ll fpow(ll a,ll b,ll p){
rll res=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
int pri[N];
bool npri[N];
int p[N],tot,g[N],cnt;
int get(int x){
rint res=x;
for(rint i=1;pri[i]*pri[i]<=x;i++){
if(x%pri[i]==0){
res=res/pri[i]*(pri[i]-1);
while(x%pri[i]==0)x/=pri[i];
}
}
if(x>1)res=res/x*(x-1);
return res;
}
void init(){
rint n=1e6;
for(rint i=2;i<=n;i++){
if(!npri[i])pri[++pri[0]]=i;
for(rint j=1;j<=pri[0]&&i*pri[j]<=n;j++){
npri[i*pri[j]]=1;
if(i%pri[j]==0)break;
}
}
}
int gcd(int x,int y){
return y?gcd(y,x%y):x;
}
bool check(int x,int n){
for(rint i=1;i<=cnt;i++)
if(fpow(x,p[i],n)==1)return 0;
return 1;
}
int main(){
init();
int T;
scanf("%d",&T);
while(T--){
int n,d;
scanf("%d%d",&n,&d);
cnt=tot=0;
int phi=get(n);
for(rint i=2;i*i<=phi;i++)
if(phi%i==0)p[++cnt]=i,p[++cnt]=phi/i;
if(phi!=1)p[++cnt]=1;
rint res=0;
for(rint i=1;i<=n;i++){
if(gcd(i,n)==1&&check(i,n)){
res=i;
break;
}
}
if(res){
for(rint i=1;i<=phi;i++){
if(gcd(i,phi)==1)g[++tot]=fpow(res,i,n);
}
}
sort(g+1,g+tot+1);
printf("%d\n",tot);
for(rint i=d;i<=tot;i+=d)
printf("%d ",g[i]);
puts("");
}
return 0;
}

bsgs(ex)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include<map>
#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
map<ll,ll> mp;
ll fpow(ll a,ll b,ll p){
ll res=1;
while(b){
if(b&1)res=res*a%p;
a=a*a%p;
b>>=1;
}
return res;
}
int main(){
ll p,x,y;
scanf("%lld%lld%lld",&p,&x,&y);
ll sqr=sqrt(p)+1,now=y%p,tmp=1;
for(int i=0;i<sqr;i++){
mp[now]=i;
now=now*x%p;
tmp=tmp*x%p;
}
now=tmp;
for(int i=1;i<=sqr;i++){
if(mp.find(now)!=mp.end())return printf("%lld\n",i*sqr-mp[now]),0;
now=now*tmp%p;
}
puts("no solution");
return 0;
}


#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <unordered_map>

using namespace std;

typedef long long LL;

const int INF = 0x3f3f3f3f;

int a, b, p;
unordered_map<int, int> hs;

int exgcd(int a, int b, int &x, int &y) {
if (!b) {
x = 1, y = 0;
return a;
}
int d = exgcd(b, a % b, y, x);
y -= a / b * x;
return d;
}

int BSGS(int a, int b, int p) {
if (1 % p == b % p) return 0;
int k = sqrt(p) + 1;
hs.clear();
for (int y = 0, r = b % p; y < k; y++) {
hs[r] = y;
r = (LL)r * a % p;
}
int ak = 1;
for (int i = 1; i <= k; i++) ak = (LL)ak * a % p;
for (int x = 1, l = ak; x <= k; x++) {
if (hs.count(l)) return k * x - hs[l];
l = (LL)l * ak % p;
}
return -INF;
}

int exBSGS(int a, int b, int p) {
b = (b % p + p) % p;
if (1 % p == b % p) return 0;
int x, y;
int d = exgcd(a, p, x, y);
if (d > 1) {
if (b % d) return -INF;
exgcd(a / d, p / d, x, y);
return exBSGS(a, (LL)b / d * x % (p / d), p / d) + 1;
}
return BSGS(a, b, p);
}

int main() {
while (~scanf("%d%d%d", &a, &p, &b), a || b || p) {
int res = exBSGS(a, b, p);
if (res < 0) puts("No Solution");
else printf("%d\n", res);
}
return 0;
}

lucas(ex)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=2e5+10;
ll p,fac[N];
ll fpow(ll a,ll b){
ll res=1;
while(b){
if(b&1)res=res*a%p;
a=a*a%p;
b>>=1;
}
return res;
}
ll C(int n,int m){
if(n<m)return 0;
return fac[n]*fpow(fac[m],p-2)%p*fpow(fac[n-m],p-2)%p;
}
ll lucas(int n,int m){
if(!m)return 1;
return lucas(n/p,m/p)*C(n%p,m%p)%p;
}
int main(){
int T;
scanf("%d",&T);
while(T--){
int n,m;
scanf("%d%d%lld",&n,&m,&p);
fac[0]=fac[1]=1;
for(int i=2;i<p;i++)
fac[i]=fac[i-1]*i%p;
printf("%lld\n",lucas(n+m,m));
}
return 0;
}

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e6+2;
typedef long long ll;
#define rint register int
#define rll register long long
ll A[N],B[N],n,n1,n2,m,pri[N],prik[N],val[20][N];
int rk[N];
ll fpow(rll a,rll b,rll p){
rll res=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
void exgcd(rll a,rll b,rll &x,rll &y){
if(!b){
x=1;y=0;return ;
}
exgcd(b,a%b,y,x);
y-=a/b*x;
}
ll Inv(rll a,rll p){
rll x,y;
exgcd(a,p,x,y);
return (x%p+p)%p;
}
ll calc(rll x,rll p,rll pk){
if(!x)return 1;
rll res=fpow(val[rk[p]][pk],x/pk,pk)*val[rk[p]][x%pk]%pk;
return calc(x/p,p,pk)*res%pk;
}
ll calc2(rll x,rll p){
if(!x)return 0;
return calc2(x/p,p)+x/p;
}
ll exlc(rll n,rll m,rll p,rll pk){
if(n<m)return 0;
rll v1=calc(n,p,pk),v2=calc(m,p,pk)*calc(n-m,p,pk)%pk;
return v1*Inv(v2,pk)%pk*fpow(p,calc2(n,p)-calc2(m,p)-calc2(n-m,p),pk)%pk;
}
ll solve(rll p,rll pk){
rint Max=1<<n1,cnt=0;
rll res=0,s=0;
for(rint sta=0;sta<Max;sta++){
s=cnt=0;
for(rint i=1;i<=n1;i++)
if(sta&(1<<i-1))s+=A[i],cnt++;
if(cnt&1)res-=exlc(m-s-1,n-1,p,pk);
else res+=exlc(m-s-1,n-1,p,pk);
}
return (res%pk+pk)%pk;
}
int main(){
rint p;rll n,m;
scanf("%lld%lld%d",&n,&m,&p);
rint x=p;
for(rint i=2;i*i<=x;i++){
if(x%i==0){
pri[++pri[0]]=i;prik[pri[0]]=1;
while(x%i==0)x/=i,prik[pri[0]]*=i;
}
}
if(x>1)pri[++pri[0]]=x,prik[pri[0]]=x;
for(rint i=1;i<=pri[0];i++){
val[i][0]=1;rk[pri[i]]=i;
for(rint j=1;j<=prik[i];j++){
if(j%pri[i])val[i][j]=val[i][j-1]*j%prik[i];
else val[i][j]=val[i][j-1];
}
}
rll res=0;
for(rint i=1;i<=pri[0];i++){
ll x=exlc(n,m,pri[i],prik[i]);
ll x_=Inv(p/prik[i],prik[i]);
res=(res+x*x_%p*(p/prik[i])%p)%p;
}
res=(res%p+p)%p;
printf("%lld\n",res);
return 0;
}

crt(excrt)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=20;
typedef long long ll;
ll a[N],b[N];
void exgcd(ll a,ll b,ll &x,ll &y){
if(!b){
x=1;y=0;
return ;
}
exgcd(b,a%b,y,x);
y-=a/b*x;
}
int main(){
int n;
scanf("%d",&n);
ll M=1;
for(int i=1;i<=n;i++){
scanf("%lld%lld",&a[i],&b[i]);
M*=a[i];
}
ll res=0;
for(int i=1;i<=n;i++){
ll x,y;
exgcd(M/a[i],a[i],x,y);
if(x<0)x+=a[i];
res+=b[i]*M/a[i]*x;
}
printf("%lld\n",res%M);
}


#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e5+10;
#define int long long
int exgcd(int a,int b,int &x,int &y){
if(!b){
x=1;y=0;
return a;
}
int g=exgcd(b,a%b,y,x);
y-=a/b*x;
return g;
}
int Mul(int a,int b,int p){
int res=0,f=1;
if(b<0)f=-1,b=-b;
while(b){
if(b&1)res=(res+a)%p;
a=(a+a)%p;
b>>=1;
}
return res*f;
}
signed main(){
int n;
while(~scanf("%lld",&n)){
int ans,M,a,b,x,y,tag=0;
scanf("%lld%lld",&M,&ans);
for(int i=2;i<=n;i++){
scanf("%lld%lld",&b,&a);
if(tag)continue;
int g=exgcd(M,b,x,y);
int c=(a-ans%b+b)%b;
if(c%g)tag=1;
else{
int bg=b/g;
x=Mul(x,c/g,bg);
ans+=x*M;
M=M*bg;
ans=(ans%M+M)%M;
}
}
if(tag)puts("-1");
else printf("%lld\n",ans);
}
return 0;
}

二次剩余

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
#define rint register int
#define rll register long long
ll I,n,p;
inline ll Add(rll x,rll y){return (x+y)%p;}
inline ll Del(rll x,rll y){return ((x-y)%p+p)%p;}
inline ll Mul(rll x,rll y){x*=y;return x>=p?x%p:x;}
struct cp{
ll x,y;
cp operator * (const cp&A)const{return (cp){Add(Mul(x,A.x),Mul(Mul(I,y),A.y)),Add(Mul(x,A.y),Mul(y,A.x))};}
};
ll fpow(rll a,rll b){
rll res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
cp fpow(cp a,rll b){
cp res=(cp){1,0};
for(;b;b>>=1,a=a*a)
if(b&1)res=res*a;
return res;
}
inline bool check(rll x){
if(fpow(x,p-1>>1)==1)return 1;
return 0;
}
int main(){
int T;
scanf("%d",&T);
while(T--){
scanf("%lld%lld",&n,&p);
if(n==0)printf("0\n");
else if(!check(n))printf("Hola!\n");
else {
rll a;I=0;
while(I==0||check(I))a=rand()%p,I=Del(Mul(a,a),n);
rll res1=fpow((cp){a,1},p+1>>1).x;
rll res2=p-res1;
if(res1==res2)printf("%lld ",res1);
else printf("%lld %lld\n",min(res1,res2),max(res1,res2));
}
}
return 0;
}

杜教筛,min25筛,洲阁筛

约数个数

1
2
3
4
5
if(i%prime[j]==0){
f[i*prime[j]]=f[i]*2ll-f[i/prime[j]];
break;
}
f[i*prime[j]]=f[i]*2;

常用卷积

φ=idμ\varphi=id*\mu

φI=id\varphi*I=id

μI=ε\mu*I=\varepsilon

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<unordered_map>
using namespace std;
const int N=1e7+10;
typedef long long ll;
bool npri[N];
int pri[N];
ll mu[N],phi[N];
void get(int n){
mu[1]=phi[1]=1;
for(int i=2;i<=n;i++){
if(!npri[i])pri[++pri[0]]=i,mu[i]=-1,phi[i]=i-1;
for(int j=1;j<=pri[0]&&i*pri[j]<=n;j++){
npri[i*pri[j]]=1;
if(i%pri[j]==0){
phi[i*pri[j]]=phi[i]*pri[j];
break;
}
mu[i*pri[j]]=-mu[i];
phi[i*pri[j]]=phi[i]*phi[pri[j]];
}
}
for(int i=1;i<=n;i++)phi[i]+=phi[i-1],mu[i]+=mu[i-1];
}
unordered_map<int,ll> ans1,ans2;
ll getphi(int n){
if(n<=1e7)return phi[n];
if(ans1.find(n)!=ans1.end())return ans1[n];
ll res=0;
for(ll l=2,r;l<=n;l=r+1){
r=n/(n/l);
res-=getphi(n/l)*(r-l+1);
}
return ans1[n]=(unsigned long long)n*(n+1ll)/2+res;
}
ll getmu(int n){
if(n<=1e7)return mu[n];
if(ans2.find(n)!=ans2.end())return ans2[n];
ll res=1;
for(ll l=2,r;l<=n;l=r+1){
r=n/(n/l);
res-=getmu(n/l)*(r-l+1);
}
return ans2[n]=res;
}
int main(){
get(1e7);
int T;
scanf("%d",&T);
while(T--){
int n;
scanf("%d",&n);
printf("%lld %lld\n",getphi(n),getmu(n));
}
return 0;
}

可以用来求 ikφii^k\varphi_i

iφii\varphi_i为例,

写出卷积式dndφdg(n/d)\sum_{d|n}d\varphi_{d}g(n/d)

容易发现令g(n)=ng(n)=n原式即为dnφd=n\sum_{d|n}\varphi_{d}=n

其余次方的也类似

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e6+10;
const int p=1e9+7;
typedef long long ll;
ll fpow(ll a,ll b){
ll res;
for(res=1;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
ll Add(ll x,ll y){return x+y>=p?x+y-p:x+y;}
ll Del(ll x,ll y){return x-y>=0?x-y:x-y+p;}
ll n,sqr;
bool npri[N];
ll pri[N],s1[N],s2[N];
void get(int n){
for(int i=2;i<=n;i++){
if(!npri[i]){
pri[++pri[0]]=i;
s1[pri[0]]=Add(s1[pri[0]-1],i);
s2[pri[0]]=Add(s2[pri[0]-1],(ll)i*i%p);
}
for(int j=1;j<=pri[0]&&i*pri[j]<=n;j++){
npri[i*pri[j]]=1;
if(i%pri[j]==0)break;
}
}
}
ll g1[N],g2[N],w[N],v1[N],v2[N],tot;
ll calc(ll x,int y){
if(y&&x<=pri[y])return 0;
ll now=x<=sqr?v1[x]:v2[n/x];
ll res=Add(Del(g2[now],g1[now]),Del(s1[y],s2[y]));
for(int i=y+1;pri[i]*pri[i]<=x;i++){
ll np=pri[i];
for(int j=1;np<=x;j++,np=np*pri[i]){
ll tmp=np%p;
res=Add(res,(tmp-1)*tmp%p*Add(calc(x/np,i),j!=1)%p);
}
}
return res;
}
int main(){
scanf("%lld",&n);
sqr=sqrt(n);
get(1e6);
ll inv2=fpow(2,p-2),inv6=fpow(6,p-2);
for(ll l=1,r;l<=n;l=r+1){
r=n/(n/l);
ll now=n/l;
w[++tot]=now;
if(now<=sqr)v1[now]=tot;
else v2[n/now]=tot;
now%=p;
g1[tot]=Del((now+1)*now%p*inv2%p,1);
g2[tot]=Del((now+1)*now%p*(now*2+1)%p*inv6%p,1);
}
for(int i=1;pri[i]*pri[i]<=n;i++){
for(int j=1;j<=tot&&pri[i]*pri[i]<=w[j];j++){
ll now=w[j]/pri[i];
now=now<=sqr?v1[now]:v2[n/now];
g1[j]=Del(g1[j],pri[i]*Del(g1[now],s1[i-1])%p);
g2[j]=Del(g2[j],pri[i]*pri[i]%p*Del(g2[now],s2[i-1])%p);
}
}
printf("%lld\n",Add(1,calc(n,0)));
return 0;
}

Miller-Rabin,Pollard-rho

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int pri[]={2,3,5,7,11,13,17,19,23};
inline ll Mul(rll x,rll y,rll p){return ((unsigned long long)x*y-(unsigned long long)((long double)x/p*y)*p+2*p)%p;}
inline ll fpow(rll a,rll b,rll p){
rll res=1;
for(;b;b>>=1,a=Mul(a,a,p))
if(b&1)res=Mul(res,a,p);
return res;
}
bool check(rll x){
rll t1=x-1,t2=0;
while(~t1&1)t1>>=1,t2++;
for(rint i=0;i<9;i++){
if(x==pri[i])return 1;
if(x<pri[i])return 0;
rll t=fpow(pri[i],t1,x),lst;
for(rint j=0;j<t2;j++){
lst=t;t=Mul(t,t,x);
if(t==1&&lst!=1&&lst!=x-1)return 0;
}
if(t!=1)return 0;
}
return 1;
}
set<ll> res;
inline ll getnxt(rll x){
if(x==4)return 2;
rll v1=rand()%x+1,v2=rand()%x+1,v3=0,v4=1,v5=1;
for(rint i=0,j=2;v4==1;i++){
v1=(Mul(v1,v1,x)+v2)%x;
v5=Mul(v5,abs(v1-v3),x);
if(i%127==0&&v5!=1){
v4=__gcd(x,v5);
v5=1;
}
if(i==j){
j<<=1;v4=__gcd(x,v5);
v5=1;v3=v1;
}
}
return v4;
}
void dfs(rll x){
if(x==1)return ;
if(check(x)){
res.insert(x);
return ;
}
rll nxt=x;
while(nxt==x)nxt=getnxt(nxt);
while(x%nxt==0)x/=nxt;
dfs(x);dfs(nxt);
}
int main(){
rint T;
scanf("%d",&T);
while(T--){
rll x;
scanf("%lld",&x);
if(check(x))puts("Prime");
else {
res.clear();
dfs(x);
printf("%lld\n",*res.rbegin());
}
}
return 0;
}

类欧几里得

给定 n,a,b,cn,\,a,\,b,\,c ,分别求 i=0nai+bc, i=0nai+bc2, i=0niai+bc\sum\limits_{i=0}^{n}\lfloor \frac{ai+b}{c} \rfloor\,,\ \sum\limits_{i=0}^{n}{\lfloor \frac{ai+b}{c} \rfloor}^2\,,\ \sum\limits_{i=0}^{n}i\lfloor \frac{ai+b}{c} \rfloor ,答案对 998244353998244353 取模。多组数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN=22,P=998244353;
ll t,p,q,r,l;
struct Node {
Node () {cntu=cntr=sumi=sums=sqrs=prod=0;}
ll cntu,cntr,sumi,sums,sqrs,prod;
Node operator + (Node b) {
Node c;
c.cntu=(cntu+b.cntu)%P,c.cntr=(cntr+b.cntr)%P;
c.sumi=(sumi+b.sumi+cntr*b.cntr)%P;
c.sums=(sums+b.sums+cntu*b.cntr)%P;
c.sqrs=(sqrs+b.sqrs+((cntu*cntu)%P)*b.cntr+(2*cntu*b.sums)%P)%P;
c.prod=((prod+b.prod+((cntu*cntr)%P)*b.cntr)%P+cntu*b.sumi+cntr*b.sums)%P;
return c;
}
}nu,nr,ans;
Node qpow (Node a,ll k) {
Node res;
while (k) {
if (k&1) {res=res+a;}
a=a+a,k>>=1;
}
return res;
}
ll div (ll a,ll b,ll c,ll d) {return ((long double)1.0*a*b+c)/d;}
Node solve (ll p,ll q,ll r,ll l,Node a,Node b) {
if (!l) {return Node();}
if (p>=q) {return solve(p%q,q,r,l,a,qpow(a,p/q)+b);}
ll m=div(l,p,r,q);
if (!m) {return qpow(b,l);}
ll cnt=l-div(q,m,-r-1,p);
return qpow(b,(q-r-1)/p)+a+solve(q,p,(q-r-1)%p,m-1,b,a)+qpow(b,cnt);
}
int main () {
scanf("%lld",&t);
for (int ii=1;ii<=t;ii++) {
scanf("%lld%lld%lld%lld",&l,&p,&r,&q);
nu.cntu=1,nu.cntr=0,nu.sumi=0,nu.sums=0,nu.sqrs=0,nu.prod=0;
nr.cntu=0,nr.cntr=1,nr.sumi=1,nr.sums=0,nr.sqrs=0,nr.prod=0;
ans=qpow(nu,r/q)+solve(p,q,r%q,l,nu,nr);
printf("%lld %lld %lld\n",(ans.sums+r/q)%P,(ans.sqrs+((r/q)%P)*((r/q)%P))%P,ans.prod);
}
return 0;
}

线性基

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
struct xor_base{
ll b[55];
bool insert(rll x){
for(rint i=50;~i;i--){
if(~x>>i&1)continue;
if(!b[i]){
b[i]=x;
return 1;
}
x^=b[i];
}
return 1;
}
ll getmax(){
rll res=0;
for(rint i=50;~i;i--)
if((res^b[i])>res)res^=b[i];
return res;
}
}b;
int main(){
rint n;
scanf("%d",&n);
while(n--){
rll x;
scanf("%lld",&x);
b.insert(x);
}
printf("%lld\n",b.getmax());
return 0;
}

矩阵求逆

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
Mat getinv(Mat a){
Mat res;res.init();
for(int i=0;i<3;i++){
for(int j=i;j<3;j++){
if(a.mt[j][i]){
swap(a.mt[j],a.mt[i]);
swap(res.mt[j],res.mt[i]);
break;
}
}
assert(a.mt[i][i]!=0);
for(int j=0;j<3;j++){
if(!a.mt[j][i])continue;
if(i==j)continue;
ll w=a.mt[j][i]*fpow(a.mt[i][i],p-2)%p;
for(int k=0;k<3;k++){
a.mt[j][k]=(a.mt[j][k]-a.mt[i][k]*w%p+p)%p;
res.mt[j][k]=(res.mt[j][k]-res.mt[i][k]*w%p+p)%p;
}
}
}
for(int i=0;i<3;i++)
for(int j=0;j<3;j++)
res.mt[i][j]=res.mt[i][j]*fpow(a.mt[i][i],p-2)%p;
return res;
}

行列式,矩阵树定理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=310;
const int p=1e9+7;
typedef long long ll;
ll fpow(ll a,ll b){
ll res;
for(res=1;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
ll a[N][N];
ll solve(int n){
ll res=1;
for(int i=2;i<=n;i++){
if(a[i][i]==0){
for(int j=i+1;j<=n;j++)
if(a[j][i]){
swap(a[i],a[j]);
res=-res;goto L;
}
return 0;
}L:;
for(int j=i+1;j<=n;j++){
ll r=a[j][i]*fpow(a[i][i],p-2)%p;
for(int k=i;k<=n;k++)
a[j][k]=(a[j][k]+p-a[i][k]*r%p)%p;
}
}
for(int i=2;i<=n;i++)res=res*a[i][i]%p;
return (res+p)%p;
}
int main(){
int n,m,t;
scanf("%d%d%d",&n,&m,&t);
if(t){
while(m--){
int x,y,w;
scanf("%d%d%d",&x,&y,&w);
a[y][y]=(a[y][y]+w)%p;a[x][y]=(a[x][y]-w+p)%p;
}
}else{
while(m--){
int x,y,w;
scanf("%d%d%d",&x,&y,&w);
a[x][y]=(a[x][y]+p-w)%p;
a[y][x]=(a[y][x]+p-w)%p;
a[x][x]=(a[x][x]+w)%p;
a[y][y]=(a[y][y]+w)%p;
}
}
printf("%lld\n",solve(n));
return 0;
}

组合计数

莫比乌斯反演

n=dnϕ(d)n=\sum_{d|n}\phi(d)

f(x)=dxg(d)f(x)=\sum_{d|x}g(d)

那么有

g(x)=dxμ(d)f(xd)g(x)=\sum_{d|x}\mu(d)f(\frac{x}d)

二项式反演

Fi=j=in(ji)GjF_i=\sum_{j=i}^n\binom{j}{i}G_{j}

那么一定有

Gi=j=in(ji)Fj(1)jiG_i=\sum_{j=i}^n\binom{j}{i}F_j(-1)^{j-i}

Fi=j=0i(ij)GjF_i=\sum_{j=0}^i\binom{i}{j}G_{j}

Gi=j=0i(ij)Fj(1)ijG_i=\sum_{j=0}^i\binom{i}{j}F_j(-1)^{i-j}

Min-Max容斥

很多时候能够求出的最小值是一个期望,我们要求最大值的期望,期望只有在加减的时候才有线性性,所以并不能直接取 maxmax ,我们需要一个能够直接加减得到最大值的式子。

具体的,有

max(S)=TS&&T(1)T1min(T)\max(S) = \sum_{T \in S \&\&T\neq\emptyset} (-1)^{|T|-1}\min(T)

子集反演

G(S)=TS(1)STF(T)G(S) = \sum _{T \in S}(-1)^{|S-T|}F(T)

伯努利数

有一系列数 BiB_i ,首先有 B0=1B_0=1,对于别的,有

i=0k(k+1i)Bi=0\sum_{i=0}^k\binom{k+1}{i}B_i=0

这就是伯努利数看起来没什么用

由上式可以看出可以 O(n2)O(n^2) 预处理出伯努利数,然后呢?

再次扔出一个式子

i=0n1ik=1k+1i=0k(k+1i)Bink+1i\sum_{i=0}^{n-1} i^k=\frac{1}{k+1}\sum_{i=0}^k\binom{k+1}{i}B_in^{k+1-i}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
typedef long long ll;
const int maxn = 10000;
const int mod = 1e9 + 7;
ll B[maxn]; // 伯努利数
ll C[maxn][maxn]; // 组合数
ll inv[maxn]; // 逆元(计算伯努利数)

void init() {
// 预处理组合数
for (int i = 0; i < maxn; i++) {
C[i][0] = C[i][i] = 1;
for (int k = 1; k < i; k++) {
C[i][k] = (C[i - 1][k] % mod + C[i - 1][k - 1] % mod) % mod;
}
}
// 预处理逆元
inv[1] = 1;
for (int i = 2; i < maxn; i++) {
inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
// 预处理伯努利数
B[0] = 1;
for (int i = 1; i < maxn; i++) {
ll ans = 0;
if (i == maxn - 1) break;
for (int k = 0; k < i; k++) {
ans += C[i + 1][k] * B[k];
ans %= mod;
}
ans = (ans * (-inv[i + 1]) % mod + mod) % mod;
B[i] = ans;
}
}

斯特林数

[nm]=[n1m1]+(n1)[n1m]{n \brack m}={n-1 \brack m-1}+(n-1){n-1 \brack m}

n!=i=0n[ni]n!=\sum_{i=0}^n{n \brack i}

xn=i=0n[ni]xix^{\overline n}=\sum_{i=0}^n{n \brack i}x^i

{nm}={n1m1}+m{n1m}{n \brace m}={n-1 \brace m-1}+m{n-1 \brace m}

mn=i=1m(mi){ni}i!m^n=\sum_{i=1}^m\binom{m}{i}{n \brace i}i!

k=0ni=0maikixk(nk)=k=0ni=0maixk(nk)j=0i(kj)j!{ij}=i=0maij=0i{ij}j!k=0nxk(nk)(kj)\begin{aligned} \sum_{k=0}^n\sum_{i=0}^ma_ik^ix^k\binom{n}{k}&=\sum_{k=0}^n \sum_{i=0}^m a_ix^k\binom{n}{k} \sum_{j=0}^i \binom{k}{j}j! {i \brace j}\\ &=\sum_{i=0}^m a_i \sum_{j=0}^i {i \brace j}j! \sum_{k=0}^n x^k \binom{n}{k}\binom{k}{j} \end{aligned}

i=0maij=0i{ij}j!k=0nxk(nk)(kj)=i=0maij=0i{ij}j!k=0nxk(nj)(njkj)=i=0maij=0i{ij}j!(nj)k=jnxk(njkj)=i=0maij=0i{ij}n!(nj)!xjk=0njxk(njk)=i=0maij=0i{ij}n!(nj)!xj(x+1)nj\begin{aligned} \sum_{i=0}^m a_i \sum_{j=0}^i {i \brace j}j! \sum_{k=0}^n x^k \binom{n}{k}\binom{k}{j}&= \sum_{i=0}^m a_i \sum_{j=0}^i {i \brace j}j! \sum_{k=0}^n x^k \binom{n}{j}\binom{n-j}{k-j} \\ &=\sum_{i=0}^m a_i \sum_{j=0}^i {i \brace j}j!\binom{n}{j} \sum_{k=j}^n x^k \binom{n-j}{k-j}\\ &=\sum_{i=0}^m a_i \sum_{j=0}^i {i \brace j}\frac{n!}{(n-j)!} x^j\sum_{k=0}^{n-j} x^k \binom{n-j}{k} \\&= \sum_{i=0}^m a_i \sum_{j=0}^i {i \brace j}\frac{n!}{(n-j)!} x^j(x+1)^{n-j} \end{aligned}

第一类(行)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include<vector>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1<<19;
const int p=167772161;
#define rint register int
#define ll long long
#define rll register long long
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline ll Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
int r[N],w[N],pd,t;
ll fpow(rll a,rll b){
rll res=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd!=n){
pd=n;
for(rint i=0;i<n;i++)
r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
}
for(rint i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(rint mid=1;mid<n;mid<<=1)
for(rint i=0,R=mid<<1;i<n;i+=R)
for(rint j=0;j<mid;j++)
t=Mul(a[i+j+mid],w[mid+j]),a[i+j+mid]=Del(a[i+j],t),a[i+j]=Add(a[i+j],t);
if(~typ)return ;
rint finv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)
a[i]=Mul(a[i],finv);
}
ll fac[N],inv[N];
vector<int> vec[N<<2];
void dfs(rint x,rint l,rint r){
if(l==r){
vec[x].resize(2);
vec[x][0]=l;
vec[x][1]=1;
return ;
}
rint mid=l+r>>1;mid-=r%2==0;
dfs(x<<1,l,mid);
rint Max=1;while(Max<r-l+2)Max<<=1;
vector<int> f;f.resize(Max);vec[x<<1|1].resize(Max);
for(rint i=0,pw=1;i<=mid+1;i++){
f[i]=Mul(fac[i],vec[x<<1][i]);
vec[x<<1|1][i]=Mul(pw,inv[i]);
pw=Mul(pw,mid+1);
}
ntt(Max,f,-1);ntt(Max,vec[x<<1|1],1);
for(rint i=0;i<Max;i++)
vec[x<<1|1][i]=Mul(vec[x<<1|1][i],f[i]);
ntt(Max,vec[x<<1|1],1);
for(rint i=0;i<Max;i++)
vec[x<<1|1][i]=Mul(vec[x<<1|1][i],inv[i]);
vec[x<<1|1].resize(mid+2);
ntt(Max,vec[x<<1],1);ntt(Max,vec[x<<1|1],1);
vec[x].resize(Max);
for(rint i=0;i<Max;i++)
vec[x][i]=Mul(vec[x<<1][i],vec[x<<1|1][i]);
ntt(Max,vec[x],-1);
if(r%2==0){
for(rint i=Max;i;i--){
vec[x][i]=Add(Mul(vec[x][i],r),vec[x][i-1]);
}
vec[x][0]=Mul(vec[x][0],r);
}
}
int main(){
w[N/2]=1;w[N/2+1]=t=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],t);
for(rint i=N/2-1;i;i--)w[i]=w[i<<1];
rint n;
scanf("%d",&n);
fac[0]=fac[1]=inv[0]=inv[1]=1;
for(rint i=2;i<=n;i++){
fac[i]=fac[i-1]*i%p;
inv[i]=p-p/i*inv[p%i]%p;
}
for(rint i=2;i<=n;i++)
inv[i]=inv[i-1]*inv[i]%p;
if(n==1)vec[1].resize(1),vec[1][0]=1;
else dfs(1,0,n-1);
for(rint i=0;i<=n;i++)
printf("%d ",vec[1][i]);
return 0;
}

第二类(行)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e6+10;
const int p=167772161;
typedef long long ll;
ll fpow(ll a,ll b){
ll res;
for(res=1;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
ll A[N],B[N],fac[N];
#define Add(x,y) (x+y>=p?x+y-p:x+y)
#define Del(x,y) (x-y>=0?x-y:x-y+p)
int r[N];
void ntt(int n,ll *a,int typ){
for(int i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){
ll Wn=fpow(3,(p-1)/(mid*2));
if(typ==-1)Wn=fpow(Wn,p-2);
for(int i=0,R=mid<<1;i<n;i+=R){
ll w=1;
for(int j=0;j<mid;j++,w=w*Wn%p){
const ll x=a[i+j],y=w*a[i+j+mid]%p;
a[i+j]=Add(x,y);
a[i+j+mid]=Del(x,y);
}
}
}
}
int main(){
int n;
scanf("%d",&n);
fac[0]=1;
for(int i=1;i<=n;i++)fac[i]=fac[i-1]*i%p;
fac[n]=fpow(fac[n],p-2);
for(int i=n;i;i--)fac[i-1]=fac[i]*i%p;
for(int i=0;i<=n;i++){
A[i]=fpow(i,n)*fac[i]%p;
B[i]=(i&1)?p-fac[i]:fac[i];
}
int Max=1,l=0;
while(Max<=n+n)Max<<=1,l++;
for(int i=0;i<Max;i++)
r[i]=(r[i>>1]>>1|((i&1)<<l-1));
ntt(Max,A,1);ntt(Max,B,1);
for(int i=0;i<Max;i++)
A[i]=A[i]*B[i]%p;
ntt(Max,A,-1);
ll inv=fpow(Max,p-2);
for(int i=0;i<Max;i++)
A[i]=A[i]*inv%p;
for(int i=0;i<=n;i++)
printf("%lld ",A[i]);
return 0;
}

第二类(列)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include<vector>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1<<18;
const int p=167772161;
#define rint register int
#define ll long long
#define rll register long long
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline ll Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
int r[N],w[N],pd,t;
ll fpow(rll a,rll b){
rll res=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd!=n){
pd=n;
for(rint i=0;i<n;i++)
r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
}
for(rint i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(rint mid=1;mid<n;mid<<=1)
for(rint i=0;i<n;i+=mid<<1)
for(rint j=0;j<mid;j++)
t=Mul(w[mid+j],a[i+j+mid]),a[i+j+mid]=Del(a[i+j],t),a[i+j]=Add(a[i+j],t);
if(~typ)return ;
rint inv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)
a[i]=Mul(a[i],inv);
}
vector<int> vec[N<<2];
void dfs(rint x,rint l,rint r){
if(l==r){
vec[x].resize(2);
vec[x][0]=1;
vec[x][1]=p-l;
return ;
}
rint mid=l+r>>1;
dfs(x<<1,l,mid);
dfs(x<<1|1,mid+1,r);
rint Max=1;
while(Max<r-l+2)Max<<=1;
ntt(Max,vec[x<<1],1);ntt(Max,vec[x<<1|1],1);
vec[x].resize(Max);
for(rint i=0;i<Max;i++)
vec[x][i]=Mul(vec[x<<1][i],vec[x<<1|1][i]);
ntt(Max,vec[x],-1);
}
void dfs_inv(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){
f.resize(1);
f[0]=fpow(g[0],p-2);
return ;
}
dfs_inv(f,g,n+1>>1);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],Del(2,Mul(f[i],g[i])));
ntt(Max,f,-1);f.resize(n);
}
int ans[N],tot;
int main(){
w[N/2]=1;w[N/2+1]=t=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],t);
for(rint i=N/2-1;i;i--)w[i]=w[i<<1];
rint n,k;
scanf("%d%d",&n,&k);
dfs(1,1,k);dfs_inv(vec[1],vec[1],n+1);
for(rint i=1;i<=k;i++)
ans[tot++]=0;
for(rint i=0;i<=n;i++)
ans[tot++]=vec[1][i];
for(rint i=0;i<=n;i++)
printf("%d ",ans[i]);
return 0;
}

欧拉数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=1<<19;
const int p=998244353;
int w[N],r[N],pd;
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline ll Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
inline ll fpow(rll a,rll b){
rll res=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
unsigned long long A[N],t;
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd!=n){
for(rint i=0;i<n;i++)
r[i]=(r[i>>1]>>1|((i&1)?n>>1:0));
pd=n;
}
for(rint i=0;i<n;i++)
A[i]=a[r[i]];
for(rint mid=1;mid<n;mid<<=1)
for(rint i=0;i<n;i+=mid<<1)
for(rint j=0;j<mid;j++)
t=w[mid+j]*A[i+j+mid]%p,A[i+j+mid]=A[i+j]-t+p,A[i+j]+=t;
for(rint i=0;i<n;i++)
a[i]=A[i]%p;
if(~typ)return ;
reverse(a.begin()+1,a.end());
rint inv=fpow(n,p-2);
for(rint i=0;i<n;i++)
a[i]=Mul(a[i],inv);
}
ll fac[N],inv[N];
ll C(rint n,rint m){
return fac[n]*inv[m]%p*inv[n-m]%p;
}
int main(){
w[N/2]=1;w[N/2+1]=t=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],t);
for(rint i=N/2-1;i;i--)w[i]=w[i<<1];
rint n;
cin >> n;
fac[0]=1;
for(rint i=1;i<=n+1;i++)
fac[i]=fac[i-1]*i%p;
inv[n+1]=fpow(fac[n+1],p-2);
for(rint i=n+1;i;i--)
inv[i-1]=inv[i]*i%p;
vector<int> f,g;
f.resize(n+1);g.resize(n+1);
for(rint i=0;i<=n;i++)
f[i]=fpow(i+1,n);
for(rint i=0;i<=n;i++)
g[i]=(i&1)?p-C(n+1,i):C(n+1,i);
rint Max=1;
while(Max<=n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);
for(rint i=0;i<=n;i++)
printf("%d ",f[i]);
return 0;
}

如何优雅的求和

计算 k=0nf(k)(nk)xk(1x)nk\sum_{k=0}^{n}f(k)\binom{n}{k}x^k(1-x)^{n-k}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = (l); i <= (r); i++)
#define per(i, r, l) for (int i = (r); i >= (l); i--)
#define mem(a, b) memset(a, b, sizeof a)
#define For(i, l, r) for (int i = (l), i##e = (r); i < i##e; i++)
#define pb push_back

using namespace std;
using ll = long long;

const int N = 2e6 + 5, P = 998244353;

int n, m, x, f[N], a[N], fac[N], ifac[N];

ll Pow(ll a, int n, ll r = 1)
{
for (; n; n >>= 1, a = a * a % P)
if (n & 1)
r = r * a % P;
return r;
}
int main()
{
cin.tie(0)->sync_with_stdio(0);
int T;
cin >> T;
while (T--)
{
memset(a, 0, sizeof(a));
int A, B;
cin >> n >> m >> A >> B;
f[0] = 0;
x = A * Pow(B, P - 2) % P;
for (int i = 1; i <= m + 2; i++)
{
f[i] = (f[i - 1] + Pow(i, m)) % P;
}
++m;
m = min(n, m);
rep(i, 0, m) fac[i] = i ? (ll)fac[i - 1] * i % P : 1;
per(i, m + 1, 1) ifac[i - 1] = i > m ? Pow(fac[m], P - 2) : (ll)ifac[i] * i % P;
rep(i, 1, m) ifac[i] = (ll)ifac[i] * fac[i - 1] % P;
int p = 1, ans = 0;
if (n == m)
per(i, m, 0) a[i] = p, p = p * ll(P + 1 - x) % P;
else
per(i, m, 0)
{
a[i] = (a[i + 1] * ll(P + 1 - x) + p) % P;
p = p * ll(n - i) % P * ifac[m - i + 1] % P * (P - x) % P;
}
p = 1;
rep(i, 0, m)
{
ans = (ans + (ll)f[i] * p % P * a[i]) % P;
p = p * ll(n - i) % P * ifac[i + 1] % P * x % P;
}
cout << ans << '\n';
}
return 0;
}

数据结构

lct维护子树信息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=2e5+10;
struct Lct{
int ch[N][2],fa[N],siz[N],siz2[N],tag[N];
void up(int x){siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1+siz2[x];}
bool nrt(int x){
return ch[fa[x]][0]==x||ch[fa[x]][1]==x;
}
void rotate(int x){
int y=fa[x];int z=fa[y];
int k=ch[y][1]==x;
if(nrt(y))ch[z][ch[z][1]==y]=x;
fa[x]=z;ch[y][k]=ch[x][!k];
fa[ch[x][!k]]=y;ch[x][!k]=y;
fa[y]=x;up(y);up(x);
}
void down(int x){
if(tag[x]){
swap(ch[x][0],ch[x][1]);
tag[ch[x][0]]^=1;
tag[ch[x][1]]^=1;
tag[x]=0;
}
}
void update(int x){
if(nrt(x))update(fa[x]);
down(x);
}
void splay(int x){
update(x);
for(int y=fa[x],z=fa[y];nrt(x);rotate(x),y=fa[x],z=fa[y])
if(nrt(y))(ch[y][1]==x)==(ch[z][1]==y)?rotate(y):rotate(x);
}
void access(int x){
for(int t=0;x;t=x,x=fa[x]){
splay(x);
siz2[x]+=siz[ch[x][1]]-siz[t];
ch[x][1]=t;
up(x);
}
}
void mkrt(int x){
access(x);
splay(x);
tag[x]^=1;
}
void link(int x,int y){
mkrt(x);mkrt(y);
fa[x]=y;
siz2[y]+=siz[x];
up(y);
}
void split(int x,int y){
mkrt(x);
access(y);
splay(y);
}
}T;
int main(){
int n,q;
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)T.siz[i]=1;
while(q--){
char op;
int x,y;
scanf(" %c%d%d",&op,&x,&y);
if(op=='A')T.link(x,y);
else {
T.split(x,y);
printf("%lld\n",1ll*(T.siz2[x]+1)*(T.siz2[y]+1));
}
}
return 0;
}

莫队

带奇偶性优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=5e4+10;
int col[N],a[N],cnt[N];
ll up[N],down[N];
struct Node{
int l,r,id;
bool operator < (const Node&A)const{
return col[l]==col[A.l]?col[l]&1?r<A.r:r>A.r:col[l]<col[A.l];
}
}q[N];
int main(){
rint n,m;
scanf("%d%d",&n,&m);
rint sqr=sqrt(n);
for(rint i=1;i<=n;i++){
scanf("%d",&a[i]);
col[i]=(i-1)/sqr+1;
}
for(rint i=1;i<=m;i++){
scanf("%d%d",&q[i].l,&q[i].r);
q[i].id=i;
}
sort(q+1,q+m+1);
rll res=0;
for(rint i=1,cl=1,cr=0;i<=m;i++){
//先扩再缩,不然会出现问题
while(cr<q[i].r)res+=2*cnt[a[++cr]]++;
while(cl>q[i].l)res+=2*cnt[a[--cl]]++;
while(cr>q[i].r)res-=2*--cnt[a[cr--]];
while(cl<q[i].l)res-=2*--cnt[a[cl++]];
if(res){
up[q[i].id]=res;
down[q[i].id]=1ll*(cr-cl+1)*(cr-cl);
}else{
up[q[i].id]=0;
down[q[i].id]=1;
}
}
for(rint i=1;i<=m;i++){
rll g=__gcd(up[i],down[i]);
printf("%lld/%lld\n",up[i]/g,down[i]/g);
}
return 0;
}

带修莫队

本质上是三维莫队

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=1e6+10;
int sqr;
struct Node{
int l,r,t,id;
bool operator < (const Node&A)const{
if(l/sqr!=A.l/sqr)return l/sqr<A.l/sqr;
if(r/sqr!=A.r/sqr)return r/sqr<A.r/sqr;
return t/sqr<A.t/sqr;
}
}q[N];
struct xg{
int idx,x;
}r[N];
int cnt1,cnt2,col[N],cl,cr,ct,cnt[N],res,ans[N];
void Add(rint x){
if(cl<=r[x].idx&&r[x].idx<=cr){
res-=!--cnt[col[r[x].idx]];
res+=!cnt[r[x].x]++;
}
swap(r[x].x,col[r[x].idx]);
}
int main(){
rint n,m;
scanf("%d%d",&n,&m);
sqr=pow(n,2.0/3);
for(rint i=1;i<=n;i++)
scanf("%d",&col[i]);
for(rint i=1;i<=m;i++){
char s[10];
scanf("%s",s);
if(s[0]=='Q'){
++cnt1;
scanf("%d%d",&q[cnt1].l,&q[cnt1].r);
q[cnt1].t=cnt2;
q[cnt1].id=cnt1;
}else {
++cnt2;
scanf("%d%d",&r[cnt2].idx,&r[cnt2].x);
}
}
sort(q+1,q+cnt1+1);cr=ct=0;cl=1;
for(rint i=1;i<=cnt1;i++){
while(cr<q[i].r)res+=!cnt[col[++cr]]++;
while(cl>q[i].l)res+=!cnt[col[--cl]]++;
while(cr>q[i].r)res-=!--cnt[col[cr--]];
while(cl<q[i].l)res-=!--cnt[col[cl++]];
while(ct<q[i].t)Add(++ct);
while(ct>q[i].t)Add(ct--);
ans[q[i].id]=res;
}
for(rint i=1;i<=cnt1;i++)
printf("%d\n",ans[i]);
return 0;
}

回滚莫队

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<bits/stdc++.h>
using namespace std;
#define rint register int
const int N=2e5+10;
unordered_map<int,int> mp;
int sqr,cnt,L[N],R[N],col[N];
struct Node{
int l,r,id;
bool operator < (const Node&A)const{
return col[l]==col[A.l]?r<A.r:col[l]<col[A.l];
}
}q[N];
int a[N],pos1[N],pos2[N],ans[N];
int main(){
rint n;
scanf("%d",&n);
sqr=sqrt(n);
for(rint i=1;i<=n;i++){
scanf("%d",&a[i]);
if(!mp[a[i]])mp[a[i]]=++cnt;
a[i]=mp[a[i]];
col[i]=(i-1)/sqr+1;
if(!L[col[i]])L[col[i]]=i;
R[col[i]]=i;
}
rint m;
scanf("%d",&m);
for(rint i=1;i<=m;i++){
scanf("%d%d",&q[i].l,&q[i].r);
q[i].id=i;
}
sort(q+1,q+m+1);
for(rint c=1,k=1;c<=col[n];c++){
rint l=R[c]+1,r=R[c],nowr=R[c],now=0;
for(;col[q[k].l]==c;k++){
if(col[q[k].r]==c){
for(rint i=q[k].l;i<=q[k].r;i++){
if(!pos1[a[i]])pos1[a[i]]=i;
now=max(now,i-pos1[a[i]]);
}
for(rint i=q[k].l;i<=q[k].r;i++)
pos1[a[i]]=0;
ans[q[k].id]=now;
now=0;
}else {
while(r<q[k].r){
++r;
if(!pos1[a[r]])pos1[a[r]]=r;
pos2[a[r]]=r;
now=max(now,r-pos1[a[r]]);
}
rint tmp=now;
while(l>q[k].l){
--l;
if(!pos2[a[l]])pos2[a[l]]=l;
else now=max(now,pos2[a[l]]-l);
}
ans[q[k].id]=now;
now=tmp;
while(l<=nowr){
if(pos2[a[l]]==l)pos2[a[l]]=0;
l++;
}
}
}
memset(pos1,0,sizeof(pos1));
memset(pos2,0,sizeof(pos2));
}
for(rint i=1;i<=m;i++)
printf("%d\n",ans[i]);
return 0;
}

树上莫队

带修改了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=3e5+10;
struct Edge{
int to,nxt;
}e[N<<1];
int h[N],idx;
void Ins(int a,int b){
e[++idx].to=b;e[idx].nxt=h[a];h[a]=idx;
}
int fir[N],sec[N],rk[N],fa[N],tp[N],d[N],son[N],siz[N],v[N],w[N],c[N],Time;
void dfs1(rint u){
fir[u]=++Time;rk[Time]=u;
siz[u]=1;
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa[u])continue;
fa[v]=u;
d[v]=d[u]+1;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
sec[u]=++Time;rk[Time]=u;
}
void dfs2(rint u,rint t){
tp[u]=t;
if(son[u])dfs2(son[u],t);
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
int lca(rint x,rint y){
while(tp[x]^tp[y]){
if(d[tp[x]]<d[tp[y]])y=fa[tp[y]];
else x=fa[tp[x]];
}
return d[x]>d[y]?y:x;
}
int sqr,bel[N];
struct Ask{
int l,r,h,id,lc;
bool operator < (const Ask&A)const{
return bel[l]==bel[A.l]?bel[r]==bel[A.r]?h<A.h:r<A.r:l<A.l;
}
}q[N];
int c1,c2,v1[N],v2[N],cnt[N];
long long ans[N],res;
bool vis[N];
void upd(rint x){
if(vis[x]){
res-=1ll*w[cnt[c[x]]]*v[c[x]];
cnt[c[x]]--;
}else {
cnt[c[x]]++;
res+=1ll*w[cnt[c[x]]]*v[c[x]];
}
vis[x]^=1;
}
void xg(rint x){
if(vis[v1[x]]){
upd(v1[x]);
swap(c[v1[x]],v2[x]);
upd(v1[x]);
}else {
swap(c[v1[x]],v2[x]);
}
}
int main(){
rint n,m,ask;
scanf("%d%d%d",&n,&m,&ask);
for(rint i=1;i<=m;i++)scanf("%d",&v[i]);
for(rint i=1;i<=n;i++)scanf("%d",&w[i]);
for(rint i=2;i<=n;i++){
rint a,b;
scanf("%d%d",&a,&b);
Ins(a,b);Ins(b,a);
}
for(rint i=1;i<=n;i++)scanf("%d",&c[i]);
dfs1(1);
dfs2(1,1);
sqr=pow(Time,2.0/3.0);
for(rint i=1;i<=Time;i++)bel[i]=(i-1)/sqr+1;
for(rint i=1;i<=ask;i++){
rint opt;
scanf("%d",&opt);
if(opt==0){
c1++;
scanf("%d%d",&v1[c1],&v2[c1]);
}else {
c2++;
rint x,y;
scanf("%d%d",&x,&y);
rint lc=lca(x,y);
if(lc==x||lc==y){
if(lc==x){
q[c2].l=fir[x];
q[c2].r=fir[y];
q[c2].h=c1;
q[c2].id=c2;
q[c2].lc=-1;
}else {
q[c2].l=fir[y];
q[c2].r=fir[x];
q[c2].h=c1;
q[c2].id=c2;
q[c2].lc=-1;
}
}else {
if(fir[x]>sec[y]){
q[c2].l=sec[y];
q[c2].r=fir[x];
q[c2].h=c1;
q[c2].id=c2;
q[c2].lc=lc;
}else {
q[c2].l=sec[x];
q[c2].r=fir[y];
q[c2].h=c1;
q[c2].id=c2;
q[c2].lc=lc;
}
}
}
}
sort(q+1,q+c2+1);
for(rint i=1,x=1,y=0,z=0;i<=c2;i++){
while(y<q[i].r)upd(rk[++y]);
while(x>q[i].l)upd(rk[--x]);
while(y>q[i].r)upd(rk[y--]);
while(x<q[i].l)upd(rk[x++]);
while(z<q[i].h)xg(++z);
while(z>q[i].h)xg(z--);
if(q[i].lc!=-1)upd(q[i].lc);
ans[q[i].id]=res;
if(q[i].lc!=-1)upd(q[i].lc);
}
for(rint i=1;i<=c2;i++)
printf("%lld\n",ans[i]);
return 0;
}

二次离线莫队

二次离线是指,莫队是询问的一次离线,由于移动莫队端点时更新这个点的贡献可能比较困难,所以把移动的贡献重新算。

所有的贡献可以表示为 r+1r+1[l,r][l,r] 产生的贡献的类型,把它转化为 r+1r+1[1,r][1,r][1,l1][1,l-1] 的类型,之后离线这种询问,统一处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
#define rint register int
#define ll long long
#define rll register long long
int cnt[N],a[N],col[N];
struct Ask{
int l,r,id;
bool operator < (const Ask&A)const{
return col[l]==col[A.l]?col[l]&1?r<A.r:r>A.r:col[l]<col[A.l];
}
}q[N];
vector<int> vec;
struct Node{
int l,r,typ,id;
};
vector<Node> vec2[N];
ll ans[N],tmpans[N],tmpans2[N];
int main(){
rint n,m,k;
scanf("%d%d%d",&n,&m,&k);
for(rint i=0;i<16384;i++){
rint c=0;
for(rint x=i;x;x&=x-1)c++;
if(c==k)vec.push_back(i);
}
rint sqr=sqrt(n);
for(rint i=1;i<=n;i++){
col[i]=(i-1)/sqr+1;
scanf("%d",&a[i]);
}
for(rint i=1;i<=m;i++){
scanf("%d%d",&q[i].l,&q[i].r);
q[i].id=i;
}
sort(q+1,q+m+1);
for(rint i=1;i<=n;i++){
for(rint v:vec)
cnt[v^a[i]]++;
tmpans[i+1]=cnt[a[i+1]];
tmpans2[i]=cnt[a[i]];
}
for(rint i=1;i<=n;i++)
tmpans[i]+=tmpans[i-1],tmpans2[i]+=tmpans2[i-1];
for(rint i=1,cl=1,cr=0;i<=m;i++){
if(cr<q[i].r){
vec2[cl-1].push_back((Node){cr+1,q[i].r,0,i});
ans[i]+=tmpans[q[i].r]-tmpans[cr];
cr=q[i].r;
}
if(cl>q[i].l){
vec2[cr].push_back((Node){q[i].l,cl-1,1,i});
ans[i]-=tmpans2[cl-1]-tmpans2[q[i].l-1];
cl=q[i].l;
}
if(cr>q[i].r){
vec2[cl-1].push_back((Node){q[i].r+1,cr,1,i});
ans[i]-=tmpans[cr]-tmpans[q[i].r];
cr=q[i].r;
}
if(cl<q[i].l){
vec2[cr].push_back((Node){cl,q[i].l-1,0,i});
ans[i]+=tmpans2[q[i].l-1]-tmpans2[cl-1];
cl=q[i].l;
}
}
memset(cnt,0,sizeof(cnt));
for(rint i=1;i<=n;i++){
for(rint v:vec)
cnt[v^a[i]]++;
for(Node v:vec2[i]){
if(v.typ){
for(rint j=v.l;j<=v.r;j++)
ans[v.id]+=cnt[a[j]];
}else {
for(rint j=v.l;j<=v.r;j++)
ans[v.id]-=cnt[a[j]];
}
}
}
for(rint i=1;i<=m;i++){
ans[i]+=ans[i-1];
tmpans[q[i].id]=i;
}
for(rint i=1;i<=m;i++)
printf("%lld\n",ans[tmpans[i]]);
return 0;
}

Splay

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=1e5+10;
int fa[N],ch[N][2],rt,tot,siz[N],val[N],cnt[N];
void up(rint x){
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];
}
void rotate(rint x){
rint y=fa[x],z=fa[y];
rint k=ch[y][1]==x;
ch[z][ch[z][1]==y]=x;
fa[x]=z;
ch[y][k]=ch[x][k^1];
fa[ch[x][k^1]]=y;
ch[x][k^1]=y;
fa[y]=x;
up(y);up(x);
}
void splay(rint x,rint t){
for(rint y=fa[x],z=fa[y];fa[x]!=t;rotate(x),y=fa[x],z=fa[y])
if(z!=t)((ch[z][1]==y)==(ch[y][1]==x))?rotate(y):rotate(x);
if(!t)rt=x;
}
void find(rint x){
rint u=rt;
if(!u)return ;
while(ch[u][x>val[u]]&&x!=val[u])
u=ch[u][x>val[u]];
splay(u,0);
}
void insert(rint x){
rint u=rt,f=0;
while(u&&val[u]!=x){
f=u;
u=ch[u][x>val[u]];
}
if(u)cnt[u]++;
else {
u=++tot;
if(f)ch[f][x>val[f]]=u;
fa[tot]=f;
val[tot]=x;
cnt[tot]=siz[tot]=1;
ch[tot][0]=ch[tot][1]=0;
}
splay(u,0);
}
int pre(rint x){
find(x);
if(val[rt]<x)return rt;
int u=ch[rt][0];
while(ch[u][1])u=ch[u][1];
return u;
}
int nxt(rint x){
find(x);
if(val[rt]>x)return rt;
int u=ch[rt][1];
while(ch[u][0])u=ch[u][0];
return u;
}
void del(rint x){
int l_x=pre(x),r_x=nxt(x);
splay(l_x,0);
splay(r_x,l_x);
rint u=ch[r_x][0];
if(cnt[u]>1){
cnt[u]--;splay(u,0);
}else {
ch[r_x][0]=0;
}
}
int qrk(rint x){
rint u=rt;
if(siz[u]<x)return -1;
while(1){
int v=ch[u][0];
if(x>siz[v]+cnt[u]){
x-=siz[v]+cnt[u];
u=ch[u][1];
}else if(siz[v]>=x)u=v;
else return val[u];
}
return -1;
}
const int INF=0x3f3f3f3f;
int main(){
int n;
scanf("%d",&n);
insert(INF);insert(-INF);
while(n--){
rint opt,x;
scanf("%d%d",&opt,&x);
if(opt==1)insert(x);
else if(opt==2)del(x);
else if(opt==3){find(x);printf("%d\n",siz[ch[rt][0]]);}
else if(opt==4){printf("%d\n",qrk(x+1));}
else if(opt==5){printf("%d\n",val[pre(x)]);}
else printf("%d\n",val[nxt(x)]);
}
return 0;
}

无旋Treap

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=1e5+10;
int cnt,rt,ch[N][2],siz[N],val[N],rnd[N];
mt19937 trnd(time(0));
uniform_int_distribution<int> dis(1,1e9);
int New(rint x){
++cnt;val[cnt]=x;rnd[cnt]=dis(trnd);siz[cnt]=1;
return cnt;
}
void up(rint x){siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;}
void split(rint rt,rint k,rint &x,rint &y){
if(!rt){x=0;y=0;return;}
if(val[rt]<=k)x=rt,split(ch[x][1],k,ch[x][1],y);
else y=rt,split(ch[y][0],k,x,ch[y][0]);
up(rt);
}
int merge(rint x,rint y){
if(!x||!y)return x+y;
if(rnd[x]<rnd[y])return ch[x][1]=merge(ch[x][1],y),up(x),x;
return ch[y][0]=merge(x,ch[y][0]),up(y),y;
}
int kth(rint x,rint k){
while(1){
if(k<=siz[ch[x][0]])x=ch[x][0];
else if(k==siz[ch[x][0]]+1)return x;
else k-=siz[ch[x][0]]+1,x=ch[x][1];
}
return -1;
}
int main(){
rint T,x,y,z;
scanf("%d",&T);
while(T--){
rint opt,w;
scanf("%d%d",&opt,&w);
if(opt==1){
split(rt,w,x,y);
rt=merge(x,merge(New(w),y));
}else if(opt==2){
split(rt,w,x,y);
split(x,w-1,x,z);
z=merge(ch[z][0],ch[z][1]);
rt=merge(merge(x,z),y);
}else if(opt==3){
split(rt,w-1,x,y);
printf("%d\n",siz[x]+1);
rt=merge(x,y);
}else if(opt==4){
printf("%d\n",val[kth(rt,w)]);
}else if(opt==5){
split(rt,w-1,x,y);
printf("%d\n",val[kth(x,siz[x])]);
rt=merge(x,y);
}else {
split(rt,w,x,y);
printf("%d\n",val[kth(y,1)]);
rt=merge(x,y);
}
}
return 0;
}

可持久化平衡树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include<bits/stdc++.h>
using namespace std;
const int N=5e5+10;
const int maxn=1e7+10;
#define rint register int
struct fhq{
int ch[2],fa,siz,val,rnd;
}t[maxn<<2];
int rt[N],cnt;
mt19937 rnd(time(0));
uniform_int_distribution<int> dis(1,1e9);
int New(rint v){
t[++cnt].val=v;t[cnt].rnd=dis(rnd);t[cnt].siz=1;
return cnt;
}
void up(rint x){
t[x].siz=t[t[x].ch[0]].siz+t[t[x].ch[1]].siz+1;
}
void split(rint rt,rint k,rint &x,rint &y){
if(!rt){x=y=0;return;}
if(t[rt].val<=k){
x=++cnt;t[x]=t[rt];
split(t[x].ch[1],k,t[x].ch[1],y);
up(x);
}else {
y=++cnt;t[y]=t[rt];
split(t[y].ch[0],k,x,t[y].ch[0]);
up(y);
}
}
int merge(rint x,rint y){
if(!x||!y)return x+y;
if(t[x].rnd>t[y].rnd){
rint p=++cnt;t[p]=t[x];
t[p].ch[1]=merge(t[p].ch[1],y);
up(p);return p;
}else {
rint p=++cnt;t[p]=t[y];
t[p].ch[0]=merge(x,t[p].ch[0]);
up(p);return p;
}
}
int x,y,z;
void insert(int &rt,int v){
split(rt,v,x,y);
rt=merge(merge(x,New(v)),y);
}
void del(int &rt,int v){
split(rt,v,x,y);
split(x,v-1,x,z);
if(z)z=merge(t[z].ch[0],t[z].ch[1]);
rt=merge(merge(x,z),y);
}
int kth(int rt,int k){
while(1){
if(k<=t[t[rt].ch[0]].siz)rt=t[rt].ch[0];
else if(k==t[t[rt].ch[0]].siz+1)return rt;
else k-=t[t[rt].ch[0]].siz+1,rt=t[rt].ch[1];
}
}
int per(int &rt,int k){
split(rt,k-1,x,y);
rint res=kth(x,t[x].siz);
rt=merge(x,y);
return res;
}
int nxt(int &rt,int k){
split(rt,k,x,y);
rint res=kth(y,1);
rt=merge(x,y);
return res;
}
void getk(int &rt,int k){
split(rt,k-1,x,y);
printf("%d\n",t[x].siz);
rt=merge(x,y);
}
int main(){
rint n,tim=0;
scanf("%d",&n);
rt[0]=merge(New(-0x7fffffff),New(0x7fffffff));
while(n--){
rint v,opt,a;
scanf("%d%d%d",&v,&opt,&a);
rt[++tim]=rt[v];
if(opt==1)insert(rt[tim],a);
else if(opt==2)del(rt[tim],a);
else if(opt==3)getk(rt[tim],a);
else if(opt==4)printf("%d\n",t[kth(rt[tim],a+1)].val);
else if(opt==5)printf("%d\n",t[per(rt[tim],a)].val);
else printf("%d\n",t[nxt(rt[tim],a)].val);
}
return 0;
}

线段树合并

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
struct Edge{
int to,nxt;
}e[N<<1];
int h[N],idx;
void Ins(int a,int b){
e[++idx].to=b;e[idx].nxt=h[a];h[a]=idx;
}
int dep[N],fa[N],siz[N],son[N];
void dfs1(int u){
siz[u]=1;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u])continue;
fa[v]=u;
dep[v]=dep[u]+1;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
int top[N];
void dfs2(int u,int tt){
top[u]=tt;
if(son[u])dfs2(son[u],tt);
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
int lca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
x=fa[top[x]];
}
return dep[x]>dep[y]?y:x;
}
int cnt,lc[N*80],rc[N*80],val[N*80],ide[N*80];
void pushup(int rt){
int ls=lc[rt],rs=rc[rt];
if(val[ls]>=val[rs]){
val[rt]=val[ls];
ide[rt]=ide[ls];
}else{
val[rt]=val[rs];
ide[rt]=ide[rs];
}
}
void modify(int &rt,int l,int r,int pos,int w){
if(!rt)rt=++cnt;
if(l==r){
val[rt]+=w;
ide[rt]=l;
return;
}
int mid=l+r>>1;
if(pos<=mid)modify(lc[rt],l,mid,pos,w);
else modify(rc[rt],mid+1,r,pos,w);
pushup(rt);
}
int merge(int ls,int rs,int l,int r){
if(!ls||!rs)return ls+rs;
if(l==r){
val[ls]+=val[rs];
return ls;
}
int mid=l+r>>1;
lc[ls]=merge(lc[ls],lc[rs],l,mid);
rc[ls]=merge(rc[ls],rc[rs],mid+1,r);
val[ls]+=val[rs];
pushup(ls);
return ls;
}
int ans[N];
int rt[N];
void dfs(int u){
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u])continue;
dfs(v);
rt[u]=merge(rt[u],rt[v],1,N-10);
}
ans[u]=ide[rt[u]];
if(val[rt[u]]==0)ans[u]=0;
}
inline int read(){
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-')f=-1;
ch=getchar();
}
while(ch<='9'&&ch>='0'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
int main(){
int n=read(),m=read();
for(int i=1;i<n;i++){
int a=read(),b=read();
Ins(a,b);Ins(b,a);
}
dfs1(1);
dfs2(1,1);
for(int i=1;i<=m;i++){
int x=read(),y=read(),z=read();
modify(rt[x],1,N-10,z,1);
modify(rt[y],1,N-10,z,1);
int t=lca(x,y);
modify(rt[t],1,N-10,z,-1);
modify(rt[fa[t]],1,N-10,z,-1);
}
dfs(1);
for(int i=1;i<=n;i++)
printf("%d\n",ans[i]);
return 0;
}

线段树分裂

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=1e6+10;
const int maxn=N<<2;
int rt[N],lc[maxn],rc[maxn],cnt;
ll val[maxn];
void up(rint x){
val[x]=val[lc[x]]+val[rc[x]];
}
void Build(rint &rt,rint l,rint r){
if(!rt)rt=++cnt;
if(l==r){
scanf("%lld",&val[rt]);
return ;
}
rint mid=l+r>>1;
Build(lc[rt],l,mid);
Build(rc[rt],mid+1,r);
up(rt);
}
void split(rint &r1,rint &r2,rint l,rint r,rint cl,rint cr){
if(!r1){
r2=0;
return;
}
if(cl<=l&&r<=cr){
r2=r1;r1=0;
return ;
}
r2=++cnt;
rint mid=l+r>>1;
if(cl<=mid)split(lc[r1],lc[r2],l,mid,cl,cr);
if(cr>mid)split(rc[r1],rc[r2],mid+1,r,cl,cr);
up(r1);up(r2);
}
int merge(rint x,rint y){
if(!x||!y)return x+y;
lc[x]=merge(lc[x],lc[y]);
rc[x]=merge(rc[x],rc[y]);
val[x]+=val[y];
return x;
}
void update(rint &rt,rint l,rint r,rint p,rint v){
if(!rt)rt=++cnt;
val[rt]+=v;
if(l==r)return;
rint mid=l+r>>1;
if(p<=mid)update(lc[rt],l,mid,p,v);
else update(rc[rt],mid+1,r,p,v);
}
ll query(rint rt,rint l,rint r,rint cl,rint cr){
if(!rt)return 0;
if(cl<=l&&r<=cr)return val[rt];
rint mid=l+r>>1;rll res=0;
if(cl<=mid)res+=query(lc[rt],l,mid,cl,cr);
if(cr>mid)res+=query(rc[rt],mid+1,r,cl,cr);
return res;
}
int kth(rint rt,rint l,rint r,rint k){
if(k>val[rt])return -1;
if(l==r)return l;
int mid=l+r>>1;
if(k<=val[lc[rt]])return kth(lc[rt],l,mid,k);
else return kth(rc[rt],mid+1,r,k-val[lc[rt]]);
}
int main(){
rint n,m,cnt=1;
scanf("%d%d",&n,&m);
Build(rt[1],1,n);
while(m--){
rint opt;
scanf("%d",&opt);
if(opt==0){
rint p,x,y;
scanf("%d%d%d",&p,&x,&y);++cnt;
split(rt[p],rt[cnt],1,n,x,y);
}else if(opt==1){
rint x,y;
scanf("%d%d",&x,&y);
rt[x]=merge(rt[x],rt[y]);
}else if(opt==2){
rint p,x,y;
scanf("%d%d%d",&p,&x,&y);
update(rt[p],1,n,y,x);
}else if(opt==3){
rint p,x,y;
scanf("%d%d%d",&p,&x,&y);
printf("%lld\n",query(rt[p],1,n,x,y));
}else if(opt==4){
rint p,k;
scanf("%d%d",&p,&k);
printf("%d\n",kth(rt[p],1,n,k));
}
}
return 0;
}

势能分析线段树

区间与/或一个数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include<bits/stdc++.h>
using namespace std;
namespace lin4xu{

#define re register
#define rint re int
#define ll long long
#define rll re ll
#define db double
#define rdb re db
#define rch re char

template <typename T> inline T read()
{
re T ans=0;rint f=0;rch ch=getchar();
for(;!isdigit(ch);ch=getchar()) f=ch=='-';
for(;isdigit(ch);ch=getchar()) ans=ans*10+(ch&15);
return f?-ans:ans;
}

const int maxn=5e5+10;

int sum1[maxn<<2];
int sum2[maxn<<2];
int tag[maxn<<2];
int mn[maxn<<2];

inline void push_up(rint x)
{
rint lx=x<<1,rx=lx|1;
mn[x]=min(mn[lx],mn[rx]);
sum1[x]=sum1[lx]&sum1[rx];
sum2[x]=sum2[lx]|sum2[rx];
}

inline void push_down(rint x)
{
if(!tag[x]) return;
rint lx=x<<1,rx=lx|1;
sum1[lx]+=tag[x],sum2[lx]+=tag[x],tag[lx]+=tag[x],mn[lx]+=tag[x];
sum1[rx]+=tag[x],sum2[rx]+=tag[x],tag[rx]+=tag[x],mn[rx]+=tag[x];
tag[x]=0;
}

int left,right,val;

void build(rint x,rint l,rint r)
{
if(l==r) return sum1[x]=sum2[x]=mn[x]=read<int>(),void();
rint mid=l+r>>1;
build(x<<1,l,mid);
build(x<<1|1,mid+1,r);
push_up(x);
}

void modify1(rint x,rint l,rint r)
{
if(l>=left&&r<=right&&sum1[x]-(sum1[x]&val)==sum2[x]-(sum2[x]&val)) return tag[x]+=(sum1[x]&val)-sum1[x],mn[x]+=(sum1[x]&val)-sum1[x],sum1[x]&=val,sum2[x]&=val,void();
rint mid=l+r>>1;push_down(x);
if(left<=mid) modify1(x<<1,l,mid);
if(right>mid) modify1(x<<1|1,mid+1,r);
push_up(x);
}

void modify2(rint x,rint l,rint r)
{
if(l>=left&&r<=right&&sum1[x]-(sum1[x]|val)==sum2[x]-(sum2[x]|val)) return tag[x]+=(sum1[x]|val)-sum1[x],mn[x]+=(sum1[x]|val)-sum1[x],sum1[x]|=val,sum2[x]|=val,void();
rint mid=l+r>>1;push_down(x);
if(left<=mid) modify2(x<<1,l,mid);
if(right>mid) modify2(x<<1|1,mid+1,r);
push_up(x);
}

int query(rint x,rint l,rint r)
{
if(l>=left&&r<=right) return mn[x];
rint mid=l+r>>1,ans=0x7fffffff;push_down(x);
if(left<=mid) ans=min(ans,query(x<<1,l,mid));
if(right>mid) ans=min(ans,query(x<<1|1,mid+1,r));
return ans;
}

int main()
{
rint n=read<int>(),m=read<int>();
build(1,1,n);
for(rint tt=1;tt<=m;tt++)
{
rint opt=read<int>();left=read<int>(),right=read<int>();
if(opt==1) val=read<int>(),modify1(1,1,n);
else if(opt==2) val=read<int>(),modify2(1,1,n);
else printf("%d\n",query(1,1,n));
}
return 0;
}};

int main(){return lin4xu::main();}

区间最值

以区间和一个数 vvmin\min 为例,我们记录一个最大值 MaxMax,记录一个次大值 sMaxsMax,如果 v>Maxv>Max 那么可以直接返回,如果 v>sMaxv>sMax 那么只需要改最大值,否则直接递归,这样做是对的,原因在于每次往下递归一定会合并至少两个不同的数,合并次数最多为 nn,单次合并是 log\log

点分治/树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=4e4+10;
struct Edge{
int to,val,nxt;
}e[N<<1];
int h[N],idx;
void Ins(int a,int b,int c){
e[++idx].to=b;e[idx].val=c;
e[idx].nxt=h[a];h[a]=idx;
}
int Mx[N],siz[N],rt,Tsiz;
bool vis[N];
void get(int u,int fa){
siz[u]=1;Mx[u]=0;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(vis[v]||v==fa)continue;
get(v,u);
siz[u]+=siz[v];
Mx[u]=max(Mx[u],siz[v]);
}
Mx[u]=max(Mx[u],Tsiz-siz[u]);
if(Mx[u]<Mx[rt])rt=u;
}
int d[N],tt,k;
ll ans;
void dfs(int u,int fa,int dis){
d[++tt]=dis;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa||vis[v])continue;
dfs(v,u,dis+e[i].val);
}
}
ll getans(int rt,int st){
tt=0;
dfs(rt,0,st);
sort(d+1,d+tt+1);
int i=1,j=tt;
ll tmp=0;
while(i<=j){
while(i<j&&d[i]+d[j]>k)j--;
tmp+=j-i;
i++;
}
return tmp;
}
void div(int rt){
ans+=getans(rt,0);
vis[rt]=1;
for(int i=h[rt];i;i=e[i].nxt){
int v=e[i].to;
if(vis[v])continue;
ans-=getans(v,e[i].val);
Tsiz=siz[v];rt=0;
get(v,0);
div(rt);
}
}
int main(){
int n;
scanf("%d",&n);
for(int i=1;i<n;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
Ins(a,b,c);Ins(b,a,c);
}
scanf("%d",&k);
Tsiz=n;Mx[0]=0x7fffffff;
get(1,0);
div(rt);
printf("%lld\n",ans);
return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
typedef long long ll;
unordered_map<int,int> dis[N];
struct Edge{
int to,nxt,val;
}e[N<<1];
int h[N],idx;
void Ins(int a,int b,int c){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;
}
struct Node{
int x,y,w;
Node(){}
Node(int a,int b,int c){x=a;y=b;w=c;}
bool operator < (const Node&A)const{
return y<A.y;
}
};
vector<Node> vec[N],rc[N];
vector<ll> s[N],s2[N];
int x[N],rt,Tsiz,Max[N],siz[N];
bool vis[N];
void getrt(int u,int fa){
siz[u]=1;Max[u]=0;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa||vis[v])continue;
getrt(v,u);
siz[u]+=siz[v];
Max[u]=max(Max[u],siz[v]);
}
Max[u]=max(Max[u],Tsiz-siz[u]);
if(Max[rt]>Max[u])rt=u;
}
int fa[N];
void cal(int u,int ff,int now){
vec[now].push_back(Node(u,x[u],dis[now][u]));
s[now].push_back(0);
if(fa[now]){
rc[now].push_back(Node(u,x[u],dis[fa[now]][u]));
s2[now].push_back(0);
}
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==ff||vis[v])continue;
dis[now][v]=dis[now][u]+e[i].val;
cal(v,u,now);
}
}
void dfs(int u){
vis[u]=1;
dis[u][u]=0;
vec[u].push_back(Node(0,-1,0));
vec[u].push_back(Node(0,2e9,0));
s[u].push_back(0);
s[u].push_back(0);
rc[u].push_back(Node(0,-1,0));
rc[u].push_back(Node(0,2e9,0));
s2[u].push_back(0);
s2[u].push_back(0);
cal(u,0,u);
sort(vec[u].begin(),vec[u].end());
for(int i=1;i<vec[u].size();i++)s[u][i]=s[u][i-1]+vec[u][i].w;
if(fa[u]){
sort(rc[u].begin(),rc[u].end());
for(int i=1;i<vec[u].size();i++)s2[u][i]=s2[u][i-1]+rc[u][i].w;
}
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(vis[v])continue;
Tsiz=siz[v];rt=0;
getrt(v,u);
fa[rt]=u;
dfs(rt);
}
}
int main(){
int n,m,A;
scanf("%d%d%d",&n,&m,&A);
for(int i=1;i<=n;i++)
scanf("%d",&x[i]);
for(int i=2;i<=n;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
Ins(a,b,c);Ins(b,a,c);
}
Max[0]=1e9;Tsiz=n;
getrt(1,0);
dfs(rt);
ll res=0;
while(m--){
int x,l,r;
scanf("%d%d%d",&x,&l,&r);
l=(l+res)%A,r=(r+res)%A;
if(l>r)swap(l,r);
ll tmp1,posl=lower_bound(vec[x].begin(),vec[x].end(),Node(0,l,0))-vec[x].begin()-1;
ll tmp2,posr=upper_bound(vec[x].begin(),vec[x].end(),Node(0,r,0))-vec[x].begin()-1;
res=s[x][posr]-s[x][posl];
for(int i=x;fa[i];i=fa[i]){
posl=lower_bound(vec[fa[i]].begin(),vec[fa[i]].end(),Node(0,l,0))-vec[fa[i]].begin()-1;
posr=upper_bound(vec[fa[i]].begin(),vec[fa[i]].end(),Node(0,r,0))-vec[fa[i]].begin()-1;
tmp1=s[fa[i]][posr]-s[fa[i]][posl];
tmp2=posr-posl;
res+=tmp1+tmp2*dis[fa[i]][x];
posl=lower_bound(rc[i].begin(),rc[i].end(),Node(0,l,0))-rc[i].begin()-1;
posr=upper_bound(rc[i].begin(),rc[i].end(),Node(0,r,0))-rc[i].begin()-1;
tmp1=s2[i][posr]-s2[i][posl];
tmp2=posr-posl;
res-=tmp1+tmp2*dis[fa[i]][x];
}
printf("%lld\n",res);
}
return 0;
}

边分治/树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define rll register long long
typedef long long ll;
const int N=3e5+10;
struct Edge{
int to,nxt;bool vis;
ll val;
};
struct Node{
int x,y;
ll dis;
};
struct T3{
Edge e[N<<1];
int h[N],idx;
void Ins(rint a,rint b,rll c){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;
}
int tp[N],son[N],fa[N],siz[N],d[N];
ll dis[N];
void dfs1(rint u){
siz[u]=1;
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa[u])continue;
dis[v]=dis[u]+e[i].val;
d[v]=d[u]+1;fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v])son[u]=v;
}
}
void dfs2(rint u,rint t){
tp[u]=t;
if(son[u])dfs2(son[u],t);
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
int lca(rint x,rint y){
while(tp[x]^tp[y]){
if(d[tp[x]]<d[tp[y]])y=fa[tp[y]];
else x=fa[tp[x]];
}
return d[x]<d[y]?x:y;
}
ll dist(rint x,rint y){
return dis[x]+dis[y]-2*dis[lca(x,y)];
}
void solve(){
dfs1(1);
dfs2(1,1);
}
}tr3;
ll val[N],ext;
int col[N],dfn[N];
bool cmp(rint a,rint b){
return dfn[a]<dfn[b];
}
Node merge(Node u,Node v){
if(u.x==0)return v;
if(v.x==0)return u;
Node res=u;
ll tmp=tr3.dist(u.x,v.x)+val[u.x]+val[v.x];
if(tmp>res.dis)res=(Node){u.x,v.x,tmp};
tmp=tr3.dist(u.x,v.y)+val[u.x]+val[v.y];
if(tmp>res.dis)res=(Node){u.x,v.y,tmp};
tmp=tr3.dist(u.x,u.y)+val[u.x]+val[u.y];
if(tmp>res.dis)res=(Node){u.x,u.y,tmp};
tmp=tr3.dist(v.x,v.y)+val[v.x]+val[v.y];
if(tmp>res.dis)res=(Node){v.x,v.y,tmp};
tmp=tr3.dist(v.y,u.y)+val[v.y]+val[u.y];
if(tmp>res.dis)res=(Node){v.y,u.y,tmp};
tmp=tr3.dist(v.x,u.y)+val[v.x]+val[u.y];
if(tmp>res.dis)res=(Node){v.x,u.y,tmp};
return res;
}
ll merge2(Node u,Node v){
if(u.x==0)return -1e16;
if(v.x==0)return -1e16;
Node res=(Node){0,0,0};
ll tmp=tr3.dist(u.x,v.x)+val[u.x]+val[v.x];
if(tmp>res.dis)res=(Node){u.x,v.x,tmp};
tmp=tr3.dist(u.x,v.y)+val[u.x]+val[v.y];
if(tmp>res.dis)res=(Node){u.x,v.y,tmp};
tmp=tr3.dist(v.y,u.y)+val[v.y]+val[u.y];
if(tmp>res.dis)res=(Node){v.y,u.y,tmp};
tmp=tr3.dist(v.x,u.y)+val[v.x]+val[u.y];
if(tmp>res.dis)res=(Node){v.x,u.y,tmp};
return res.dis;
}
ll res;
struct T2{
Edge e[N<<1];
int h[N],idx;
void Ins(rint a,rint b,rll c){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;
}
int Time,top[N],son[N],fa[N],siz[N],d[N];
ll dis[N];
void dfs1(rint u){
siz[u]=1;dfn[u]=++Time;
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa[u])continue;
dis[v]=dis[u]+e[i].val;
d[v]=d[u]+1;fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v])son[u]=v;
}
}
void dfs2(rint u,rint t){
top[u]=t;
if(son[u])dfs2(son[u],t);
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
h[u]=0;
}
int lca(rint x,rint y){
while(top[x]^top[y]){
if(d[top[x]]<d[top[y]])y=fa[top[y]];
else x=fa[top[x]];
}
return d[x]<d[y]?x:y;
}
ll dist(rint x,rint y){
return dis[x]+dis[y]-2*dis[lca(x,y)];
}
int stk[N],tp;
void insert(rint x){
if(!tp){
stk[++tp]=x;
return;
}
int lc=lca(x,stk[tp]);
if(lc==stk[tp])stk[++tp]=x;
else {
while(tp>1&&d[stk[tp-1]]>=d[lc]){
Ins(stk[tp-1],stk[tp],0);
tp--;
}
if(stk[tp]!=lc){
Ins(lc,stk[tp],0);
stk[tp]=lc;
}
stk[++tp]=x;
}
}
Node f[3][N];
void dfs(rint u){
stk[++tp]=u;
if(col[u])f[col[u]][u]=(Node){u,u,val[u]};
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
dfs(v);
res=max(res,merge2(f[1][u],f[2][v])+ext-2*dis[u]);
res=max(res,merge2(f[2][u],f[1][v])+ext-2*dis[u]);
f[1][u]=merge(f[1][u],f[1][v]);
f[2][u]=merge(f[2][u],f[2][v]);
}
}
void calc(rint *a,rint n){
sort(a+1,a+n+1,cmp);
if(a[1]!=1)stk[++tp]=1;
for(rint i=1;i<=n;i++)insert(a[i]),val[a[i]]+=dis[a[i]];
while(tp>1)Ins(stk[tp-1],stk[tp],0),tp--;
tp=0;
dfs(1);
for(rint i=1;i<=tp;i++)
h[stk[i]]=f[2][stk[i]].x=f[2][stk[i]].y=f[2][stk[i]].dis=f[1][stk[i]].x=f[1][stk[i]].y=f[1][stk[i]].dis=0;
tp=idx=0;
}
void solve(){
dfs1(1);
dfs2(1,1);
idx=0;
}
}tr2;
const int INF=0x7f7f7f7f;
struct T1{
Edge e[N<<1];
int h[N],idx;
void Ins(rint a,rint b,rll c){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;e[idx].vis=0;
}
Edge e2[N<<1];
int h2[N],idx2,cnt,n;
void Ins2(rint a,rint b,rll c){
e2[++idx2].to=b;e2[idx2].nxt=h2[a];
h2[a]=idx2;e2[idx2].val=c;
}
void rebuild(rint u,rint fa){
rint lst=0;
for(rint i=h2[u];i;i=e2[i].nxt){
rint v=e2[i].to;
if(v==fa)continue;
rebuild(v,u);
if(!lst){
lst=u;
}else {
++cnt;
Ins(lst,cnt,0);Ins(cnt,lst,0);
lst=cnt;
}
Ins(lst,v,e2[i].val);Ins(v,lst,e2[i].val);
}
}
int siz[N],Tsiz,Max,rt;
void getsiz(rint u,rint fa){
siz[u]=1;
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa||e[i].vis)continue;
getsiz(v,u);
siz[u]+=siz[v];
}
}
void getrt(rint u,rint fa){
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa||e[i].vis)continue;
getrt(v,u);
rint now=max(siz[v],Tsiz-siz[v]);
if(Max>now)Max=now,rt=i;
}
}
int s[N],tt;
void dfs(rint u,rint fa,rint *f,rint typ,rint &tt,rll *g,rll dis){
if(u<=n){
f[++tt]=u;col[u]=typ;g[u]=dis;
}
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa||e[i].vis)continue;
dfs(v,u,f,typ,tt,g,dis+e[i].val);
}
}
void dfs(){
if(!rt)return ;
e[rt].vis=e[rt^1].vis=1;
rint v1=e[rt].to,v2=e[rt^1].to;
tt=0;ext=e[rt].val;
dfs(v1,0,s,1,tt,val,0);
dfs(v2,0,s,2,tt,val,0);
tr2.calc(s,tt);
for(rint i=1;i<=tt;i++)col[s[i]]=0,val[s[i]]=0;
getsiz(v1,0);Tsiz=siz[v1];Max=INF;rt=0;getrt(v1,0);dfs();
getsiz(v2,0);Tsiz=siz[v2];Max=INF;rt=0;getrt(v2,0);dfs();
}
void solve(rint x){
cnt=n=x;idx=1;
rebuild(1,0);
getsiz(1,0);
rt=0;Tsiz=n;Max=INF;
getrt(1,0);
dfs();
}
}tr1;
int main(){
rint n;
scanf("%d",&n);
for(rint i=2;i<=n;i++){
rint a,b;rll w;
scanf("%d%d%lld",&a,&b,&w);
tr1.Ins2(a,b,w);tr1.Ins2(b,a,w);
}
for(rint i=2;i<=n;i++){
rint a,b;rll w;
scanf("%d%d%lld",&a,&b,&w);
tr2.Ins(a,b,w);tr2.Ins(b,a,w);
}
for(rint i=2;i<=n;i++){
rint a,b;rll w;
scanf("%d%d%lld",&a,&b,&w);
tr3.Ins(a,b,w);tr3.Ins(b,a,w);
}
tr2.solve();
tr3.solve();
tr1.solve(n);
printf("%lld\n",res);
return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include<bits/stdc++.h>
#define int long long
using namespace std;
int const N=366671;
template<unsigned M>struct graph
{
int target[2*M],pre[2*M],last[M],tot,w[2*M];
void add(int x,int y,int z)
{
target[++tot]=y;
pre[tot]=last[x];
last[x]=tot;
w[tot]=z;
}
};
graph<N>g1,g3;graph<N<<1>g2;
int n,now,siz[N<<1],nows,ed,li,cnt,mx[N<<5|1][2],ch[N<<5|1][2],last[N],dis[N],
tmp,ans=-1e18,rt[N];
bool del[N<<2];
void dfs(int x,int fa)
{
int las=x;
for(int i=g1.last[x];i;i=g1.pre[i])
{
int tar=g1.target[i];
if(tar==fa)continue;
g2.add(las,++now,0),g2.add(now,las,0),g2.add(now,tar,g1.w[i]),g2.add(tar,now,g1.w[i]);
dis[tar]=dis[x]+g1.w[i];las=now;dfs(tar,x);
}
}
void dfs2(int x,int fa,int nowd,int op)
{
if(!op)tmp++;
if(x<=n)
{
cnt++;
if(!last[x])rt[x]=last[x]=cnt,cnt++;
ch[last[x]][op]=cnt;
mx[last[x]][op]=dis[x]+nowd;
last[x]=cnt;
}
for(int i=g2.last[x];i;i=g2.pre[i])
{
int tar=g2.target[i];
if(tar==fa||del[i])continue;
dfs2(tar,x,nowd+g2.w[i],op);
}
}
void get(int x,int fa)
{
siz[x]=1;
for(int i=g2.last[x];i;i=g2.pre[i])
{
int tar=g2.target[i];
if(tar==fa||del[i])continue;
get(tar,x);siz[x]+=siz[tar];
if(max(siz[tar],nows-siz[tar])<li)li=max(siz[tar],nows-siz[tar]),ed=(i+1)>>1;
}
}
void solve(int x,int s)
{
if(s==1)return;
ed=li=1e9;nows=s;
get(x,0);
int r1=g2.target[(ed<<1)-1],r2=g2.target[ed<<1];
del[(ed<<1)-1]=del[ed<<1]=1;tmp=0;
dfs2(r1,r2,0,0),dfs2(r2,r1,g2.w[ed<<1],1);
int tt=siz[x]-tmp;
solve(r1,tmp);solve(r2,tt);
}
int merge(int x,int y,int t)
{
if((!x)||(!y))return x+y;
ans=max(ans,max(mx[x][0]+mx[y][1],mx[y][0]+mx[x][1])+2*t);
mx[x][0]=max(mx[x][0],mx[y][0]),mx[x][1]=max(mx[x][1],mx[y][1]);
ch[x][0]=merge(ch[x][0],ch[y][0],t),ch[x][1]=merge(ch[x][1],ch[y][1],t);
return x;
}
void dfs3(int x,int fa,int nowd)
{
ans=max(ans,2*(dis[x]-nowd));
for(int i=g3.last[x];i;i=g3.pre[i])
{
int tar=g3.target[i];
if(tar==fa)continue;
dfs3(tar,x,nowd+g3.w[i]);
rt[x]=merge(rt[x],rt[tar],-nowd);
}
}
signed main()
{
memset(mx,0xc0,sizeof(mx));
int x,y,z;
scanf("%lld",&n);now=n;
for(int i=1;i<n;i++)scanf("%lld%lld%lld",&x,&y,&z),g1.add(x,y,z),g1.add(y,x,z);
dfs(1,0);
solve(1,now);
for(int i=1;i<n;i++)scanf("%lld%lld%lld",&x,&y,&z),g3.add(x,y,z),g3.add(y,x,z);
dfs3(1,0,0);
printf("%lld",ans>>1);
return 0;
}

cdq分治

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=5e5+10;
int tr[N],maxn;
void Add(int x,int w){
for(;x<=maxn;x+=x&-x)
tr[x]+=w;
}
int ans[N],tmp[N],tt;
int query(int x){
int cnt=0;
for(;x;x&=x-1)cnt+=tr[x];
return cnt;
}
struct Node{
int s,c,m,num,val;
bool operator < (const Node&A)const{
if(s==A.s){
if(c==A.c)return m<A.m;
return c<A.c;
}
return s<A.s;
}
}p1[N],p2[N],t[N];
void cdq(int l,int r){
if(l==r)return;
int mid=l+r>>1;
cdq(l,mid);cdq(mid+1,r);
int cl=l,cr=mid+1,nl=l;
while(cl<=mid&&cr<=r){
if(p2[cl].c<=p2[cr].c)Add(p2[cl].m,p2[cl].num),t[nl++]=p2[cl++];
else p2[cr].val+=query(p2[cr].m),t[nl++]=p2[cr++];
}
while(cl<=mid)Add(p2[cl].m,p2[cl].num),t[nl++]=p2[cl++];
while(cr<=r)p2[cr].val+=query(p2[cr].m),t[nl++]=p2[cr++];
for(int i=l;i<=mid;i++)Add(p2[i].m,-p2[i].num);
for(int i=l;i<=r;i++)p2[i]=t[i];
}
int main(){
int n;
scanf("%d%d",&n,&maxn);
for(int i=1;i<=n;i++)
scanf("%d%d%d",&p1[i].s,&p1[i].c,&p1[i].m);
int x=0;
sort(p1+1,p1+n+1);
for(int i=1,s=0;i<=n;i++){
s++;
if(p1[i].s!=p1[i+1].s||p1[i].c!=p1[i+1].c||p1[i].m!=p1[i+1].m){
p2[++x]=p1[i];
p2[x].num=s;
s=0;
}
}
cdq(1,x);
for(int i=1;i<=x;i++)
ans[p2[i].val+p2[i].num-1]+=p2[i].num;
for(int i=0;i<n;i++)
printf("%d\n",ans[i]);
return 0;
}

树状数组套主席树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
const int maxn=9e6+10;
#define rint register int
int rt[N],val[maxn],lc[maxn],rc[maxn],a[N],stk[N],tp,cnt;
int New(){
if(tp)return stk[tp--];
return ++cnt;
}
void update(int &rt,int l,int r,int p,int v){
if(!rt)rt=New();
val[rt]+=v;
if(l==r)return;
rint mid=l+r>>1;
if(p<=mid){
update(lc[rt],l,mid,p,v);
if(!val[lc[rt]])stk[++tp]=lc[rt],lc[rt]=0;
}else {
update(rc[rt],mid+1,r,p,v);
if(!val[rc[rt]])stk[++tp]=rc[rt],rc[rt]=0;
}
}
int q1[N],tt1,q2[N],tt2;
int kth(int l,int r,int k){
if(l==r)return l;
rint now=0;
for(rint i=1;i<=tt1;i++)now+=val[lc[q1[i]]];
for(rint i=1;i<=tt2;i++)now-=val[lc[q2[i]]];
rint mid=l+r>>1;
if(k<=now){
for(rint i=1;i<=tt1;i++)q1[i]=lc[q1[i]];
for(rint i=1;i<=tt2;i++)q2[i]=lc[q2[i]];
return kth(l,mid,k);
}else {
for(rint i=1;i<=tt1;i++)q1[i]=rc[q1[i]];
for(rint i=1;i<=tt2;i++)q2[i]=rc[q2[i]];
return kth(mid+1,r,k-now);
}
}
int query(rint l,rint r,rint k,rint res){
if(l==r)return res;
rint mid=l+r>>1;
if(k<=mid){
for(rint i=1;i<=tt1;i++)q1[i]=lc[q1[i]];
for(rint i=1;i<=tt2;i++)q2[i]=lc[q2[i]];
return query(l,mid,k,res);
}else {
for(rint i=1;i<=tt1;i++)res+=val[lc[q1[i]]],q1[i]=rc[q1[i]];
for(rint i=1;i<=tt2;i++)res-=val[lc[q2[i]]],q2[i]=rc[q2[i]];
return query(mid+1,r,k,res);
}
}
int n,m;
int getk(rint l,rint r,rint k){
tt1=tt2=0;
for(rint i=r;i;i-=i&-i)q1[++tt1]=rt[i];
for(rint i=l-1;i;i-=i&-i)q2[++tt2]=rt[i];
return query(0,1e8,k,1);
}
int main(){
scanf("%d%d",&n,&m);
for(rint i=1;i<=n;i++){
scanf("%d",&a[i]);
for(rint j=i;j<=n;j+=j&-j)update(rt[j],0,1e8,a[i],1);
}
while(m--){
rint opt;
scanf("%d",&opt);
if(opt==1){
rint l,r,k;
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",getk(l,r,k));
}else if(opt==2){
rint l,r,k;
scanf("%d%d%d",&l,&r,&k);
tt1=tt2=0;
for(rint j=r;j;j-=j&-j)q1[++tt1]=rt[j];
for(rint j=l-1;j;j-=j&-j)q2[++tt2]=rt[j];
printf("%d\n",kth(0,1e8,k));
}else if(opt==3){
rint pos,k;
scanf("%d%d",&pos,&k);
for(rint j=pos;j<=n;j+=j&-j){
update(rt[j],0,1e8,a[pos],-1);
update(rt[j],0,1e8,k,1);
}
a[pos]=k;
}else if(opt==4){
rint l,r,k;
scanf("%d%d%d",&l,&r,&k);
k=getk(l,r,k);
tt1=tt2=0;
for(rint j=r;j;j-=j&-j)q1[++tt1]=rt[j];
for(rint j=l-1;j;j-=j&-j)q2[++tt2]=rt[j];
if(k==1)puts("-2147483647");
else printf("%d\n",kth(0,1e8,k-1));
}else {
rint l,r,k;
scanf("%d%d%d",&l,&r,&k);
k=getk(l,r,k);
for(rint i=1;i<=tt1;i++)k+=val[q1[i]];
for(rint i=1;i<=tt2;i++)k-=val[q2[i]];
tt1=tt2=0;
for(rint j=r;j;j-=j&-j)q1[++tt1]=rt[j];
for(rint j=l-1;j;j-=j&-j)q2[++tt2]=rt[j];
if(k>r-l+1)puts("2147483647");
else printf("%d\n",kth(0,1e8,k));
}
}
return 0;
}

线段树套平衡树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#include<bits/stdc++.h>
using namespace std;

#define re register
#define rint re int
#define ll long long
#define rll re ll
#define db double
#define rdb re db
#define rch re char
#define un unsigned
#define ull un ll
#define rull re ull
#define lx t[x].l
#define rx t[x].r
#define ly t[y].l
#define ry t[y].r
#define lz t[z].l
#define rz t[z].r

template <typename T> T read()
{
re T ans=0;rint f=1;rch ch=getchar();
for(;!isdigit(ch);ch=getchar())
ch=='-'?f=-1:0;
for(;isdigit(ch);ch=getchar())
ans=ans*10+(ch&15);
return ans*f;
}

const int maxn=5e4+10;
const int maxnlg=1e6+10;

struct Node
{
int l,r,val,size,key;
}t[maxnlg<<2];
int rt[maxn<<2];
int a[maxn];
int t_cnt;

inline int get_point(rint val)
{
t[++t_cnt]=(Node){0,0,val,1,rand()};
return t_cnt;
}

inline void push_up(rint x)
{
t[x].size=t[lx].size+t[rx].size+1;
}

void spilt(rint z,rint val,rint &x,rint &y)
{
if(!z) return x=y=0,void();
if(t[z].val<=val) x=z,spilt(rz,val,rx,y),push_up(x);
else y=z,spilt(lz,val,x,ly),push_up(y);
}

int merge(rint x,rint y)
{
if(!x||!y) return x|y;
if(t[x].key>t[y].key) return rx=merge(rx,y),push_up(x),x;
else return ly=merge(x,ly),push_up(y),y;
}

void insert(rint &rt,rint val)
{
rint x,y;
spilt(rt,val,x,y);
rt=merge(x,merge(get_point(val),y));
}

void del(rint &rt,rint val)
{
rint x,y,z;
spilt(rt,val-1,x,y);
spilt(y,val,y,z);
rt=merge(x,merge(merge(ly,ry),z));
}

void build(rint x,rint l,rint r)
{
for(rint i=l;i<=r;i++) insert(rt[x],a[i]);
if(l==r) return;
rint mid=l+r>>1;
build(x<<1,l,mid),build(x<<1|1,mid+1,r);
}

int query(rint &rt,rint val)
{
rint x,y;
spilt(rt,val-1,x,y);
rint ans=t[x].size;
rt=merge(x,y);
return ans;
}

int query(rint x,rint l,rint r,rint left,rint right,rint val)
{
if(l>=left&&r<=right) return query(rt[x],val);
rint mid=l+r>>1,ans=0;
if(left<=mid) ans+=query(x<<1,l,mid,left,right,val);
if(right>mid) ans+=query(x<<1|1,mid+1,r,left,right,val);
return ans;
}

void modify(rint x,rint l,rint r,rint pos,rint val)
{
del(rt[x],a[pos]),insert(rt[x],val);
if(l==r) return a[l]=val,void();
rint mid=l+r>>1;
if(pos<=mid) modify(x<<1,l,mid,pos,val);
else modify(x<<1|1,mid+1,r,pos,val);
}

int find_kth(rint x,rint kth)
{
if(!kth) return -2147483647;
if(kth>t[x].size) return 2147483647;
while(1)
{
if(kth<=t[lx].size) x=lx;
else if(kth==t[lx].size+1) return t[x].val;
else kth-=t[lx].size+1,x=rx;
}
}

int query_pre(rint &rt,rint val)
{
rint x,y;
spilt(rt,val-1,x,y);
rint ans=find_kth(x,t[x].size);
rt=merge(x,y);
return ans;
}

int query_nxt(rint &rt,rint val)
{
rint x,y;
spilt(rt,val,x,y);
rint ans=find_kth(y,1);
rt=merge(x,y);
return ans;
}

int query_pre(rint x,rint l,rint r,rint left,rint right,rint val)
{
if(l>=left&&r<=right) return query_pre(rt[x],val);
rint mid=l+r>>1,ans=-2147483647;
if(left<=mid) ans=max(ans,query_pre(x<<1,l,mid,left,right,val));
if(right>mid) ans=max(ans,query_pre(x<<1|1,mid+1,r,left,right,val));
return ans;
}

int query_nxt(rint x,rint l,rint r,rint left,rint right,rint val)
{
if(l>=left&&r<=right) return query_nxt(rt[x],val);
rint mid=l+r>>1,ans=2147483647;
if(left<=mid) ans=min(ans,query_nxt(x<<1,l,mid,left,right,val));
if(right>mid) ans=min(ans,query_nxt(x<<1|1,mid+1,r,left,right,val));
return ans;
}

int main()
{
rint n=read<int>(),m=read<int>();
for(rint i=1;i<=n;i++) a[i]=read<int>();
build(1,1,n);
for(rint tt=1;tt<=m;tt++)
{
rint opt=read<int>();
if(opt==1)
{
rint l=read<int>(),r=read<int>();
printf("%d\n",query(1,1,n,l,r,read<int>())+1);
}
if(opt==2)
{
rint l=read<int>(),r=read<int>(),K=read<int>();
rint left=0,right=1e8;
while(left<=right)
{
rint mid=left+right>>1;
if(query(1,1,n,l,r,mid)+1<=K) left=mid+1;
else right=mid-1;
}
printf("%d\n",right);
}
if(opt==3)
{
rint pos=read<int>();
modify(1,1,n,pos,read<int>());
}
if(opt==4)
{
rint l=read<int>(),r=read<int>();
printf("%d\n",query_pre(1,1,n,l,r,read<int>()));
}
if(opt==5)
{
rint l=read<int>(),r=read<int>();
printf("%d\n",query_nxt(1,1,n,l,r,read<int>()));
}
}
return 0;
}

虚树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include<bits/stdc++.h>
using namespace std;
const int N=5e5+10;
typedef long long ll;
struct Edge{
int to,nxt,val;
}e[N<<1];
int h[N],idx;
void Ins(int a,int b,int c){
e[++idx].to=b;e[idx].nxt=h[a];
e[idx].val=c;h[a]=idx;
}
int top[N],dep[N],fa[N],dfn[N],Time,son[N],siz[N];
ll Min[N];
void dfs1(int u){
dfn[u]=++Time;
siz[u]=1;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u])continue;
fa[v]=u;
dep[v]=dep[u]+1;
Min[v]=min(Min[u],1ll*e[i].val);
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t){
top[u]=t;
if(son[u])dfs2(son[u],t);
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
int lca(int x,int y){
while(top[x]^top[y]){
if(dep[top[x]]>dep[top[y]])x=fa[top[x]];
else y=fa[top[y]];
}
return dep[x]<dep[y]?x:y;
}
int arr[N],stk[N],tp;
ll f[N];
bool cmp(int a,int b){
return dfn[a]<dfn[b];
}
void insert(int x){
int lc=lca(x,stk[tp]);
if(lc==stk[tp]){
stk[++tp]=x;
return;
}
while(dep[lc]<dep[stk[tp-1]]){
Ins(stk[tp-1],stk[tp],0);
tp--;
}
if(lc==stk[tp-1])Ins(stk[tp-1],stk[tp],0),tp--;
else Ins(lc,stk[tp],0),stk[tp]=lc;
stk[++tp]=x;
}
bool vis[N];
void dfs3(int u){
ll s=0;
f[u]=Min[u];
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
dfs3(v);
s+=f[v];
}
if(!vis[u])f[u]=min(s,f[u]);
h[u]=vis[u]=0;
}
int main(){
int n;
scanf("%d",&n);
for(int i=2;i<=n;i++){
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
Ins(u,v,w);Ins(v,u,w);
}
Min[1]=1e18;
dfs1(1);
dfs2(1,1);
int m;
scanf("%d",&m);
memset(h,0,sizeof(h));
idx=0;
while(m--){
int k;
scanf("%d",&k);
for(int i=1;i<=k;i++)
scanf("%d",&arr[i]),vis[arr[i]]=1;
sort(arr+1,arr+k+1,cmp);
stk[++tp]=1;
for(int i=(arr[1]==1)?2:1;i<=k;i++)insert(arr[i]);
while(tp>1)Ins(stk[tp-1],stk[tp],0),tp--;
dfs3(1);
printf("%lld\n",f[1]);
tp=idx=0;
}
return 0;
}

圆方树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include<bits/stdc++.h>
using namespace std;
const int N=2e6+10;
#define rint register int
struct Edge{
int to,nxt;
}e[N<<1],e2[N<<1];
int h[N],idx;
void Ins(rint a,rint b){
e[++idx].to=b;e[idx].nxt=h[a];h[a]=idx;
}
int h2[N],idx2;
void Ins2(rint a,rint b){
e2[++idx2].to=b;e2[idx2].nxt=h2[a];h2[a]=idx2;
}
int stk[N],tp,low[N],dfn[N],Time,fang;
void tarjan(rint u,rint fa){
stk[++tp]=u;
low[u]=dfn[u]=++Time;
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa)continue;
if(!dfn[v]){
tarjan(v,u);
low[u]=min(low[u],low[v]);
if(low[v]>=dfn[u]){
++fang;
while(1){
rint x=stk[tp--];
Ins2(x,fang);Ins2(fang,x);
if(x==v)break;
}
Ins2(fang,u);Ins2(u,fang);
}
}else low[u]=min(low[u],dfn[v]);
}
}
int top[N],fa[N],son[N],siz[N],d[N],dis[N];
void dfs1(rint u){
siz[u]=1;
for(rint i=h2[u];i;i=e2[i].nxt){
rint v=e2[i].to;
if(v==fa[u])continue;
fa[v]=u;
d[v]=d[u]+1;
dis[v]+=dis[u];
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(rint u,rint t){
top[u]=t;
if(son[u])dfs2(son[u],t);
for(rint i=h2[u];i;i=e2[i].nxt){
rint v=e2[i].to;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
int lca(rint x,rint y){
while(top[x]^top[y]){
if(d[top[x]]>d[top[y]])x=fa[top[x]];
else y=fa[top[y]];
}
return d[x]<d[y]?x:y;
}
int main(){
rint n,m;
scanf("%d%d",&n,&m);fang=n;
for(rint i=1;i<=m;i++){
rint a,b;
scanf("%d%d",&a,&b);
Ins(a,b);Ins(b,a);
}
tarjan(1,0);
for(rint i=1;i<=n;i++)dis[i]=1;
dfs1(1);
dfs2(1,1);
rint q;
scanf("%d",&q);
while(q--){
rint a,b;
scanf("%d%d",&a,&b);
rint lc=lca(a,b);
printf("%d\n",dis[a]+dis[b]-dis[lc]-dis[fa[lc]]);
}
return 0;
}

长链剖分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
#define rint register int
#define rll register long long
struct Edge{
int to,nxt;
}e[N<<1];
int h[N],idx;
void Ins(rint a,rint b){
e[++idx].to=b;e[idx].nxt=h[a];h[a]=idx;
}
int buf[N<<2],*p,*f[N];
int len[N],son[N];
void dfs1(rint u,rint fa){
len[u]=1;
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa)continue;
dfs1(v,u);
len[u]=max(len[u],len[v]+1);
if(len[v]>len[son[u]])son[u]=v;
}
}
int ans[N];
void dfs2(rint u,rint fa){
if(son[u]){
f[son[u]]=f[u]+1;
dfs2(son[u],u);
}else {
f[u][0]=1;
return;
}
f[u][0]=1;
if(f[son[u]][ans[son[u]]]==1)ans[u]=0;
else ans[u]=ans[son[u]]+1;
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(v==fa||v==son[u])continue;
f[v]=p;p+=len[v];
dfs2(v,u);
for(rint j=0;j<len[v];j++){
f[u][j+1]+=f[v][j];
if(f[u][j+1]>f[u][ans[u]]||(f[u][j+1]==f[u][ans[u]]&&j+1<ans[u]))ans[u]=j+1;
}
}
}
int main(){
rint n;
scanf("%d",&n);
for(rint i=2;i<=n;i++){
rint a,b;
scanf("%d%d",&a,&b);
Ins(a,b);Ins(b,a);
}
dfs1(1,0);p=buf;
f[1]=p;p+=len[1];
dfs2(1,0);
for(rint i=1;i<=n;i++)
printf("%d\n",ans[i]);
return 0;
}

KD-Tree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=(5e5+10)*32;
const double alh=0.75;
int typ;
struct Node{
int x[2];
bool operator < (const Node&A)const{
return x[typ]<A.x[typ];
}
}a[N];
struct KdT{
int ls,rs,siz,Max[2],Min[2];
Node v;
}T[N];
#define ls T[rt].ls
#define rs T[rt].rs
int cnt,krt[N],lc[N],rc[N];
void up(int rt){
for(int i=0;i<2;i++){
T[rt].Max[i]=T[rt].Min[i]=T[rt].v.x[i];
if(ls){
T[rt].Max[i]=max(T[rt].Max[i],T[ls].Max[i]);
T[rt].Min[i]=min(T[rt].Min[i],T[ls].Min[i]);
}
if(rs){
T[rt].Max[i]=max(T[rt].Max[i],T[rs].Max[i]);
T[rt].Min[i]=min(T[rt].Min[i],T[rs].Min[i]);
}
}
T[rt].siz=T[ls].siz+T[rs].siz+1;
}
int stk[N],tt;
void get(int rt){
stk[++tt]=rt;
a[tt]=T[rt].v;
if(ls)get(ls);
if(rs)get(rs);
}
void Build(int &rt,int l,int r,int d){
rt=0;
if(l>r)return;
rt=stk[tt--];
int mid=l+r>>1;typ=d;
nth_element(a+l,a+mid,a+r+1);
T[rt].v=a[mid];
Build(ls,l,mid-1,d^1);
Build(rs,mid+1,r,d^1);
up(rt);
}
void insert(int &rt,Node v,int d){
if(!rt){
rt=++cnt;
T[rt].v=v;
up(rt);
return;
}
if(max(T[ls].siz,T[rs].siz)>=T[rt].siz*alh){
get(rt);
Build(rt,1,tt,d);
}
if(v.x[d]<=T[rt].v.x[d])insert(ls,v,d^1);
else insert(rs,v,d^1);
up(rt);
}
int dfs(int rt,Node v1,Node v2){
if(!rt)return 0;
if(T[rt].Max[0]<v1.x[0])return 0;
if(T[rt].Max[1]<v1.x[1])return 0;
if(T[rt].Min[0]>v2.x[0])return 0;
if(T[rt].Min[1]>v2.x[1])return 0;
if(v1.x[0]<=T[rt].Min[0]&&T[rt].Max[0]<=v2.x[0]
&& v1.x[1]<=T[rt].Min[1]&&T[rt].Max[1]<=v2.x[1])
return T[rt].siz;
int res=0;
if(v1.x[0]<=T[rt].v.x[0]&&T[rt].v.x[0]<=v2.x[0]
&& v1.x[1]<=T[rt].v.x[1]&&T[rt].v.x[1]<=v2.x[1])
res++;
if(ls)res+=dfs(ls,v1,v2);
if(rs)res+=dfs(rs,v1,v2);
return res;
}
int tot;
void update(int &rt,int l,int r,int pos,Node v){
if(!rt)rt=++tot;
insert(krt[rt],v,1);
if(l==r)return;
int mid=l+r>>1;
if(pos<=mid)update(lc[rt],l,mid,pos,v);
else update(rc[rt],mid+1,r,pos,v);
}
int query(int rt,int l,int r,int k,Node v1,Node v2){
if(l==r)return l;
int rval=dfs(krt[rc[rt]],v1,v2);
int mid=l+r>>1;
if(k<=rval)return query(rc[rt],mid+1,r,k,v1,v2);
else return query(lc[rt],l,mid,k-rval,v1,v2);
}
int main(){
int n,q,rt=0,res=0;
scanf("%d%d",&n,&q);
while(q--){
int opt;
scanf("%d",&opt);
if(opt==1){
Node v;int val;
scanf("%d%d%d",&v.x[0],&v.x[1],&val);
v.x[0]^=res;v.x[1]^=res;val^=res;
update(rt,1,1e9,val,v);
}else{
Node v1,v2;int k;
scanf("%d%d%d%d%d",&v1.x[0],&v1.x[1],&v2.x[0],&v2.x[1],&k);
v1.x[0]^=res;v1.x[1]^=res;v2.x[0]^=res;v2.x[1]^=res;k^=res;
res=0;
if(dfs(krt[rt],v1,v2)<k)puts("NAIVE!ORZzyz.");
else printf("%d\n",res=query(rt,1,1e9,k,v1,v2));
}
}
return 0;
}

LCT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e5+10;
const int p=51061;
typedef long long ll;
int ch[N][2],fa[N],tag[N],siz[N];
ll sum[N],taga[N],tagm[N],val[N];
void up(int x){
sum[x]=(sum[ch[x][0]]+sum[ch[x][1]]+val[x])%p;
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;
}
void upd1(int x,ll v){
if(!x)return;
val[x]=val[x]*v%p;
sum[x]=sum[x]*v%p;
taga[x]=taga[x]*v%p;
tagm[x]=tagm[x]*v%p;
}
void upd2(int x,ll v){
if(!x)return;
taga[x]=(taga[x]+v)%p;
val[x]=(val[x]+v)%p;
sum[x]=(sum[x]+v*siz[x]%p)%p;
}
void down(int x){
if(tag[x]){
tag[ch[x][0]]^=1;tag[ch[x][1]]^=1;
swap(ch[x][0],ch[x][1]);
tag[x]=0;
}
if(tagm[x]!=1){
upd1(ch[x][0],tagm[x]);
upd1(ch[x][1],tagm[x]);
tagm[x]=1;
}
if(taga[x]){
upd2(ch[x][0],taga[x]);
upd2(ch[x][1],taga[x]);
taga[x]=0;
}
}
bool nrt(int x){return ch[fa[x]][1]==x||ch[fa[x]][0]==x;}
void rotate(int x){
int y=fa[x];int z=fa[y];
int k=ch[y][1]==x;
if(nrt(y))ch[z][ch[z][1]==y]=x;
fa[x]=z;ch[y][k]=ch[x][!k];
fa[ch[x][!k]]=y;ch[x][!k]=y;
fa[y]=x;up(y);up(x);
}
void update(int x){
if(nrt(x))update(fa[x]);
down(x);
}
void splay(int x){
update(x);
for(int y=fa[x],z=fa[y];nrt(x);rotate(x),y=fa[x],z=fa[y])
if(nrt(y))(ch[y][1]==x)==(ch[z][1]==y)?rotate(y):rotate(x);
}
void access(int x){
for(int t=0;x;t=x,x=fa[x])
splay(x),ch[x][1]=t,up(x);
}
void mkrt(int x){
access(x);
splay(x);
tag[x]^=1;
}
void split(int x,int y){
mkrt(x);
access(y);
splay(y);
}
void link(int x,int y){
mkrt(x);
fa[x]=y;
}
void cut(int x,int y){
split(x,y);
if(ch[y][0]==x&&ch[x][1]==0)
ch[y][0]=0,fa[x]=0,up(y);
}
int main(){
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
siz[i]=val[i]=tagm[i]=1;
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
link(x,y);
}
while(m--){
char str[10];int x,y;
scanf("%s%d%d",str,&x,&y);
if(str[0]=='+'){
int v;
scanf("%d",&v);
split(x,y);
upd2(y,v);
}else if(str[0]=='-'){
cut(x,y);
scanf("%d%d",&x,&y);
link(x,y);
}else if(str[0]=='*'){
int v;
scanf("%d",&v);
split(x,y);
upd1(y,v);
}else {
split(x,y);
printf("%lld\n",sum[y]);
}
}
return 0;
}

左偏树(可并堆)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int dis[N],val[N],lc[N],rc[N],f[N];
int find(int x){
return x==f[x]?x:(f[x]=find(f[x]));
}
int merge(int lt,int rt){
if(!lt||!rt)return lt+rt;
if(val[lt]>val[rt]||(val[lt]==val[rt]&&lt>rt))swap(lt,rt);
rc[lt]=merge(rc[lt],rt);
if(dis[lc[lt]]<dis[rc[lt]])swap(lc[lt],rc[lt]);
dis[lt]=dis[rc[lt]]+1;
f[lc[lt]]=f[rc[lt]]=f[lt]=lt;
return lt;
}
void pop(int x){
x=find(x);
printf("%d\n",val[x]);
f[lc[x]]=lc[x];
f[rc[x]]=rc[x];
f[x]=merge(lc[x],rc[x]);
val[x]=-1;
}
int main(){
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&val[i]),f[i]=i;
for(int i=1;i<=m;i++){
int op,x;
scanf("%d%d",&op,&x);
if(op==1){
int y;
scanf("%d",&y);
if(find(x)==find(y)||val[x]==-1||val[y]==-1)continue;
f[find(x)]=f[find(y)]=merge(find(x),find(y));
}else {
if(val[x]==-1)printf("-1\n");
else pop(x);
}
}
}

猫树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
const int MAXN=(1<<19)+7;

static int n,m,a[MAXN],len,Log2[MAXN],pos[MAXN];

namespace Meow_Tree
{
static int p[21][MAXN],s[21][MAXN];//区间最大子段和,区间最大前缀和
void build_tree(int h,int l,int r,int dps)
{
if(l==r){pos[l]=h;return;}
int mid=(l+r)>>1,prep,sm;
p[dps][mid]=s[dps][mid]=sm=a[mid];
prep=a[mid];
if(sm<0)sm=0;
Repe(i,mid-1,l)
{
prep+=a[i];sm+=a[i];
s[dps][i]=max(s[dps][i+1],prep);
p[dps][i]=max(p[dps][i+1],sm);
if(sm<0)sm=0;
}

p[dps][mid+1]=s[dps][mid+1]=sm=a[mid+1];
prep=a[mid+1];
if(sm<0)sm=0;
Rep(i,mid+2,r)
{
prep+=a[i];sm+=a[i];
s[dps][i]=max(s[dps][i-1],prep);
p[dps][i]=max(p[dps][i-1],sm);
if(sm<0)sm=0;
}
build_tree(h<<1,l,mid,dps+1);
build_tree(h<<1|1,mid+1,r,dps+1);
}

inline int query(int l,int r)
{
if(l==r)return a[l];
static int dps;dps=Log2[pos[l]]-Log2[pos[l]^pos[r]];
return max(max(p[dps][l],p[dps][r]),s[dps][l]+s[dps][r]);
}
}

using namespace Meow_Tree;

inline void init()
{
read(n);
Rep(i,1,n)read(a[i]);
len=2;while(len<n)len<<=1;
Rep(i,2,len<<1)Log2[i]=Log2[i>>1]+1;
build_tree(1,1,len,1);
}

支配树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
namespace dominator{

const int N=2e5+10,M=3e5+10;
struct Edge{
int to,nxt;
}e[3][M];
int h[3][N],idx[3];
// 0 原图
// 1 反图
void Ins(rint t,rint x,rint y){
e[t][++idx[t]].to=y;e[t][idx[t]].nxt=h[t][x];h[t][x]=idx[t];
}
int dfn[N],rk[N],Time,fa[N],f[N],mn[N],sdom[N],idom[N],ans[N];
// sdom 半支配点
// idom 支配点
void dfs(int u){
f[u]=mn[u]=sdom[u]=u;
dfn[u]=++Time,rk[Time]=u;
for(rint i=h[0][u];i;i=e[0][i].nxt){
rint v=e[0][i].to;
if(!dfn[v])
fa[v]=u,dfs(v);
}
}
int find(rint u){
if(u==f[u])return u;
int ret=find(f[u]);
if(dfn[sdom[mn[f[u]]]]<dfn[sdom[mn[u]]])
mn[u]=mn[f[u]];
return f[u]=ret;
}
void work(rint S){
dfs(S);
for(rint i=Time;i>1;i--){
rint u=rk[i];
for(rint j=h[1][u];j;j=e[1][j].nxt){
rint v=e[1][j].to;
if(!dfn[v])continue;
// 有可能S走不到v
find(v);
if(dfn[sdom[mn[v]]]<dfn[sdom[u]])
sdom[u]=sdom[mn[v]];
}
f[u]=fa[u];
Ins(2,sdom[u],u);
u=fa[u];
for(rint j=h[2][u];j;j=e[2][j].nxt){
rint v=e[2][j].to;
find(v);
idom[v]=(u==sdom[mn[v]]?u:mn[v]);
}
h[2][u]=0;
}
for(int i=2;i<=Time;i++){
int u=rk[i];
if(idom[u]!=sdom[u])idom[u]=idom[idom[u]];
}
for(int i=Time;i;i--){
int u=rk[i];ans[u]++;
ans[idom[u]]+=ans[u];
}
}

}
int main(){
rint n,m;
scanf("%d%d",&n,&m);
for(rint i=1;i<=m;i++){
rint x,y;
scanf("%d%d",&x,&y);
dominator::Ins(0,x,y);
dominator::Ins(1,y,x);
}
dominator::work(1);
for(int i=1;i<=n;i++)
printf("%d ",dominator::ans[i]);
return 0;
}

李超树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include<map>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e5+10;
typedef double db;
#define rint register int
int tr[N<<2],cnt;
db k[N],b[N];
void New(rint x,rint y,rint xx,rint yy){
++cnt;
if(x==xx)k[cnt]=0,b[cnt]=max(y,yy);
else k[cnt]=1.0*(yy-y)/(xx-x),b[cnt]=yy-k[cnt]*xx;
}
#define ls rt<<1
#define rs rt<<1|1
void update(rint rt,rint l,rint r,rint cl,rint cr,rint v){
rint mid=l+r>>1;
if(cl<=l&&r<=cr){
rint id=tr[rt];
if(!id){
tr[rt]=v;
return;
}
db v1=mid*k[id]+b[id],v2=mid*k[v]+b[v];
if(l==r){
if(v2>v1)tr[rt]=v;
}else if(k[id]==k[v]){
if(v2>v1)tr[rt]=v;
}else if(k[id]<k[v]){
if(v2>v1){
tr[rt]=v;
update(ls,l,mid,cl,cr,id);
}else update(rs,mid+1,r,cl,cr,v);
}else {
if(v2>v1){
tr[rt]=v;
update(rs,mid+1,r,cl,cr,id);
}else update(ls,l,mid,cl,cr,v);
}
return;
}
if(cl<=mid)update(ls,l,mid,cl,cr,v);
if(cr>mid)update(rs,mid+1,r,cl,cr,v);
}
int res1;
db res2;
void query(rint rt,rint l,rint r,rint p){
rint v=tr[rt];
if(v){
db now=k[v]*p+b[v];
if(!res1)res1=v,res2=now;
else if(now>res2)res1=v,res2=now;
else if(now==res2&&res1>v)res1=v;
}
if(l==r)return;
rint mid=l+r>>1;
if(p<=mid)query(ls,l,mid,p);
else query(rs,mid+1,r,p);
}
const int Max=39989;
int main(){
rint n;
scanf("%d",&n);
rint lst=0;
while(n--){
rint opt;
scanf("%d",&opt);
if(opt==0){
rint x;
scanf("%d",&x);
x=(x+lst-1)%Max+1;res1=0;
query(1,1,Max,x);
printf("%d\n",lst=res1);
}else {
rint x,y,xx,yy;
scanf("%d%d%d%d",&x,&y,&xx,&yy);
x=(x+lst-1)%Max+1,xx=(xx+lst-1)%Max+1;
y=(y+lst-1)%1000000000+1;
yy=(yy+lst-1)%1000000000+1;
New(x,y,xx,yy);
update(1,1,Max,min(x,xx),max(x,xx),cnt);
}
}
return 0;
}

划分树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
typedef long long lld;

int getint() // 本题没有负数输入,因此快读不需判断负号。
{
char ch;
while ((ch = getchar()) < '!')
;
int x = ch ^ '0';
while ((ch = getchar()) > '!') x = (x * 10) + (ch ^ '0');
return x;
}

void putll(lld x) { // 输出long long
if (x / 10) putll(x / 10);
putchar((x % 10) ^ '0');
}

int ti[2][17][131073];

int lowbit(int x) { return x & (-x); }

void fixup(int k, int x) {
for (; lowbit(x) < 1 << k; x += lowbit(x)) {
++ti[0][k - 1][x];
}
++ti[0][k - 1][x];
}

void fixdown(int k, int x) {
for (; ((x ^ lowbit(x)) - 1) >> k == (x - 1) >> k; x ^= lowbit(x)) {
++ti[1][k - 1][x];
}
++ti[1][k - 1][x];
}

int queryup(int k, int x) {
int res = 0;
for (; lowbit(x) < 1 << k; x += lowbit(x)) {
res += ti[1][k - 1][x];
}
return res + ti[1][k - 1][x];
}

int querydown(int k, int x) {
int res = 0;
for (; ((x ^ lowbit(x)) - 1) >> k == (x - 1) >> k; x ^= lowbit(x)) {
res += ti[0][k - 1][x];
}
return res + ti[0][k - 1][x];
}

int ai[100005];
int tx[100005];

lld mgx(int* npi, int* nai, int* rxi, int* lxi, int al, int ar, int bl,
int br) {
int rx = 1;
int lx = ar - al;
lld res = 0;
for (; al != ar || bl != br; ++npi, ++nai, ++rxi, ++lxi) {
if (al != ar && (bl == br || tx[al] < tx[bl])) {
--lx;
*rxi = rx;
*nai = tx[al];
*npi = al;
++al;
} else {
++rx;
*lxi = lx;
res += ar - al;
*nai = tx[bl];
*npi = bl;
++bl;
}
}
return res;
}

int npi[1700005];
int nri[1700005];
int nli[1700005];

int main() {
const int n = getint();
const int m = getint();
for (int i = 1; i <= n; ++i) {
ai[i] = getint();
}

if (n == 1) {
for (int i = 1; i <= m; ++i) {
putchar('0');
putchar('\n');
}
return 0;
}

const int logn = 31 - __builtin_clz(n - 1);

lld ans = 0;
for (int i = logn; i >= 0; --i) {
memcpy(tx, ai, (n + 1) * sizeof(int));
for (int j = 1; j <= n; j += 1 << (logn - i + 1)) {
ans += mgx(npi + n * i + j, ai + j, nri + n * i + j, nli + n * i + j, j,
min(n + 1, j + (1 << (logn - i))),
min(n + 1, j + (1 << (logn - i))),
min(n + 1, j + (1 << (logn - i + 1))));
}
}

putll(ans);
putchar('\n');

for (int asdf = 1; asdf < m; ++asdf) {
int x = getint();
for (int i = 0, p = 0; i <= logn; ++i, p += n) {
if (nri[p + x]) {
ans -= nri[p + x] - querydown(logn - i + 1, x) - 1;
fixdown(logn - i + 1, x);
} else {
ans -= nli[p + x] - queryup(logn - i + 1, x);
fixup(logn - i + 1, x);
}
x = npi[p + x];
}
putll(ans);
putchar('\n');
}
}

笛卡尔树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define rint register int
const int N=1e5+10;
inline int read(){
rint x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch<='9'&&ch>='0'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int stk[N],tp,a[N],ch[N][2];
int main(){
rint n=read(),m=read();
for(rint i=1;i<=n;i++){
a[i]=read();
while(tp&&a[stk[tp]]<=a[i])ch[i][0]=stk[tp--];
ch[stk[tp]][1]=i;
stk[++tp]=i;
}
rint rt=stk[1];
while(m--){
rint l=read(),r=read();
for(rint x=rt;;x=x>r?ch[x][0]:ch[x][1])
if(l<=x&&x<=r){
printf("%d\n",a[x]);
break;
}
}
}

分块/树分块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include<bits/stdc++.h>
using namespace std;
namespace lin4xu{

#define re register
#define rint re int
#define ll long long
#define rll re ll
#define db double
#define rdb re db
#define rch re char
#define ldb long db
#define rldb re ldb
#define un unsigned
#define ull un ll
#define rull re ull

template <typename T> inline T read()
{
re T ans=0;rint f=0;rch ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=1;
for(;isdigit(ch);ch=getchar()) ans=ans*10+(ch&15);
return f?-ans:ans;
}

const int maxn=4e4+10;
const int S=1000;
const int nS=maxn/S+10;

struct edge
{
int to,next;
}e[maxn<<1];
bitset<maxn> mapp[nS][nS];
bitset<maxn> b;
int head[maxn];
int sta[nS],rk[nS];
int fa[maxn];
int top[maxn];
int deep[maxn];
int id[maxn];
int a[maxn];
int lsh[maxn];
int up[maxn];
int size[maxn];
int son[maxn];
int cnt,s_cnt,sta_top;

inline void add(rint x,rint y)
{
e[++cnt].to=y;
e[cnt].next=head[x];
head[x]=cnt;
}

inline int dfs(rint x)
{
rint tmp=0;size[x]=1;
for(rint i=head[x];i;i=e[i].next)
{
rint v=e[i].to;
if(v==fa[x]) continue;
deep[v]=deep[x]+1,fa[v]=x,tmp=max(tmp,dfs(v)),size[x]+=size[v];
if(size[son[x]]<size[v]) son[x]=v;
}
if(++tmp==S) rk[id[x]=++s_cnt]=x,tmp=0;
return tmp;
}

void dfs(rint x,rint t)
{
top[x]=t;
if(id[x])
{
up[x]=sta[sta_top];
sta[++sta_top]=id[x];
mapp[id[x]][id[x]][a[x]]=1;
for(rint i=up[x],pre=id[x];i;pre=i,i=up[rk[i]])
{
mapp[id[x]][i]=mapp[id[x]][pre],mapp[id[x]][i][a[rk[i]]]=1;
for(rint j=fa[rk[pre]];j!=rk[i];j=fa[j]) mapp[id[x]][i][a[j]]=1;
}
}
if(son[x]) dfs(son[x],t);
for(rint i=head[x];i;i=e[i].next)
{
rint v=e[i].to;
if(v==fa[x]||v==son[x]) continue;
dfs(v,v);
}
if(id[x]) sta_top--;
else up[x]=sta[sta_top];
}

inline int query(rint x,rint y)
{
while(top[x]!=top[y]) deep[top[x]]<deep[top[y]]?y=fa[top[y]]:x=fa[top[x]];
return deep[x]<deep[y]?x:y;
}

int main()
{
rint n=read<int>(),m=read<int>();
for(rint i=1;i<=n;i++) a[i]=lsh[i]=read<int>();
sort(lsh+1,lsh+n+1);rint N=unique(lsh+1,lsh+n+1)-lsh-1;
for(rint i=1;i<=n;i++) a[i]=lower_bound(lsh+1,lsh+N+1,a[i])-lsh;
for(rint i=1;i<n;i++)
{
rint x=read<int>(),y=read<int>();
add(x,y),add(y,x);
}
deep[1]=1,dfs(1),dfs(1,1);
rint lastans=0;
for(rint tt=1;tt<=m;tt++)
{
rint x=read<int>()^lastans,y=read<int>(),lca=query(x,y);
if(deep[rk[up[x]]]>=deep[lca])
{
rint fax=rk[up[x]];
for(rint i=x;i!=fax;i=fa[i]) b[a[i]]=1;
for(;deep[rk[up[fax]]]>=deep[lca];fax=rk[up[fax]]);
b|=mapp[up[x]][id[fax]];
if(fax!=lca)
{
for(rint i=fa[fax];i!=lca;i=fa[i]) b[a[i]]=1;
b[a[lca]]=1;
}
}
else
{
for(rint i=x;i!=lca;i=fa[i]) b[a[i]]=1;
b[a[lca]]=1;
}
if(deep[rk[up[y]]]>deep[lca])
{
rint fay=rk[up[y]];
for(rint i=y;i!=fay;i=fa[i]) b[a[i]]=1;
for(;deep[rk[up[fay]]]>deep[lca];fay=rk[up[fay]]);
b|=mapp[up[y]][id[fay]];
for(rint i=fa[fay];i!=lca;i=fa[i]) b[a[i]]=1;
}
else for(rint i=y;i!=lca;i=fa[i]) b[a[i]]=1;
printf("%d\n",lastans=b.count()),b.reset();
}
return 0;
}};

int main(){return lin4xu::main();}

克鲁斯卡尔重构树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include<cstdio>
#include<queue>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=1e6+10;
struct Edge{
int to,nxt,val,val2;
bool operator < (const Edge&A)const{
return val2>A.val2;
}
}e[N<<1],e2[N];
int h[N],idx;
void Ins(int a,int b,int c){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;
}
int cnt,lst[N],f[N];
int find(int x){
return x==f[x]?x:(f[x]=find(f[x]));
}
struct Node{
int id;ll val;
bool operator < (const Node&A)const{
return val>A.val;
}
};
priority_queue<Node> q;
bool vis[N];
ll dis[N];
void dij(){
memset(vis,0,sizeof(vis));
memset(dis,0x3f,sizeof(dis));
q.push((Node){1,0});dis[1]=0;
while(!q.empty()){
Node u=q.top();q.pop();
if(vis[u.id])continue;
vis[u.id]=1;
for(int i=h[u.id];i;i=e[i].nxt){
int v=e[i].to;
if(dis[v]>dis[u.id]+e[i].val){
dis[v]=dis[u.id]+e[i].val;
q.push((Node){v,dis[v]});
}
}
}
}
int dfn[N],Time,w[N],siz[N],p[N][25];
ll rk[N];
void dfs(int u){
dfn[u]=++Time;
rk[Time]=dis[u];
siz[u]=1;
for(int i=0;p[u][i];i++)
p[u][i+1]=p[p[u][i]][i];
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
p[v][0]=u;
dfs(v);
siz[u]+=siz[v];
}
}
struct Tree{
int l,r;
ll Min;
}tr[N<<2];
#define ls rt<<1
#define rs rt<<1|1
void Build(int rt,int l,int r){
tr[rt].l=l;tr[rt].r=r;
if(l==r){
tr[rt].Min=rk[l];
return;
}
int mid=l+r>>1;
Build(ls,l,mid);
Build(rs,mid+1,r);
tr[rt].Min=min(tr[ls].Min,tr[rs].Min);
}
ll query(int rt,int l,int r){
if(l<=tr[rt].l&&tr[rt].r<=r)return tr[rt].Min;
int mid=tr[rt].l+tr[rt].r>>1;
ll res=dis[0];
if(l<=mid)res=min(res,query(ls,l,r));
if(r>mid)res=min(res,query(rs,l,r));
return res;
}
int main(){
int T;
scanf("%d",&T);
while(T--){
memset(h,0,sizeof(h));
idx=0;
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=m;i++){
scanf("%d%d%d%d",&e2[i].nxt,&e2[i].to,&e2[i].val,&e2[i].val2);
Ins(e2[i].nxt,e2[i].to,e2[i].val);
Ins(e2[i].to,e2[i].nxt,e2[i].val);
}
dij();
memset(h,0,sizeof(h));cnt=idx=0;
for(int i=1;i<=n;i++){
f[i]=i;
w[i]=0;
lst[i]=++cnt;
}
sort(e2+1,e2+m+1);
for(int i=1;i<=m;i++){
int x=find(e2[i].nxt),y=find(e2[i].to);
if(x!=y){
++cnt;w[cnt]=e2[i].val2;
Ins(cnt,lst[x],0);
Ins(cnt,lst[y],0);
f[x]=y;dis[cnt]=dis[0];
lst[y]=cnt;
}
}
memset(p,0,sizeof(p));
Time=0;
dfs(cnt);
Build(1,1,cnt);
int ask,k,s;
scanf("%d%d%d",&ask,&k,&s);
ll res=0;
while(ask--){
int v,wv;
scanf("%d%d",&v,&wv);
v=(v+k*res-1)%n+1;
wv=(wv+k*res)%(s+1);
for(int i=23;~i;i--)
if(p[v][i]&&w[p[v][i]]>wv)v=p[v][i];
printf("%lld\n",res=query(1,dfn[v],dfn[v]+siz[v]-1));
}
}
return 0;
}

扫描线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include<vector>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=4e5+10;
typedef long long ll;
struct Node{
int x,y,xx,yy;
}a[N];
struct Add{
int l,r,v;
Add(){}
Add(int a,int b,int c){
l=a;r=b;v=c;
}
};
vector<Add> vec[N];
int rk[N],s[N],tmp[N],tot;
int len[N<<2],cnt[N<<2];
void update(int rt,int l,int r,int cl,int cr,int v){
if(cl<=l&&r<=cr){
cnt[rt]+=v;
if(cnt[rt])len[rt]=rk[r+1]-rk[l];
else if(l!=r)len[rt]=len[rt<<1]+len[rt<<1|1];
else len[rt]=0;
return;
}
int mid=l+r>>1;
if(cl<=mid)update(rt<<1,l,mid,cl,cr,v);
if(cr>mid)update(rt<<1|1,mid+1,r,cl,cr,v);
if(cnt[rt])len[rt]=rk[r+1]-rk[l];
else len[rt]=len[rt<<1]+len[rt<<1|1];
}
int main(){
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d%d%d%d",&a[i].x,&a[i].y,&a[i].xx,&a[i].yy);
tmp[++tot]=a[i].y;tmp[++tot]=a[i].yy;
}
sort(tmp+1,tmp+tot+1);
tot=unique(tmp+1,tmp+tot+1)-tmp-1;
for(int i=1;i<=n;i++){
int x=lower_bound(tmp+1,tmp+tot+1,a[i].y)-tmp;
rk[x]=a[i].y;a[i].y=x;
x=lower_bound(tmp+1,tmp+tot+1,a[i].yy)-tmp;
rk[x]=a[i].yy;a[i].yy=x;
}
int maxy=tot;tot=0;
for(int i=1;i<=n;i++)tmp[++tot]=a[i].x,tmp[++tot]=a[i].xx;
sort(tmp+1,tmp+tot+1);
tot=unique(tmp+1,tmp+tot+1)-tmp-1;
for(int i=1;i<=n;i++){
int x=lower_bound(tmp+1,tmp+tot+1,a[i].x)-tmp;
s[x]=a[i].x;
vec[x].push_back(Add(a[i].y,a[i].yy-1,1));
x=lower_bound(tmp+1,tmp+tot+1,a[i].xx)-tmp;
s[x]=a[i].xx;
vec[x].push_back(Add(a[i].y,a[i].yy-1,-1));
}
ll res=0;
for(int i=1;i<tot;i++){
for(int j=0;j<vec[i].size();j++)
update(1,1,maxy,vec[i][j].l,vec[i][j].r,vec[i][j].v);
res+=1ll*(s[i+1]-s[i])*len[1];
}
printf("%lld\n",res);
return 0;
}

欧拉游览树 ETT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
/*
虽然上文提到过块状链表实现 ETT
在某些情况下可能较简单,但对于此题块状链表复杂度有可能无法通过而且实现较繁琐,所以这份代码采用
FHQ Treap 实现。
*/
#include <bits/stdc++.h>
#define N 1000000
#define int long long
using namespace std;
/*FHQ TREAP*/
int rt, tot, f[N], rnd[N], ls[N], rs[N], siz[N], tag[N], val[N], sum[N], pd[N],
pds[N];

void pushup(int x) {
siz[x] = siz[ls[x]] + siz[rs[x]] + 1;
sum[x] = sum[ls[x]] + sum[rs[x]] + val[x];
pds[x] = pds[ls[x]] + pds[rs[x]] + pd[x];
}

void link(int x, int c, int y) {
if (c)
rs[x] = y;
else
ls[x] = y;
if (y) f[y] = x;
pushup(x);
}

int newNode(int x, int y) {
siz[++tot] = 1;
val[tot] = sum[tot] = x;
pd[tot] = pds[tot] = y;
rnd[tot] = rand();
return tot;
}

void setTag(int x, int v) {
tag[x] += v;
sum[x] += v * pds[x];
val[x] += v * pd[x];
}

void pushdown(int x) {
if (ls[x]) setTag(ls[x], tag[x]);
if (rs[x]) setTag(rs[x], tag[x]);
tag[x] = 0;
}

void split(int now, int k, int &x, int &y) {
f[now] = 0;
if (!now) {
x = y = 0;
return;
}
pushdown(now);
if (siz[ls[now]] + 1 <= k) {
x = now;
split(rs[now], k - siz[ls[now]] - 1, rs[x], y);
link(x, 1, rs[x]);
} else {
y = now;
split(ls[now], k, x, ls[y]);
link(y, 0, ls[y]);
}
}

int merge(int x, int y) {
if (!x || !y) return x | y;
if (rnd[x] < rnd[y]) {
pushdown(x);
link(x, 1, merge(rs[x], y));
return x;
} else {
pushdown(y);
link(y, 0, merge(x, ls[y]));
return y;
}
}

int rnk(int x) {
int c = 1, ans = 0;
while (x) {
if (c) ans += siz[ls[x]] + 1;
c = (rs[f[x]] == x);
x = f[x];
}
return ans;
}

/*ETT*/
int s[N], e[N];

void add(int x, int v) {
int a, b, c;
split(rt, rnk(s[x]) - 1, a, b);
split(b, rnk(e[x]) - rnk(s[x]) + 1, b,
c); // 这里 b 是我们要进行操作的子树的括号序列。
setTag(b, v);
rt = merge(merge(a, b), c);
}

int query(int x) {
int a, b;
split(rt, rnk(s[x]), a, b);
int ans = sum[a];
rt = merge(a, b);
return ans;
}

void changeFa(int x, int y) {
int a, b, c, d;
split(rt, rnk(s[x]) - 1, a, b);
split(b, rnk(e[x]) - rnk(s[x]) + 1, b, c);
a = merge(
a,
c); // 因为我们确定不了要设置为父亲的节点在括号序列中的哪边,所以先把两边合并。
split(a, rnk(s[y]), a, d);
rt = merge(merge(a, b), d); // 把要进行操作的子树放在父亲括号序列的最前面。
}

/*main function*/
int n, m, w[N];
vector<int> v[N];

void dfs(int x) {
rt = merge(rt, s[x] = newNode(w[x], 1));
for (auto to : v[x]) dfs(to);
rt = merge(rt, e[x] = newNode(-w[x], -1));
}

signed main() {
cin >> n;
for (int i = 2; i <= n; i++) {
int f;
cin >> f;
v[f].push_back(i);
}
for (int i = 1; i <= n; i++) cin >> w[i];
dfs(1);
cin >> m;
for (int i = 1; i <= m; i++) {
char c;
cin >> c;
if (c == 'Q') {
int d;
cin >> d;
cout << query(d) << endl;
} else if (c == 'C') {
int x, y;
cin >> x >> y;
changeFa(x, y);
} else {
int p, q;
cin >> p >> q;
add(p, q);
}
}
return 0;
}

Top Tree

Sqrt Tree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
SqrtTreeItem op(const SqrtTreeItem &a, const SqrtTreeItem &b);

inline int log2Up(int n) {
int res = 0;
while ((1 << res) < n) {
res++;
}
return res;
}

class SqrtTree {
private:
int n, lg, indexSz;
vector<SqrtTreeItem> v;
vector<int> clz, layers, onLayer;
vector<vector<SqrtTreeItem> > pref, suf, between;

inline void buildBlock(int layer, int l, int r) {
pref[layer][l] = v[l];
for (int i = l + 1; i < r; i++) {
pref[layer][i] = op(pref[layer][i - 1], v[i]);
}
suf[layer][r - 1] = v[r - 1];
for (int i = r - 2; i >= l; i--) {
suf[layer][i] = op(v[i], suf[layer][i + 1]);
}
}

inline void buildBetween(int layer, int lBound, int rBound, int betweenOffs) {
int bSzLog = (layers[layer] + 1) >> 1;
int bCntLog = layers[layer] >> 1;
int bSz = 1 << bSzLog;
int bCnt = (rBound - lBound + bSz - 1) >> bSzLog;
for (int i = 0; i < bCnt; i++) {
SqrtTreeItem ans;
for (int j = i; j < bCnt; j++) {
SqrtTreeItem add = suf[layer][lBound + (j << bSzLog)];
ans = (i == j) ? add : op(ans, add);
between[layer - 1][betweenOffs + lBound + (i << bCntLog) + j] = ans;
}
}
}

inline void buildBetweenZero() {
int bSzLog = (lg + 1) >> 1;
for (int i = 0; i < indexSz; i++) {
v[n + i] = suf[0][i << bSzLog];
}
build(1, n, n + indexSz, (1 << lg) - n);
}

inline void updateBetweenZero(int bid) {
int bSzLog = (lg + 1) >> 1;
v[n + bid] = suf[0][bid << bSzLog];
update(1, n, n + indexSz, (1 << lg) - n, n + bid);
}

void build(int layer, int lBound, int rBound, int betweenOffs) {
if (layer >= (int)layers.size()) {
return;
}
int bSz = 1 << ((layers[layer] + 1) >> 1);
for (int l = lBound; l < rBound; l += bSz) {
int r = min(l + bSz, rBound);
buildBlock(layer, l, r);
build(layer + 1, l, r, betweenOffs);
}
if (layer == 0) {
buildBetweenZero();
} else {
buildBetween(layer, lBound, rBound, betweenOffs);
}
}

void update(int layer, int lBound, int rBound, int betweenOffs, int x) {
if (layer >= (int)layers.size()) {
return;
}
int bSzLog = (layers[layer] + 1) >> 1;
int bSz = 1 << bSzLog;
int blockIdx = (x - lBound) >> bSzLog;
int l = lBound + (blockIdx << bSzLog);
int r = min(l + bSz, rBound);
buildBlock(layer, l, r);
if (layer == 0) {
updateBetweenZero(blockIdx);
} else {
buildBetween(layer, lBound, rBound, betweenOffs);
}
update(layer + 1, l, r, betweenOffs, x);
}

inline SqrtTreeItem query(int l, int r, int betweenOffs, int base) {
if (l == r) {
return v[l];
}
if (l + 1 == r) {
return op(v[l], v[r]);
}
int layer = onLayer[clz[(l - base) ^ (r - base)]];
int bSzLog = (layers[layer] + 1) >> 1;
int bCntLog = layers[layer] >> 1;
int lBound = (((l - base) >> layers[layer]) << layers[layer]) + base;
int lBlock = ((l - lBound) >> bSzLog) + 1;
int rBlock = ((r - lBound) >> bSzLog) - 1;
SqrtTreeItem ans = suf[layer][l];
if (lBlock <= rBlock) {
SqrtTreeItem add =
(layer == 0) ? (query(n + lBlock, n + rBlock, (1 << lg) - n, n))
: (between[layer - 1][betweenOffs + lBound +
(lBlock << bCntLog) + rBlock]);
ans = op(ans, add);
}
ans = op(ans, pref[layer][r]);
return ans;
}

public:
inline SqrtTreeItem query(int l, int r) { return query(l, r, 0, 0); }

inline void update(int x, const SqrtTreeItem &item) {
v[x] = item;
update(0, 0, n, 0, x);
}

SqrtTree(const vector<SqrtTreeItem> &a)
: n((int)a.size()), lg(log2Up(n)), v(a), clz(1 << lg), onLayer(lg + 1) {
clz[0] = 0;
for (int i = 1; i < (int)clz.size(); i++) {
clz[i] = clz[i >> 1] + 1;
}
int tlg = lg;
while (tlg > 1) {
onLayer[tlg] = (int)layers.size();
layers.push_back(tlg);
tlg = (tlg + 1) >> 1;
}
for (int i = lg - 1; i >= 0; i--) {
onLayer[i] = max(onLayer[i], onLayer[i + 1]);
}
int betweenLayers = max(0, (int)layers.size() - 1);
int bSzLog = (lg + 1) >> 1;
int bSz = 1 << bSzLog;
indexSz = (n + bSz - 1) >> bSzLog;
v.resize(n + indexSz);
pref.assign(layers.size(), vector<SqrtTreeItem>(n + indexSz));
suf.assign(layers.size(), vector<SqrtTreeItem>(n + indexSz));
between.assign(betweenLayers, vector<SqrtTreeItem>((1 << lg) + bSz));
build(0, 0, n, 0);
}
};

O(n)-O(1) ST表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int inf=1e9;
int n,m,w[2000000],lg2[1500000];
int pos[1500000],lst[1500000],dep[1500000],id,mm,S,ch[1500000][2],stack[1500000];
struct LIM_RMQ
{
int w[1500000],bl[1500000],blo,L[150000],R[150000],pos[10000],val[150000],minn[150000],minpos[150000],t[1000],n_st;
struct ST_node{int f,id;bool operator <(const ST_node &tmp)const{return f<tmp.f;}};
struct STable
{
ST_node a[4][12];//n=1e6的话logn/2只有9...所以放心开
void make(int w[],int n)
{
for(int i=1;i<=n;i++)a[0][i]=(ST_node){w[i],i};
for(int i=1;(1<<i)<=n;i++)//st表,不过要同时处理出最小值所在的位置
for(int j=1;j+(1<<i)-1<=n;j++)
a[i][j]=min(a[i-1][j],a[i-1][j+(1<<(i-1))]);
}
ST_node query(int l,int r)
{
int len=lg2[r-l+1];
return min(a[len][l],a[len][r-(1<<len)+1]);
}
}st[10000];//这里只有sqrt(n)个,也不用开这么大
struct STable_block
{
ST_node a[20][150000];//这个占空间最大了吧...不过也只是O(n)的
void make(int w[],int n)
{
for(int i=1;i<=n;i++)a[0][i]=(ST_node){w[i],i};
for(int i=1;(1<<i)<=n;i++)
for(int j=1;j+(1<<i)-1<=n;j++)
a[i][j]=min(a[i-1][j],a[i-1][j+(1<<(i-1))]);
}
ST_node query(int l,int r)
{
int len=lg2[r-l+1];//细节,math库的log2函数不能看做O(1)的,要提前处理
return min(a[len][l],a[len][r-(1<<len)+1]);
}
}st_block;
void make(int a[],int n)
{
for(int i=1;i<=n;i++)w[i]=a[i];
lg2[1]=0;for(int i=2;i<=n;i++)lg2[i]=lg2[i>>1]+1;//处理log2
blo=max(lg2[n]>>1,1);//分块
for(int i=1;i<=n;i++)bl[i]=(i-1)/blo+1;
for(int i=1;i<=bl[n];i++)L[i]=(i-1)*blo+1,R[i]=min(i*blo,n),minn[i]=inf;
w[0]=w[1]-1;
for(int i=1;i<=bl[n];i++)
{
int tmp=0,nn=0;
for(int j=L[i];j<=R[i];j++)
{
t[++nn]=w[j],tmp=tmp<<1|(w[j]-w[j-1]==1?1:0);
if(w[j]<minn[i])minn[i]=w[j],minpos[i]=j;
//可以用一个状压来表示本质
}
if(!pos[tmp])st[pos[tmp]=++n_st].make(t,nn);
val[i]=pos[tmp];//记下每个块属于哪一个本质
}
st_block.make(minn,bl[n]);//块间rmq
}
ST_node query_block(int id,int l,int r){ST_node t=st[val[id]].query(l-L[id]+1,r-L[id]+1);return (ST_node){t.f,t.id+L[id]-1};}
//实际位置=块左端点+块内查询位置-1,如果你把块内查询从0开始写就可以省略-1
int query(int l,int r)
{
int bll=bl[l],blr=bl[r];
if(bll==blr)return query_block(bll,l,r).id;//一个块
int ml=query_block(bll,l,R[bll]).id,mr=query_block(blr,L[blr],r).id,mm;
if(w[ml]<w[mr])mm=ml;else mm=mr;//两端零散块
if(bll+1<=blr-1)//整块
{
int mmid=minpos[st_block.query(bll+1,blr-1).id];
if(w[mmid]<w[mm])mm=mmid;
}
return mm;
}
}a;
void dfs(int u,int depth)
{
lst[u]=++id,pos[id]=u;dep[id]=depth;//处理欧拉序列
for(int i=0;i<=1;i++)
if(ch[u][i])
{
dfs(ch[u][i],depth+1);
pos[++id]=u,dep[id]=depth;
}
}
int getin()
{
int x=0;char ch=getchar();
while(ch<'0'||ch>'9')ch=getchar();
while(ch>='0'&&ch<='9')x=x*10+ch-48,ch=getchar();
return x;
}
int wt[30];
void putout(int x)
{
if(!x){putchar('0');return;}
int l=0;
while(x)wt[++l]=x%10,x/=10;
while(l)putchar(wt[l--]+48);
puts("");
}
void build_tree()//笛卡尔树
{
int top=1;stack[1]=S=1;//开始根为1
for(int i=2;i<=n;i++)
{
int lst=0;
while(top&&w[i]>w[stack[top]])lst=stack[top--];
if(lst)ch[i][0]=lst;
if(top)ch[stack[top]][1]=i;else S=i;
stack[++top]=i;
}
dfs(S,1);
}
int main()
{
n=getin(),m=getin();
for(int i=1;i<=n;i++)w[i]=getin();
build_tree();
a.make(dep,id);
for(int i=1;i<=m;i++)
{
int l=lst[getin()],r=lst[getin()];
if(l>r)swap(l,r);//可能第一次出现的位置是反过来的
putout(w[pos[a.query(l,r)]]);
}
}

析合树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include<bits/stdc++.h>
const int N=1e5+10;
char buf[1<<20],*p1,*p2;
char gc(){return p1==p2?p2=buf+fread(p1=buf,1,1<<20,stdin),p1==p2?EOF:*p1++:*p1++;}
inline int read(){
int x=0,f=1;char ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1,ch=gc();}
while(ch<='9'&&ch>='0'){x=x*10+ch-'0';ch=gc();}
return x*f;
}
int n,a[N],b[N];
int stmx[N][20],stmi[N][20];
void init(){
for(int i=1;i<=n;i++)stmx[i][0]=stmi[i][0]=a[i];
for(int j=1;(1<<j)<=n;j++){
b[1<<j]=j;
for(int i=1;i+(1<<j)-1<=n;i++){
stmx[i][j]=std::max(stmx[i][j-1],stmx[i+(1<<j-1)][j-1]);
stmi[i][j]=std::min(stmi[i][j-1],stmi[i+(1<<j-1)][j-1]);
}
}
for(int i=1;i<=n;i++)if(!b[i])b[i]=b[i-1];
}
int qmx(int l,int r){
int len=b[r-l+1];
return std::max(stmx[l][len],stmx[r-(1<<len)+1][len]);
}
int qmi(int l,int r){
int len=b[r-l+1];
return std::min(stmi[l][len],stmi[r-(1<<len)+1][len]);
}
struct Tree{
int l,r,Mi,tar;
}T[N<<2];
#define ls rt<<1
#define rs rt<<1|1
void pushup(int rt){
T[rt].Mi=std::min(T[ls].Mi,T[rs].Mi);
}
void Build(int rt,int l,int r){
T[rt].l=l;T[rt].r=r;
if(l==r)return;
int mid=l+r>>1;
Build(ls,l,mid);Build(rs,mid+1,r);
}
void pushdown(int rt){
T[ls].Mi+=T[rt].tar;
T[ls].tar+=T[rt].tar;
T[rs].Mi+=T[rt].tar;
T[rs].tar+=T[rt].tar;
T[rt].tar=0;
}
void update(int rt,int l,int r,int w){
if(l<=T[rt].l&&T[rt].r<=r){
T[rt].Mi+=w;
T[rt].tar+=w;
return ;
}
if(T[rt].tar)pushdown(rt);
int mid=T[rt].l+T[rt].r>>1;
if(l<=mid)update(ls,l,r,w);
if(r>mid)update(rs,l,r,w);
pushup(rt);
}
int query(int rt){
if(T[rt].l==T[rt].r)return T[rt].l;
if(T[rt].tar)pushdown(rt);
if(!T[ls].Mi)return query(ls);
else return query(rs);
}
int stk[N],stkmx[N],stkmi[N],tp,tp1,tp2;
int lim[N<<1],L[N<<1],R[N<<1],idf[N],cnt;
bool ishe[N<<1];
bool pd(int l,int r){
return qmx(l,r)-qmi(l,r)==r-l;
}
struct Edge{
int to,nxt;
}e[N<<1];
int h[N<<1],idx;
void Ins(int a,int b){
e[++idx].to=b;e[idx].nxt=h[a];h[a]=idx;
}
int p[N<<1][25],dep[N<<1];
void dfs(int u){
for(int i=0;p[u][i];i++)
p[u][i+1]=p[p[u][i]][i];
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==p[u][0])continue;
p[v][0]=u;
dep[v]=dep[u]+1;
dfs(v);
}
}
int get(int u,int d){
for(int i=0;d;d>>=1,i++)
if(d&1)u=p[u][i];
return u;
}
int lca(int x,int y){
if(dep[x]<dep[y])std::swap(x,y);
x=get(x,dep[x]-dep[y]);
if(x==y)return x;
for(int i=21;i>=0;i--)
if(p[x][i]!=p[y][i])
x=p[x][i],y=p[y][i];
return p[x][0];
}
int main(){
n=read();
for(int i=1;i<=n;i++)a[i]=read();
init();
Build(1,1,n);
for(int i=1;i<=n;i++){
while(tp1&&a[stkmx[tp1]]<=a[i])update(1,stkmx[tp1-1]+1,stkmx[tp1],-a[stkmx[tp1]]),tp1--;
while(tp2&&a[stkmi[tp2]]>=a[i])update(1,stkmi[tp2-1]+1,stkmi[tp2],a[stkmi[tp2]]),tp2--;
update(1,stkmx[tp1]+1,i,a[i]);
update(1,stkmi[tp2]+1,i,-a[i]);
stkmx[++tp1]=i;stkmi[++tp2]=i;
int tl=query(1);idf[i]=++cnt;L[cnt]=R[cnt]=i;
int u=cnt;
while(tp&&L[stk[tp]]>=tl){
if(ishe[stk[tp]]&&pd(lim[stk[tp]],i)){
R[stk[tp]]=i;
Ins(stk[tp],u);
u=stk[tp--];
}else if(pd(L[stk[tp]],i)){
ishe[++cnt]=1;
Ins(cnt,stk[tp]);Ins(cnt,u);
L[cnt]=L[stk[tp]];R[cnt]=i;tp--;
lim[cnt]=L[u];u=cnt;
}else{
Ins(++cnt,u);
do{
Ins(cnt,stk[tp--]);
}while(tp&&!pd(L[stk[tp]],i));
L[cnt]=L[stk[tp]];R[cnt]=i;Ins(cnt,stk[tp--]);
u=cnt;
}
}
stk[++tp]=u;
update(1,1,i,-1);
}
dfs(stk[1]);
int m=read();
while(m--){
int x=read(),y=read();
x=idf[x];y=idf[y];
int lc=lca(x,y);
if(ishe[lc]){
x=get(x,dep[x]-dep[lc]-1);y=get(y,dep[y]-dep[lc]-1);
printf("%d %d\n",L[x],R[y]);
}else printf("%d %d\n",L[lc],R[lc]);
}
return 0;
}

网络流

最大流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include<bits/stdc++.h>
using namespace std;
const int N=5e4+10;
struct Edge{
int to,nxt,val;
}e[N<<1];
int h[N],idx=1;
void Ins(int a,int b,int c){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;
}
int s,t;
int d[N],now[N],q[N],hh,tt;
bool bfs(){
memset(d,-1,sizeof(d));
d[s]=1;hh=tt=0;q[0]=s;now[s]=h[s];
while(hh<=tt){
int u=q[hh++];
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(e[i].val&&d[v]==-1){
d[v]=d[u]+1;
now[v]=h[v];
if(v==t)return 1;
q[++tt]=v;
}
}
}
return 0;
}
int dinic(int u,int f){
if(u==t)return f;
int re=f;
for(int i=now[u];i&&re;i=e[i].nxt){
int v=e[i].to;now[u]=i;
if(e[i].val&&d[v]==d[u]+1){
int k=dinic(v,min(re,e[i].val));
if(!k)d[v]=-1;
re-=k;e[i].val-=k;e[i^1].val+=k;
}
}
return f-re;
}
int main(){
int n,m;
scanf("%d%d%d%d",&n,&m,&s,&t);
for(int i=1;i<=m;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
Ins(a,b,c);Ins(b,a,0);
}
long long flow=0;
while(bfs())flow+=dinic(s,0x7fffffff);
printf("%lld\n",flow);
return 0;
}

最小割树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=2e3+10;
struct Edge{
int to,nxt,val;
}e[N<<1],e2[N<<1];
int h[N],idx=1;
void Ins(int a,int b,int c){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;
}
int h2[N],idx2=1;
void Ins2(int a,int b,int c){
e2[++idx2].to=b;e2[idx2].nxt=h2[a];
h2[a]=idx2;e2[idx2].val=c;
}
int a[N],q[N],d[N],hh,tt;
int st,ed,now[N];
bool bfs(){
memset(d,-1,sizeof(d));
hh=0;tt=-1;q[++tt]=st;
now[st]=h[st];d[st]=1;
while(hh<=tt){
int u=q[hh++];
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(d[v]==-1&&e[i].val){
d[v]=d[u]+1;
now[v]=h[v];
if(v==ed)return 1;
q[++tt]=v;
}
}
}
return 0;
}
int dinic(int u,int f){
if(u==ed)return f;
int re=f;
for(int i=now[u];i&&re;i=e[i].nxt){
int v=e[i].to;
now[u]=i;
if(d[v]==d[u]+1&&e[i].val){
int k=dinic(v,min(e[i].val,re));
if(!k)d[v]=-1;
re-=k;e[i].val-=k;e[i^1].val+=k;
}
}
return f-re;
}
void get(int s,int t){
st=s;ed=t;
for(int i=2;i<=idx;i+=2){
int tmp=e[i].val+e[i^1].val>>1;
e[i].val=e[i^1].val=tmp;
}
int Max=0;
while(bfs())Max+=dinic(st,1e9);
Ins2(s,t,Max);Ins2(t,s,Max);
}
int tmp1[N],tmp2[N];
void dfs(int l,int r){
if(l==r)return ;
get(a[l],a[r]);
int cl=0,cr=0;
for(int i=l+1;i<r;i++){
if(d[a[i]]!=-1)tmp1[++cl]=a[i];
else tmp2[++cr]=a[i];
}
int now=l;
for(int i=1;i<=cl;i++)
a[++now]=tmp1[i];
for(int i=1;i<=cr;i++)
a[++now]=tmp2[i];
dfs(l,l+cl);dfs(l+cl+1,r);
}
int dep[N],p[N][25],Max[N][25];
void init(int u){
for(int i=0;p[u][i];i++){
p[u][i+1]=p[p[u][i]][i];
Max[u][i+1]=min(Max[u][i],Max[p[u][i]][i]);
}
for(int i=h2[u];i;i=e2[i].nxt){
int v=e2[i].to;
if(v==p[u][0])continue;
dep[v]=dep[u]+1;
p[v][0]=u;
Max[v][0]=e2[i].val;
init(v);
}
}
int lca(int x,int y){
int res=1e9;
if(dep[x]<dep[y])swap(x,y);
for(int d=dep[x]-dep[y],i=0;d;d>>=1,i++)
if(d&1)res=min(res,Max[x][i]),x=p[x][i];
if(x==y)return res;
for(int i=20;~i;i--)
if(p[x][i]!=p[y][i]){
res=min(res,Max[x][i]);
res=min(res,Max[y][i]);
x=p[x][i];
y=p[y][i];
}
res=min(res,Max[x][0]);res=min(res,Max[y][0]);
return res;
}
int main(){
int n,m;
scanf("%d%d",&n,&m);
while(m--){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
Ins(a,b,c);Ins(b,a,c);
}
for(int i=0;i<=n;i++)
a[i]=i;
dfs(0,n);
init(1);
int ask;
scanf("%d",&ask);
while(ask--){
int u,v;
scanf("%d%d",&u,&v);
printf("%d\n",lca(u,v));
}
return 0;
}

spfa 费用流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include<bits/stdc++.h>
using namespace std;
const int N=5e3+10;
const int M=5e4+10;
struct Edge{
int to,nxt,val,cost;
}e[M<<1];
int h[N],idx=1;
void Ins(int a,int b,int c,int d){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;e[idx].cost=d;
}
int n,m,st,ed,dis[N],pre[N];
deque<int> q;
bool inq[N];
bool spfa(){
memset(dis,0x3f,sizeof(dis));
memset(inq,0,sizeof(inq));
q.push_back(st);dis[st]=0;
while(!q.empty()){
int u=q.front();q.pop_front();inq[u]=0;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(dis[v]>dis[u]+e[i].cost&&e[i].val){
pre[v]=i;
dis[v]=dis[u]+e[i].cost;
if(!inq[v]){
if(q.empty()||dis[v]<=dis[q.front()])q.push_front(v);
else q.push_back(v);
inq[v]=1;
}
}
}
}
return dis[ed]!=0x3f3f3f3f;
}
int main(){
scanf("%d%d%d%d",&n,&m,&st,&ed);
for(int i=1;i<=m;i++){
int a,b,c,d;
scanf("%d%d%d%d",&a,&b,&c,&d);
Ins(a,b,c,d);
Ins(b,a,0,-d);
}
int res1=0,res2=0;
while(spfa()){
int Min=0x3f3f3f3f;
for(int now=ed;now!=st;now=e[pre[now]^1].to)
Min=min(Min,e[pre[now]].val);
for(int now=ed;now!=st;now=e[pre[now]^1].to){
e[pre[now]].val-=Min;
e[pre[now]^1].val+=Min;
}
res1+=Min;
res2+=dis[ed]*Min;
}
printf("%d %d\n",res1,res2);
return 0;
}

dinic 费用流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include<bits/stdc++.h>
using namespace std;
const int N=5e4+10;
const int INF=0x3f3f3f3f;
#define rint register int
struct Edge{
int to,nxt,val,cost;
}e[N<<1];
int h[N],idx=1;
void Ins(rint a,rint b,rint c,rint d){
e[++idx].to=b;e[idx].nxt=h[a];h[a]=idx;
e[idx].val=c;e[idx].cost=d;
}
int s,t,dis[N];
queue<int> q;
bool vis[N];
bool spfa(){
memset(dis,0x3f,sizeof(dis));
dis[t]=0;q.push(t);
while(!q.empty()){
rint u=q.front();q.pop();vis[u]=0;
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(e[i^1].val&&dis[v]>dis[u]+e[i^1].cost){
dis[v]=dis[u]+e[i^1].cost;
if(!vis[v]){
vis[v]=1;
q.push(v);
}
}
}
}
return dis[s]!=INF;
}
int dinic(rint u,rint f){
if(u==t)return f;
rint re=f;
vis[u]=1;
for(rint i=h[u];i&&re;i=e[i].nxt){
rint v=e[i].to;
if(!vis[v]&&e[i].val&&dis[v]+e[i].cost==dis[u]){
rint k=dinic(v,min(re,e[i].val));
if(!k)dis[v]=INF;
re-=k;e[i].val-=k;e[i^1].val+=k;
}
}
vis[u]=0;
return f-re;
}
int main(){
rint n,m;
scanf("%d%d%d%d",&n,&m,&s,&t);
for(rint i=1;i<=m;i++){
rint x,y,w,c;
scanf("%d%d%d%d",&x,&y,&w,&c);
Ins(x,y,w,c);Ins(y,x,0,-c);
}
rint res1=0,res2=0;
while(spfa()){
rint now=dinic(s,INF);
res1+=now;
res2+=now*dis[s];
}
printf("%d %d\n",res1,res2);
return 0;
}

有源汇上下界最大流

构造方案,每个边的流量加上下界

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=5e5+10;
#define rint register int
#define rll register long long
#define ll long long
const int INF=0x3f3f3f3f;
struct Edge{
int to,nxt;
ll val;
}e[N<<1];
int h[N],idx;
void Ins(rint a,rint b,rint c){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;
}
int st,ed,cnt,s,t;
ll rd[N],cd[N];
int q[N],hh,tt,now[N],d[N];
bool bfs(){
hh=0;tt=-1;q[++tt]=s;
memset(d,-1,sizeof(d));
d[s]=1;now[s]=h[s];
while(hh<=tt){
rint u=q[hh++];
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(d[v]==-1&&e[i].val){
d[v]=d[u]+1;
if(v==t)return 1;
now[v]=h[v];
q[++tt]=v;
}
}
}
return 0;
}
ll dinic(rint u,rll f){
if(u==t)return f;
rll re=f;
for(rint i=now[u];i&&re;i=e[i].nxt){
rint v=e[i].to;now[u]=i;
if(d[v]==d[u]+1&&e[i].val){
rll k=dinic(v,min(e[i].val,re));
if(!k)d[v]=-1;
re-=k;e[i].val-=k;e[i^1].val+=k;
}
}
return f-re;
}
int main(){
rint n,m;
while(~scanf("%d%d",&m,&n)){
memset(h,0,sizeof(h));idx=1;cnt=n;
memset(rd,0,sizeof(rd));memset(cd,0,sizeof(cd));
st=++cnt,ed=++cnt;
rll res=0;
for(rint i=1;i<=n;i++){
rint g;
scanf("%d",&g);
Ins(i,ed,INF-g);Ins(ed,i,0);
rd[ed]+=g;cd[i]+=g;
}
for(rint i=1;i<=m;i++){
rint c,d;
scanf("%d%d",&c,&d);
++cnt;Ins(st,cnt,d);Ins(cnt,st,0);
while(c--){
rint x,l,r;
scanf("%d%d%d",&x,&l,&r);
++x;
Ins(cnt,x,r-l);Ins(x,cnt,0);
rd[x]+=l;cd[cnt]+=l;
}
}
Ins(ed,st,INF);Ins(st,ed,0);
s=++cnt;t=++cnt;
rll sd=0;
for(rint i=1;i<=cnt;i++){
if(rd[i]>cd[i])Ins(s,i,rd[i]-cd[i]),Ins(i,s,0),sd+=rd[i]-cd[i];
if(cd[i]>rd[i])Ins(i,t,cd[i]-rd[i]),Ins(t,i,0);
}
rll tot=0;
while(bfs())tot+=dinic(s,INF);
if(tot==sd){
for(rint i=h[ed];i;i=e[i].nxt){
rint v=e[i].to;
if(v==st){
res+=e[i^1].val;
e[i].val=e[i^1].val=0;
break;
}
}
s=st;t=ed;
while(bfs())res+=dinic(s,INF);
printf("%lld\n",res);
}else puts("-1");
puts("");
}
return 0;
}

无源汇上下界可行流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=5e5+10;
#define rint register int
#define rll register long long
#define ll long long
const int INF=0x3f3f3f3f;
struct Edge{
int to,nxt;
ll val;
}e[N<<1];
int h[N],idx;
void Ins(rint a,rint b,rint c){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;
}
int st,ed,cnt,s,t;
ll rd[N],cd[N];
int q[N],hh,tt,now[N],d[N];
bool bfs(){
hh=0;tt=-1;q[++tt]=s;
memset(d,-1,sizeof(d));
d[s]=1;now[s]=h[s];
while(hh<=tt){
rint u=q[hh++];
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(d[v]==-1&&e[i].val){
d[v]=d[u]+1;
if(v==t)return 1;
now[v]=h[v];
q[++tt]=v;
}
}
}
return 0;
}
ll dinic(rint u,rll f){
if(u==t)return f;
rll re=f;
for(rint i=now[u];i&&re;i=e[i].nxt){
rint v=e[i].to;now[u]=i;
if(d[v]==d[u]+1&&e[i].val){
rll k=dinic(v,min(e[i].val,re));
if(!k)d[v]=-1;
re-=k;e[i].val-=k;e[i^1].val+=k;
}
}
return f-re;
}
int main(){
rint n;
scanf("%d",&n);
idx=1;cnt=n;
st=++cnt,ed=++cnt;
rll res=0;
for(rint i=1;i<=n;i++){
Ins(st,i,INF);Ins(i,st,0);
Ins(i,ed,INF);Ins(ed,i,0);
rint c;
scanf("%d",&c);
while(c--){
rint x;
scanf("%d",&x);
Ins(i,x,INF-1);Ins(x,i,0);
cd[i]++;rd[x]++;
}
}
Ins(ed,st,INF);Ins(st,ed,0);
s=++cnt;t=++cnt;
rll sd=0;
for(rint i=1;i<=cnt;i++){
if(rd[i]>cd[i])Ins(s,i,rd[i]-cd[i]),Ins(i,s,0),sd+=rd[i]-cd[i];
if(cd[i]>rd[i])Ins(i,t,cd[i]-rd[i]),Ins(t,i,0);
}
rll tot=0;
while(bfs())tot+=dinic(s,INF);
for(rint i=h[ed];i;i=e[i].nxt){
rint v=e[i].to;
if(v==st){
res+=e[i^1].val;
e[i].val=e[i^1].val=0;
break;
}
}
s=ed;t=st;
while(bfs())res-=dinic(s,INF);
printf("%lld\n",res);
return 0;
}

有源汇最小费用可行流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<queue>
#include<cstdio>
#include<cassert>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e5+10;
const int INF=0x3f3f3f3f;
#define rint register int
#define rll register long long
#define ll long long
struct Edge{
int to,nxt,val;
ll cost;
}e[N<<1];
int h[N],idx=1;
void Ins(rint a,rint b,rint c,rll d){
e[++idx].to=b;e[idx].nxt=h[a];
h[a]=idx;e[idx].val=c;e[idx].cost=d;
}
int st,ed,mp[N],cnt;
queue<int> q;
bool inq[N];
ll dis[N];
int pre[N];
bool spfa(){
memset(dis,0x3f,sizeof(dis));
q.push(st);dis[st]=0;
while(!q.empty()){
int u=q.front();q.pop();inq[u]=0;
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(e[i].val&&dis[v]>dis[u]+e[i].cost){
dis[v]=dis[u]+e[i].cost;
pre[v]=i;
if(!inq[v]){
q.push(v);
inq[v]=1;
}
}
}
}
return dis[ed]!=dis[0];
}
ll res;
int rd[N],cd[N];
int main(){
rint n;
scanf("%d",&n);cnt=n;
st=1;ed=++cnt;
for(rint i=1;i<=n;i++){
Ins(i,ed,INF,0);Ins(ed,i,0,0);
rint k;
scanf("%d",&k);
while(k--){
rint y,c;
scanf("%d%d",&y,&c);
Ins(i,y,INF-1,c);Ins(y,i,0,-c);
cd[i]++;rd[y]++;
res+=c;
}
}
Ins(ed,st,INF,0);Ins(st,ed,0,0);
st=++cnt;
ed=++cnt;
for(rint i=1;i<=n;i++){
if(rd[i]>cd[i])Ins(st,i,rd[i]-cd[i],0),Ins(i,st,0,0);
if(cd[i]>rd[i])Ins(i,ed,cd[i]-rd[i],0),Ins(ed,i,0,0);
}
while(spfa()){
rint Min=INF;
for(rint now=ed;now!=st;now=e[pre[now]^1].to)
Min=min(Min,e[pre[now]].val);
for(rint now=ed;now!=st;now=e[pre[now]^1].to)
e[pre[now]].val-=Min,e[pre[now]^1].val+=Min;
res+=Min*dis[ed];
}
printf("%lld\n",res);
return 0;
}

有源汇上下界可行流

只需要在s和t之间连边然后跑无源汇

有源汇最小流

求ss->tt最大流
连边t->s,inf
求ss->tt最大流
答案即为边t->s,inf的实际流量

最大独立集/最小点覆盖

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# include <cstdio>
# include <algorithm>
# include <cmath>
# include <cstring>
# include <vector>

using namespace std;

const int N_MAX = 1000, E_MAX = 1000000;

struct Edge
{
int to, next;
};

int n, m, e;
vector <int> g[N_MAX + 10];

int opp[N_MAX + 10];
bool vis[N_MAX + 10];

bool vx[N_MAX + 10], vy[N_MAX + 10];

void addEdge(int x, int y)
{
g[x].push_back(y);
}

bool find(int x)
{
if (vis[x]) return false;
vis[x] = true;
for (int i = 0; i < (int) g[x].size(); i++) {
int y = g[x][i];
if (opp[y] == 0 || find(opp[y])) {
opp[y] = x;
return true;
}
}
return false;
}

void mark(int x)
{
if (vx[x]) return;
vx[x] = true;
for (int i = 0; i < (int) g[x].size(); i++) {
int y = g[x][i];
if (opp[y] && !vy[y]) {
vy[y] = true;
mark(opp[y]);
}
}
}

int hungary()
{
memset(opp, 0, sizeof(opp));
int ans = 0;
for (int i = 1; i <= n; i++) {
memset(vis, false, sizeof(vis));
ans += find(i);
}
return ans;
}

int maxIndSet()
{
int ans = hungary();
memset(vis, false, sizeof(vis));
for (int i = 1; i <= m; i++)
vis[opp[i]] = true;
memset(vx, false, sizeof(vx));
memset(vy, false, sizeof(vy));
for (int i = 1; i <= n; i++)
if (!vis[i]) mark(i);
return n + m - ans;
}

int main()
{
scanf("%d%d%d", &n, &m, &e);
for (int i = 1; i <= e; i++) {
int x, y;
scanf("%d%d", &x, &y);
addEdge(x, y);
}
printf("%d\n", maxIndSet());
for (int i = 1; i <= n; i++)
if (vx[i]) printf("%d ", i);
puts("");
for (int i = 1; i <= m; i++)
if (!vy[i]) printf("%d ", i);
puts("");
return 0;
}

最小割相关

求最小边的最小割

设总边数为 EE,跑最大流之前所有的边权都乘 E+1E+1 然后再 +1+1
得到的结果应该是 mincut(E+1)+mincut∗(E+1)+ 割边数量(这个比较显然吧)
由于割边数越小,跑出来结果越小,所以就自动选了割边数量小的边(但相同不能保证字典序)

输出任意一种最小割的方案

跑过一次最大流之后,在残量网络上,s和t之间不连通了
进行一次dfs/bfs,求出从s出发能到达的点集S,和不能到达的点集T
所有从S跨越到T的满流边(残留网络为0)构成了一组最小割

判断一条边是否满流

运行一次最大流算法,得到一个残量网络
取残量网络上的一条满流边(u, v),判断这条边是否一定满流
对残量网络运行Tarjan算法,求出所有SCC
当u和v不属于同一个SCC的时候,这条边一定满流
否则,我们可以在SCC中找到一个包含这条边的反向边的环,沿着环增广一次,仍然不破坏流量平衡,但是这条边已经不满流了

判断某一条边是否可能为最小割中的一条

所有一定满流的边都可能为最小割

判断某条边是否一定出现在最小割中

首先还是对残量网络求SCC
考虑一条满流边(u, v),判断她是否一定出现在最小割中
当u和s属于同一个SCC,并且v和t属于同一个SCC的时候,这条边一定出现在最小割中

多项式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1<<18;
const int p=998244353;
const int inv2=p+1>>1;
#define rint register int
#define rll register long long
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline int Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
inline int fpow(rint a,rint b){
rint res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
int w[N],r[N],pd;
unsigned long long A[N],t;
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd^n){
for(rint i=0;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
pd=n;
}
for(rint i=0;i<n;i++)A[i]=a[r[i]];
for(rint mid=1;mid<n;mid<<=1)
for(rint R=mid<<1,i=0;i<n;i+=R)
for(rint j=0;j<mid;j++)
t=w[mid+j]*A[i+j+mid]%p,A[i+j+mid]=A[i+j]+p-t,A[i+j]+=t;
for(rint i=0;i<n;i++)a[i]=A[i]%p;
if(typ==1)return;
rint inv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)a[i]=Mul(a[i],inv);
}
void dfs_inv(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){f.resize(1);f[0]=fpow(g[0],p-2);return;}
dfs_inv(f,g,n+1>>1);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)f[i]=Mul(f[i],Del(2,Mul(f[i],g[i])));
ntt(Max,f,-1);f.resize(n);
}
void dfs_sqrt(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){f.resize(1);f[0]=1;return;}
dfs_sqrt(f,g,n+1>>1);
vector<int> a;
dfs_inv(a,f,n);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,a,1);ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(Add(Mul(f[i],f[i]),g[i]),Mul(inv2,a[i]));
ntt(Max,f,-1);f.resize(n);
}
void Ln(vector<int> &f,vector<int> g,rint n){
g.resize(n);
dfs_inv(f,g,n);
for(rint i=1;i<n;i++)g[i-1]=Mul(g[i],i);
g[n-1]=0;
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);
f.resize(n);
for(rint i=n-1;i;i--)f[i]=Mul(f[i-1],fpow(i,p-2));
f[0]=0;
}
void Exp(vector<int> &f,vector<int> g,int n){
g.resize(n);
if(n==1){f.resize(1);f[0]=1;return;}
Exp(f,g,n+1>>1);
vector<int> a;
Ln(a,f,n);
rint Max=1;
while(Max<n+n)Max<<=1;
g.resize(Max);a.resize(Max);
for(rint i=0;i<Max;i++)a[i]=Del(g[i],a[i]);a[0]++;
ntt(Max,a,1);ntt(Max,f,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],a[i]);
ntt(Max,f,-1);f.resize(n);
}
struct Eval{
vector<int> vec[N<<2],A,res,tmp;
vector<int> Mult(vector<int> a,vector<int> b,rint n){
rint Max=a.size();
for(int i=0;i<Max;i++)
a[i]=Mul(a[i],b[i]);
ntt(Max,a,1);a.resize(n);
return a;
}
void dfs1(rint x,rint l,rint r){
if(l==r){
vec[x].resize(2);
vec[x][0]=1;
vec[x][1]=p-tmp[l];
return;
}
rint mid=l+r>>1,L=x<<1,R=x<<1|1;
dfs1(L,l,mid);dfs1(R,mid+1,r);
rint Max=1,n=r-l+2;
while(Max<n)Max<<=1;vec[x].resize(Max);
ntt(Max,vec[L],1);ntt(Max,vec[R],1);
for(rint i=0;i<Max;i++)
vec[x][i]=Mul(vec[L][i],vec[R][i]);
ntt(Max,vec[x],-1);vec[x].resize(n);
}
void dfs2(rint x,rint l,rint r,vector<int> v){
rint n=r-l+2;
if(l==r){
res[l]=v[0];
return;
}
rint Max=1,mid=l+r>>1;
while(Max<n)Max<<=1;
ntt(Max,v,-1);
dfs2(x<<1,l,mid,Mult(v,vec[x<<1|1],mid-l+1));
dfs2(x<<1|1,mid+1,r,Mult(v,vec[x<<1],r-mid));
}
vector<int> query(vector<int> A,vector<int> a){
rint n=A.size(),m=a.size()-1;res.resize(m+1);tmp=a;
dfs1(1,1,m);
dfs_inv(vec[1],vec[1],n);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,vec[1],1);
ntt(Max,A,-1);
dfs2(1,1,m,Mult(A,vec[1],n));
return res;
}
}T;
struct Fin{
vector<int> vec[N<<2],vec2[N<<2],tmp,tmp2;
void dfs1(rint x,rint l,rint r){
if(l==r){
vec[x].resize(2);
vec[x][0]=p-tmp[l];
vec[x][1]=1;
return;
}
rint mid=l+r>>1,L=x<<1,R=x<<1|1,n=r-l+2;
dfs1(L,l,mid);dfs1(R,mid+1,r);
rint Max=1;
while(Max<n)Max<<=1;
ntt(Max,vec[L],1);ntt(Max,vec[R],1);
vec[x].resize(Max);
for(rint i=0;i<Max;i++)
vec[x][i]=Mul(vec[L][i],vec[R][i]);
ntt(Max,vec[x],-1);
vec[x].resize(n);
}
void dfs2(rint x,rint l,rint r){
if(l==r){
vec2[x].resize(1);
vec2[x][0]=Mul(tmp2[l],fpow(tmp[l],p-2));
return;
}
rint mid=l+r>>1,L=x<<1,R=x<<1|1,n=r-l+2;
dfs2(x<<1,l,mid);dfs2(x<<1|1,mid+1,r);
rint Max=1;
while(Max<n)Max<<=1;
vec2[x].resize(Max);
ntt(Max,vec2[L],1);ntt(Max,vec2[R],1);
for(rint i=0;i<Max;i++)
vec2[x][i]=Add(Mul(vec2[L][i],vec[R][i]),Mul(vec2[R][i],vec[L][i]));
ntt(Max,vec2[x],-1);vec2[x].resize(n-1);
}
vector<int> query(vector<int> x,vector<int> y){
int n=x.size()-1;tmp=x;tmp2=y;
dfs1(1,1,n);
for(rint i=1;i<=n;i++)
vec[1][i-1]=Mul(vec[1][i],i);
vec[1][n]=0;
tmp=T.query(vec[1],x);
dfs2(1,1,n);
return vec2[1];
}
}F;
int main(){
rint n,m;
w[N>>1]=1;w[N/2+1]=n=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],n);
for(rint i=N/2-1;~i;i--)w[i]=w[i<<1];
scanf("%d",&n);
vector<int> f,g;++n;
f.resize(n);g.resize(n);
for(rint i=1;i<n;i++)
scanf("%d%d",&f[i],&g[i]);
f=F.query(f,g);
for(rint v:f)printf("%d ",v);
return 0;
}

多项式求逆

给定多项式 FF ,求一个 GG 使得 FG1(modxn)F*G\equiv1 \pmod {x^n}

多项式的模意义是指次数大于等于 nn 的项都当成 00,考虑倍增,设有多项式 AA 满足 FA1(modxn2)F*A\equiv 1 \pmod {x^{\lceil{\frac{n}2}\rceil}} ,上取整的意思为了让它平方后大于等于nn的项都当成00

显然有 FG1(modxn2)F*G\equiv1 \pmod {x^{\lceil{\frac{n}2}\rceil}} ,根据模的定义可知。

所以有 F(GA)0(modxn2)F*(G-A)\equiv 0 \pmod{x^{\lceil{\frac{n}2}\rceil}}FF 不可能为 00,于是 (GA)20(modxn)(G-A)^2\equiv0 \pmod{x^n},继续推一发,得到 G22AG+A20(modxn)G^2-2AG+A^2\equiv 0 \pmod{x^n},两边同时乘 FF 得到 G2A+A2F0(modxn)G-2A+A^2F\equiv 0 \pmod{x^n},所以推出 G2AA2FG\equiv 2A-A^2F,只有常数项的时候求逆元即可,分治即可求出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void dfs_inv(rint n,vector<int> &f,vector<int> g){
g.resize(n);
if(n==1){
f.resize(1);
f[0]=fpow(g[0],p-2);
return ;
}
dfs_inv(n+1>>1,f,g);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],Del(2,Mul(f[i],g[i])));
ntt(Max,f,-1);
f.resize(n);
}

多项式开根

求一个多项式 FF 满足 F2=GF^2=G,若 FF 有两个,一般取正的。

仍然使用倍增,设 F2G(modx2n)F^2\equiv G\pmod {x^{2n}},假设已经求出了 A2G(modxn)A^2\equiv G\pmod {x^n} ,那么有 F2A20(modxn)F^2-A^2\equiv0 \pmod {x^n}

易得

(F+A)(FA)0(modxn)(F+A)(F-A)\equiv0 \pmod {x^n}

(FA)0(modxn)(F-A)\equiv0 \pmod {x^n}

(FA)20(modx2n)(F-A)^2\equiv0 \pmod {x^{2n}}

F2+A22AF0(modx2n)F^2+A^2-2AF\equiv0 \pmod {x^{2n}}

G+A22AF0(modx2n)G+A^2-2AF\equiv0 \pmod {x^{2n}}

FG+A22A(modx2n)F\equiv \frac{G+A^2}{2A}\pmod {x^{2n}}

在只有常数项的时候可以直接求出。

若常数项为 1100 ,开根后较为显然,但是有些毒瘤出题人并不喜欢简单,所以对于模意义下开根就比较重要 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void dfs_sqr(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){
f.resize(1);f[0]=1;
return;
}
dfs_sqr(f,g,n+1>>1);
vector<int> a;
dfs_inv(a,f,n);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);ntt(Max,a,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(Add(Mul(f[i],f[i]),g[i]),Mul(inv2,a[i]));
ntt(Max,f,-1);f.resize(n);
}

二次剩余

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
#define rint register int
#define rll register long long
ll I,n,p;
inline ll Add(rll x,rll y){return (x+y)%p;}
inline ll Del(rll x,rll y){return ((x-y)%p+p)%p;}
inline ll Mul(rll x,rll y){x*=y;return x>=p?x%p:x;}
struct cp{
ll x,y;
cp operator * (const cp&A)const{return (cp){Add(Mul(x,A.x),Mul(Mul(I,y),A.y)),Add(Mul(x,A.y),Mul(y,A.x))};}
};
ll fpow(rll a,rll b){
rll res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
cp fpow(cp a,rll b){
cp res=(cp){1,0};
for(;b;b>>=1,a=a*a)
if(b&1)res=res*a;
return res;
}
inline bool check(rll x){
if(fpow(x,p-1>>1)==1)return 1;
return 0;
}
int main(){
int T;
scanf("%d",&T);
while(T--){
scanf("%lld%lld",&n,&p);
if(n==0)printf("0\n");
else if(!check(n))printf("Hola!\n");
else {
rll a;I=0;
while(I==0||check(I))a=rand()%p,I=Del(Mul(a,a),n);
rll res1=fpow((cp){a,1},p+1>>1).x;
rll res2=p-res1;
if(res1==res2)printf("%lld ",res1);
else printf("%lld %lld\n",min(res1,res2),max(res1,res2));
}
}
return 0;
}

多项式除法

给定多项式 F,GF,G,求两个多项式 A,BA,B 满足 F=AG+BF=A*G+B

其中 F,G,AF,G,A 的项数分别为 n,m,nmn,m,n-m,要求 BB 的项数小于 mm

定义一种变换,FR(x)=xnF(1x)F_{R(x)}=x^n F(\frac{1}x),容易发现这种变换就是将FF的系数reverse了一下。

F=AG+BF=A*G+B

F(1x)=A(1x)G(1x)+B(1x)F(\frac{1}x)=A(\frac{1}x)G(\frac{1}x)+B(\frac{1}x)

xnF(1x)=xnA(1x)G(1x)+xnB(1x)x^nF(\frac{1}x)=x^nA(\frac{1}x)G(\frac{1}x)+x^nB(\frac{1}x)

xnF(1x)=xnmA(1x)xmG(1x)+xnm+1xm1B(1x)x^nF(\frac{1}x)=x^{n-m}A(\frac{1}x)x^mG(\frac{1}x)+x^{n-m+1}x^{m-1}B(\frac{1}x)

FR(x)=AR(x)GR(x)+xnm+1BR(x)F_{R(x)}=A_{R(x)}G_{R(x)}+x^{n-m+1}B_{R(x)}

模一下 xnm+1x^{n-m+1} 就能得到

FR(x)AR(x)GR(x)(modxnm+1)F_{R(x)}\equiv A_{R(x)}G_{R(x)} \pmod {x^{n-m+1}}

AR(x)FR(x)GR(x)1(modxnm+1)A_{R(x)}\equiv F_{R(x)}G^{-1}_{R(x)} \pmod {x^{n-m+1}}

多项式求逆即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=3e5+10;
const int p=998244353;
typedef long long ll;
ll fpow(ll a,ll b){
ll res;
for(res=1;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
ll tmp[2][N],F[N],G[N],A[N],B[N];
int r[N];
void ntt(int n,ll *a,int typ){
for(int i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){
ll Wn=fpow(3,(p-1)/(mid*2));
if(typ==-1)Wn=fpow(Wn,p-2);
for(int i=0,R=mid<<1;i<n;i+=R){
ll w=1;
for(int j=0;j<mid;j++,w=w*Wn%p){
const ll x=a[i+j],y=w*a[i+j+mid]%p;
a[i+j]=(x+y)%p;
a[i+j+mid]=(x-y+p)%p;
}
}
}
}
void dfs(int n){
if(n==1){B[0]=fpow(G[0],p-2);return;}
dfs(n+1>>1);
int Max=1,l=0;
while(Max<n+n)Max<<=1,l++;
for(int i=0;i<Max;i++){
r[i]=(r[i>>1]>>1)|((i&1)<<l-1);
A[i]=i<n?G[i]:0;
}
ntt(Max,A,1);ntt(Max,B,1);
for(int i=0;i<Max;i++)
B[i]=(2ll*B[i]%p-A[i]*B[i]%p*B[i]%p+p)%p;
ntt(Max,B,-1);
ll inv=fpow(Max,p-2);
for(int i=0;i<Max;i++)
B[i]=i<n?B[i]*inv%p:0;
}
int main(){
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)scanf("%lld",&tmp[0][i]),F[n-i]=tmp[0][i];
for(int i=0;i<=m;i++)scanf("%lld",&tmp[1][i]),G[m-i]=tmp[1][i];
for(int i=n-m+2;i<=m;i++)G[i]=0;
for(int i=n-m+2;i<=n;i++)F[i]=0;
dfs(n-m+1);
int Max=1,l=0;
while(Max<=2*n-2*m+4)Max<<=1,l++;
for(int i=0;i<Max;i++)r[i]=(r[i>>1]>>1)|((i&1)<<l-1);
ntt(Max,F,1);ntt(Max,B,1);
for(int i=0;i<Max;i++)
F[i]=F[i]*B[i]%p;
ntt(Max,F,-1);
ll inv=fpow(Max,p-2);
for(int i=0;i<Max;i++)
A[i]=(i<=n-m)?F[i]*inv%p:0;
reverse(A,A+n-m+1);
memcpy(B,A,sizeof(A));
memcpy(F,tmp[0],sizeof(tmp[0]));
memcpy(G,tmp[1],sizeof(tmp[1]));
Max=1,l=0;
while(Max<=n)Max<<=1,l++;
for(int i=0;i<Max;i++)r[i]=(r[i>>1]>>1)|((i&1)<<l-1);
ntt(Max,B,1);ntt(Max,G,1);
for(int i=0;i<Max;i++)
B[i]=B[i]*G[i]%p;
ntt(Max,B,-1);
inv=fpow(Max,p-2);
for(int i=0;i<Max;i++)
B[i]=B[i]*inv%p;
for(int i=0;i<=m;i++)
B[i]=(F[i]-B[i]+p)%p;
for(int i=0;i<=n-m;i++)
printf("%lld ",A[i]);
puts("");
for(int i=0;i<m;i++)
printf("%lld ",B[i]);
return 0;
}

多项式求导和积分

求导,[xi1]F=[xi]F×i[x^{i-1}]F'=[x^i]F\times i

积分,积分是求导的逆运算,对一个多项式积分的结果表示哪个多项式求导得到的是这个多项式

[xi]F=[xi1]Fi[x^i]\int F'=\frac{[x^{i-1}]F'}i

多项式ln

求一个多项式 FF 满足 G=lnFG=\ln F

两边求导得到 G=lnFFG'=\ln'F F'

ln\ln的导数可以知道lnF=1F\ln'F=\frac{1}F

所以 G=FFG'=\frac{F'}F

于是 G=FFG=\int\frac{F'}F

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=1<<18;
const int p=998244353;
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline ll Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
ll fpow(rll a,rll b){
rll res=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
int r[N],w[N],t,pd;
void ntt(rint n,vector<int> &a,rint typ){
if(pd!=n){
for(rint i=0;i<n;i++)
r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
pd=n;
}
a.resize(n);
for(rint i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(rint mid=1;mid<n;mid<<=1)
for(rint i=0;i<n;i+=mid<<1)
for(rint j=0;j<mid;j++)
t=Mul(w[mid+j],a[i+j+mid]),a[i+j+mid]=Del(a[i+j],t),a[i+j]=Add(a[i+j],t);
if(~typ)return ;
t=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)
a[i]=Mul(a[i],t);
}
void dfs_inv(rint n,vector<int> &f,vector<int> g){
g.resize(n);
if(n==1){
f.resize(1);
f[0]=fpow(g[0],p-2);
return ;
}
dfs_inv(n+1>>1,f,g);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],Del(2,Mul(f[i],g[i])));
ntt(Max,f,-1);
f.resize(n);
}
void Ln(rint n,vector<int> &f,vector<int> g){
g.resize(n);
f=g;
dfs_inv(n,g,g);
for(rint i=1;i<n;i++)
f[i-1]=Mul(f[i],i);
f[n-1]=0;
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);
for(rint i=n-1;i;i--)
f[i]=Mul(f[i-1],fpow(i,p-2));
f[0]=0;
f.resize(n);
}
int main(){
w[N/2]=1;w[N/2+1]=t=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],t);
for(rint i=N/2-1;i;i--)w[i]=w[i<<1];
rint n;
scanf("%d",&n);
vector<int> f;
f.resize(n);
for(rint i=0;i<n;i++)
scanf("%d",&f[i]);
Ln(n,f,f);
for(rint i=0;i<n;i++)
printf("%d ",f[i]);
return 0;
}

多项式exp

G(x)ef(x)(modxn),G1(x)ef(x)(modxn2)G(x)\equiv e^{f(x)}\pmod {x^n},G_1(x)\equiv e^{f(x)} \pmod {x^{\frac{n}2}}

GG1(1lnG1+F)(modxn)G\equiv G_1(1-lnG_1+F) \pmod {x^n}

证明略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1<<18;
const int p=998244353;
#define rint register int
#define rll register long long
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline int Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
inline int fpow(rint a,rint b){
rint res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
int w[N],r[N],pd,t;
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd^n){
for(rint i=0;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
pd=n;
}
for(rint i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(rint mid=1;mid<n;mid<<=1)
for(rint R=mid<<1,i=0;i<n;i+=R)
for(rint j=0;j<mid;j++)
t=Mul(w[mid+j],a[i+j+mid]),a[i+j+mid]=Del(a[i+j],t),a[i+j]=Add(a[i+j],t);
if(typ==1)return;
rint inv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)a[i]=Mul(a[i],inv);
}
void dfs_inv(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){f.resize(1);f[0]=fpow(g[0],p-2);return;}
dfs_inv(f,g,n+1>>1);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)f[i]=Mul(f[i],Del(2,Mul(f[i],g[i])));
ntt(Max,f,-1);f.resize(n);
}
void Ln(vector<int> &f,vector<int> g,rint n){
g.resize(n);
dfs_inv(f,g,n);
for(rint i=1;i<n;i++)g[i-1]=Mul(g[i],i);
g[n-1]=0;
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);
f.resize(n);
for(rint i=n-1;i;i--)f[i]=Mul(f[i-1],fpow(i,p-2));
f[0]=0;
}
void Exp(vector<int> &f,vector<int> g,int n){
g.resize(n);
if(n==1){f.resize(1);f[0]=1;return;}
Exp(f,g,n+1>>1);
vector<int> a;
Ln(a,f,n);
rint Max=1;
while(Max<n+n)Max<<=1;
g.resize(Max);a.resize(Max);
for(rint i=0;i<Max;i++)a[i]=Del(g[i],a[i]);a[0]++;
ntt(Max,a,1);ntt(Max,f,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],a[i]);
ntt(Max,f,-1);f.resize(n);
}
int main(){
rint n;
w[N>>1]=1;w[N/2+1]=n=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],n);
for(rint i=N/2-1;~i;i--)w[i]=w[i<<1];
vector<int> f;
scanf("%d",&n);f.resize(n);
for(rint i=0;i<n;i++)
scanf("%d",&f[i]);
Exp(f,f,n);
for(rint v:f)printf("%d ",v);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include<algorithm>
#include<cstring>
#include<cstdio>
#define ll long long
#define ull unsigned long long
#define clr(f,n) memset(f,0,sizeof(int)*(n))
#define cpy(f,g,n) memcpy(f,g,sizeof(int)*(n))
using namespace std;
const int _G=3,mod=998244353,Maxn=1050000;
inline int read(){
int X=0;char ch=0;
while(ch<48||ch>57)ch=getchar();
while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
return X;
}
ll powM(ll a,int t=mod-2){
ll ans=1;
while(t){
if(t&1)ans=ans*a%mod;
a=a*a%mod;t>>=1;
}return ans;
}
const int invG=powM(_G);
int tr[Maxn<<1],tf;
void tpre(int n){
if (tf==n)return ;tf=n;
for(int i=0;i<n;i++)
tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
}
void NTT(int *g,bool op,int n)
{
tpre(n);
static ull f[Maxn<<1],w[Maxn<<1];w[0]=1;
for (int i=0;i<n;i++)f[i]=g[tr[i]];
for(int l=1;l<n;l<<=1){
ull tG=powM(op?_G:invG,(mod-1)/(l+l));
for (int i=1;i<l;i++)w[i]=w[i-1]*tG%mod;
for(int k=0;k<n;k+=l+l)
for(int p=0;p<l;p++){
int tt=w[p]*f[k|l|p]%mod;
f[k|l|p]=f[k|p]+mod-tt;
f[k|p]+=tt;
}
if (l==(1<<17))
for (int i=0;i<n;i++)f[i]%=mod;
}if (!op){
ull invn=powM(n);
for(int i=0;i<n;++i)
g[i]=f[i]%mod*invn%mod;
}else for(int i=0;i<n;++i)g[i]=f[i]%mod;
}
void px(int *f,int *g,int n)
{for(int i=0;i<n;++i)f[i]=1ll*f[i]*g[i]%mod;}
int F[Maxn],G[Maxn],s1[Maxn],s2[Maxn];
void cdq(int l,int r)
{
if (l+1==r){
if (l>0)G[l]=powM(l)*G[l]%mod;
return ;
}int mid=(l+r)>>1,n=r-l;
cdq(l,mid);
cpy(s1,G+l,n/2);clr(s1+(n/2),n/2);NTT(s1,1,n);
cpy(s2,F,n);NTT(s2,1,n);
px(s1,s2,n);NTT(s1,0,n);
for (int i=n/2;i<n;i++)
G[l+i]=(G[l+i]+s1[i])%mod;
cdq(mid,r);
}
int n,m;
int main()
{
m=read();G[0]=1;
for (int i=0;i<m;i++)F[i]=1ll*read()*i%mod;
for (n=1;n<m;n<<=1);
cdq(0,n);
for (int i=0;i<m;i++)
printf("%d ",G[i]);
return 0;
}

泰勒展开

对于一个函数 FFF(x)F(x) 的值可以近似看作在 x0x_0 处泰勒展开的值,即

F(x)=i=0+Fx0(i)(xx0)ii!F(x)=\sum_{i=0}^{+\infty} \frac{F^{(i)}_{x_0}(x-x_0)^i}{i!}

这个式子看上去貌似到正无穷没有办法求,但是由于它递减的很快所以基本上只需要求前几项即可。

牛顿迭代

上边的式子虽然易懂但是比较繁琐,因为对于不同的东西需要用到不同的式子,考虑更加通用的办法——牛顿迭代。

考虑如果求出了 F(x)0(modxn2)F(x) \equiv 0 \pmod {x^{\lceil\frac{n}2\rceil}} 的一个解 x0x_0 ,如何求出另一个解 xx 满足 F(x)0(modxn)F(x) \equiv 0 \pmod {x^{n}}

F(x)F(x)x0x_0 处泰勒展开,容易得到

F(x)=Fx0+Fx0(xx0)+....F(x)=F_{x_0}+F'_{x_0}(x-x_0)+....

注意到后边写了省略号,并不是我懒得写了,而是因为没有用了,因为 xxx0x_0 在模 xn{x^n} 的意义下,前一半项是一样的,所以 xx0x-x_0 的最低次项也是 n2\frac{n}2 即大于 22 的次幂全是 00

于是有

Fx0+Fx0(xx0)0(modxn)F_{x_0}+F'_{x_0}(x-x_0) \equiv 0 \pmod {x^n}

Fx0xFx0x0Fx0(modxn)F'_{x_0}x\equiv F'_{x_0}x_0-F_{x_0} \pmod {x^n}

xx0Fx0Fx0(modxn)x\equiv x_0-\frac{F_{x_0}}{F'_{x_0}} \pmod {x^n}

这样推式子就简单的多了,更具体的,只需要构造一个函数 F(x)0(modxn)F(x)\equiv0 \pmod {x^n} 即可。

例如多项式乘法逆,构造函数 F(x)=1xGF(x)=\frac{1}x-G ,其中 GG 为给定的多项式 。

那么我们要求的实际上就是 F(x)0(modxn)F(x)\equiv0 \pmod {x^n} 的一个解。

直接代牛顿迭代,有

xx01x0G1x02(modxn)x\equiv x_0 - \frac{\frac{1}{x_0}-G}{-\frac{1}{x_0^2}} \pmod {x^n}

x2x0Gx02(modxn)x\equiv 2x_0-Gx_0^2 \pmod {x^n}

多项式开根,构造函数F(x)=x2GF(x)=x^2-G ,其中 GG 为给定多项式。

代牛顿迭代,得到

xx0x02G2x0(modxn)x\equiv x_0-\frac{x_0^2-G}{2x_0} \pmod {x^n}

xx02+G2x0(modxn)x\equiv \frac{x_0^2+G}{2x_0} \pmod {x^n}

简单易懂

任意模数多项式乘法

总有一些毒瘤出题人喜欢把模数搞得奇怪,让你没办法做 NTTNTT ,这时候可以用 CRTCRT ,但是常数过于大,就不大好玩 。

考虑一个古老的算法,其实也不是很古老,只是一段时间没怎么用过,FFTFFT

因为 FFTFFT 是不用取模的,所以可以直接乘然后最后一起取模,不过模数很大容易炸精,所以要拆系数,不妨令 [xi]F=a0×T+b0[x^i]F=a_0\times T +b_0 ,然后分别求出a,ba,b最后合并即可,即分别求出a0a1,a0b1+a1b0,b0b1a_0a_1,a_0b_1+a_1b_0,b_0b_1,这样做大概需要七次 FFTFFT ,还是太慢了。

于是使用魔法,使得能够快速求出要求的几项。

P=a+bi,Q=abiP=a+bi,Q=a-bi,显然他们的系数是共轭的,有一个结论

Q0=P0Q_0=P_0,其余的Qi=PniQ_i=P_{n-i},即Qi=P(n1)&(ni)Q_i=P_{(n-1)\&(n-i)},所以只需要求出 PP,然后就有了 QQ,这样就把两次 NTTNTT 合并成了一次 。此时再暴力求出剩下的几项只要五次,考虑继续优化。

A=a0a1+b0b1i,B=a0b1+a1b0A=a_0a_1+b_0b_1i,B=a_0b_1+a_1b_0,这样只需要做两次 IDFTIDFT 即可,所以就可以写出四次的 MTTMTT 了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=4e5+10;
const int M=(1<<15)-1;
typedef double db;
const db pi=acos(-1.0);
struct complex{
double x,y;
complex operator + (const complex&A)const{return (complex){x+A.x,y+A.y};}
complex operator - (const complex&A)const{return (complex){x-A.x,y-A.y};}
complex operator * (const complex&A)const{return (complex){x*A.x-y*A.y,x*A.y+y*A.x};}
friend complex operator ~ (const complex&A){return (complex){A.x,-A.y};}
}A[N],B[N],C[N],D[N],a0[N],a1[N],b0[N],b1[N],w[N],tmp;
int r[N];
void fft(int n,complex *a,int typ=1){
for(int i=0;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1,t0=n>>1;mid<n;mid<<=1,t0>>=1)
for(int i=0,R=mid<<1;i<n;i+=R)
for(int j=0,t=0;j<mid;j++,t+=t0)
tmp=w[t]*a[i+j+mid],a[i+j+mid]=a[i+j]-tmp,a[i+j]=a[i+j]+tmp;
if(typ)return;
reverse(a+1,a+n);
for(int i=0;i<n;i++)a[i]=a[i]*(complex){1.0/n,0.0};
}
int main(){
int n,m,p;
scanf("%d%d%d",&n,&m,&p);
for(int i=0;i<=n;i++){
int x;
scanf("%d",&x);
a0[i].x=x>>15;
a1[i].x=x&M;
}
for(int i=0;i<=m;i++){
int x;
scanf("%d",&x);
b0[i].x=x>>15;
b1[i].x=x&M;
}
for(int i=0;i<=n;i++)A[i].x=a0[i].x,A[i].y=a1[i].x;
for(int i=0;i<=m;i++)B[i].x=b0[i].x,B[i].y=b1[i].x;
int Max=1,l=0;
while(Max<=n+m)Max<<=1,l++;
for(int i=0;i<Max;i++){
r[i]=(r[i>>1]>>1)|((i&1)<<l-1);
w[i]=(complex){cos(2*i*pi/Max),sin(2*i*pi/Max)};
}
fft(Max,A);fft(Max,B);
for(int i=0;i<Max;i++){
tmp=~A[(Max-1)&(Max-i)];
a0[i]=(A[i]+tmp)*(complex){0.5,0.0};
a1[i]=(tmp-A[i])*(complex){0.0,0.5};
tmp=~B[(Max-1)&(Max-i)];
b0[i]=(B[i]+tmp)*(complex){0.5,0.0};
b1[i]=(tmp-B[i])*(complex){0.0,0.5};
}
for(int i=0;i<Max;i++){
A[i]=a0[i]*b0[i]+a1[i]*b1[i]*(complex){0.0,1.0};
B[i]=a1[i]*b0[i]+a0[i]*b1[i];
}
fft(Max,A,0);fft(Max,B,0);
for(int i=0;i<=n+m;i++){
long long w1=(long long)(A[i].x+0.5)%p;
long long w2=(long long)(B[i].x+0.5)%p;
long long w3=(long long)(A[i].y+0.5)%p;
printf("%lld ",((w1<<30)+(w2<<15)+w3)%p);
}
return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <algorithm>
#include <cstdio>
#include <cstring>
int mod;
namespace Math {
inline int pw(int base, int p, const int mod) {
static int res;
for (res = 1; p; p >>= 1, base = static_cast<long long> (base) * base % mod) if (p & 1) res = static_cast<long long> (res) * base % mod;
return res;
}
inline int inv(int x, const int mod) { return pw(x, mod - 2, mod); }
}

const int mod1 = 998244353, mod2 = 1004535809, mod3 = 469762049, G = 3;
const long long mod_1_2 = static_cast<long long> (mod1) * mod2;
const int inv_1 = Math::inv(mod1, mod2), inv_2 = Math::inv(mod_1_2 % mod3, mod3);
struct Int {
int A, B, C;
explicit inline Int() { }
explicit inline Int(int __num) : A(__num), B(__num), C(__num) { }
explicit inline Int(int __A, int __B, int __C) : A(__A), B(__B), C(__C) { }
static inline Int reduce(const Int &x) {
return Int(x.A + (x.A >> 31 & mod1), x.B + (x.B >> 31 & mod2), x.C + (x.C >> 31 & mod3));
}
inline friend Int operator + (const Int &lhs, const Int &rhs) {
return reduce(Int(lhs.A + rhs.A - mod1, lhs.B + rhs.B - mod2, lhs.C + rhs.C - mod3));
}
inline friend Int operator - (const Int &lhs, const Int &rhs) {
return reduce(Int(lhs.A - rhs.A, lhs.B - rhs.B, lhs.C - rhs.C));
}
inline friend Int operator * (const Int &lhs, const Int &rhs) {
return Int(static_cast<long long> (lhs.A) * rhs.A % mod1, static_cast<long long> (lhs.B) * rhs.B % mod2, static_cast<long long> (lhs.C) * rhs.C % mod3);
}
inline int get() {
long long x = static_cast<long long> (B - A + mod2) % mod2 * inv_1 % mod2 * mod1 + A;
return (static_cast<long long> (C - x % mod3 + mod3) % mod3 * inv_2 % mod3 * (mod_1_2 % mod) % mod + x) % mod;
}
} ;

#define maxn 131072

namespace Poly {
#define N (maxn << 1)
int lim, s, rev[N];
Int Wn[N | 1];
inline void init(int n) {
s = -1, lim = 1; while (lim < n) lim <<= 1, ++s;
for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
const Int t(Math::pw(G, (mod1 - 1) / lim, mod1), Math::pw(G, (mod2 - 1) / lim, mod2), Math::pw(G, (mod3 - 1) / lim, mod3));
*Wn = Int(1); for (register Int *i = Wn; i != Wn + lim; ++i) *(i + 1) = *i * t;
}
inline void NTT(Int *A, const int op = 1) {
for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
for (register int mid = 1; mid < lim; mid <<= 1) {
const int t = lim / mid >> 1;
for (register int i = 0; i < lim; i += mid << 1) {
for (register int j = 0; j < mid; ++j) {
const Int W = op ? Wn[t * j] : Wn[lim - t * j];
const Int X = A[i + j], Y = A[i + j + mid] * W;
A[i + j] = X + Y, A[i + j + mid] = X - Y;
}
}
}
if (!op) {
const Int ilim(Math::inv(lim, mod1), Math::inv(lim, mod2), Math::inv(lim, mod3));
for (register Int *i = A; i != A + lim; ++i) *i = (*i) * ilim;
}
}
#undef N
}

int n, m;
Int A[maxn << 1], B[maxn << 1];
int main() {
scanf("%d%d%d", &n, &m, &mod); ++n, ++m;
for (int i = 0, x; i < n; ++i) scanf("%d", &x), A[i] = Int(x % mod);
for (int i = 0, x; i < m; ++i) scanf("%d", &x), B[i] = Int(x % mod);
Poly::init(n + m);
Poly::NTT(A), Poly::NTT(B);
for (int i = 0; i < Poly::lim; ++i) A[i] = A[i] * B[i];
Poly::NTT(A, 0);
for (int i = 0; i < n + m - 1; ++i) {
printf("%d", A[i].get());
putchar(i == n + m - 2 ? '\n' : ' ');
}
return 0;
}

多项式快速幂

给定多项式 GG, 求一个 FF ,使得 F=GkF=G^k

比较好推,lnF=klnG\ln F=k\ln G ,F=eklnGF=e^{k\ln G}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=4e5+10;
const int p=998244353;
typedef long long ll;
int cnt[N];
ll A[N],B[N],inv[N],w[25][N];
ll Add(ll x,ll y){
x+=y;return x>=p?x-p:x;
}
ll Del(ll x,ll y){
x-=y;return x>=0?x:x+p;
}
ll Mul(ll x,ll y){
x*=y;return x>=p?x%p:x;
}
ll fpow(ll a,ll b){
ll res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
int r[N];
ll t;
void ntt(int n,ll *a,int typ){
for(int i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1,t0=0;mid<n;t0++,mid<<=1)
for(int i=0,R=mid<<1;i<n;i+=R)
for(int j=0;j<mid;j++){
t=Mul(a[i+j+mid],w[t0][j]);
a[i+j+mid]=Del(a[i+j],t);
a[i+j]=Add(a[i+j],t);
}
if(typ==1)return;
ll inv=fpow(n,p-2);
reverse(a+1,a+n);
for(int i=0;i<n;i++)a[i]=Mul(a[i],inv);
}
struct Poly{
ll A[N],B[N];
void dfs_inv(ll *a,ll *b,int n){
if(n==1){a[0]=fpow(b[0],p-2);return;}
dfs_inv(a,b,n+1>>1);
int Max=1,l=0;
while(Max<n+n)Max<<=1,l++;
for(int i=0;i<Max;i++){
r[i]=(r[i>>1]>>1)|((i&1)<<l-1);
A[i]=i<n?b[i]:0;
}
ntt(Max,A,1);ntt(Max,a,1);
for(int i=0;i<Max;i++)a[i]=Del(Mul(2,a[i]),Mul(Mul(a[i],a[i]),A[i]));
ntt(Max,a,-1);
for(int i=n;i<Max;i++)a[i]=0;
}
void Ln(ll *a,ll *b,int n){
dfs_inv(B,b,n);
memcpy(A,b,sizeof(A));
for(int i=1;i<n;i++)
A[i-1]=Mul(i,A[i]);
A[n-1]=0;
int Max=1,l=0;
while(Max<n+n)Max<<=1,l++;
for(int i=0;i<Max;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<l-1),a[i]=0;
ntt(Max,A,1);ntt(Max,B,1);
for(int i=0;i<Max;i++)
A[i]=Mul(A[i],B[i]),B[i]=0;
ntt(Max,A,-1);
for(int i=n-1;i;i--)
A[i]=Mul(A[i-1],fpow(i,p-2));
A[0]=0;
for(int i=0;i<n;i++)
a[i]=A[i];
}
}T;
ll C[N];
void Exp(ll *a,ll *b,int n){
if(n==1){a[0]=1;return;}
Exp(a,b,n+1>>1);
T.Ln(C,a,n);
int Max=1,l=0;
while(Max<n+n)Max<<=1,l++;
for(int i=0;i<Max;i++){
r[i]=(r[i>>1]>>1)|((i&1)<<l-1);
if(i<n)C[i]=Del(b[i],C[i]);
}
C[0]++;
ntt(Max,a,1);ntt(Max,C,1);
for(int i=0;i<Max;i++)a[i]=Mul(a[i],C[i]);
ntt(Max,a,-1);
for(int i=n;i<Max;i++)a[i]=0;
}
ll D[N];
int main(){
int n;
scanf("%d",&n);
char ch=getchar();int k=0;
while(ch<'0'||ch>'9')ch=getchar();
while(ch<='9'&&ch>='0')k=Add(ch-'0',Mul(k,10)),ch=getchar();
for(int i=0;i<n;i++)scanf("%lld",&A[i]);
for(int i=0,k=1;k<=n+n;k<<=1,i++){
w[i][0]=1;w[i][1]=fpow(3,(p-1)/(k*2));
for(int j=2;j<k;j++)
w[i][j]=Mul(w[i][j-1],w[i][1]);
}
T.Ln(B,A,n);
for(int i=0;i<n;i++)B[i]=Mul(B[i],k);
Exp(D,B,n);
for(int i=0;i<n;i++)
printf("%lld ",D[i]);
return 0;
}

多项式复合

给定两个多项式 F,GF,G ,求一个多项式H(x)F(G(x))(modxn)H(x)\equiv F(G(x)) \pmod {x^n}

考虑暴力做,依次考虑 GiG^i 的贡献,与 FF 的系数相乘后加入到答案里边,时间复杂度O(n2logn)O(n^2logn),常数略大。

事实上没有必要把他们都乘出来

H=i=0n1[xi]FGiH=\sum_{i=0}^{n-1}[x^i]FG^i

H=i=0Lj=0L1[xiL+j]FGiL+jH=\sum_{i=0}^{L}\sum_{j=0}^{L-1}[x^{iL+j}]FG^{iL+j}

H=i=0LGiLj=0L1[xiL+j]FGjH=\sum_{i=0}^{L}G^{iL}\sum_{j=0}^{L-1}[x^{iL+j}]FG^{j}

L=nL=\sqrt n,这样只需要预处理出GiL,GiG^{iL},G^{i}即可,时间复杂度O(n2+nnlogn)O(n^2+n\sqrt nlogn),实际上 n2n^2 的部分常数比较小所以跑的飞快。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include<cstdio>
#include<cmath>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1<<18;
const int p=998244353;
const int inv2=p+1>>1;
#define rint register int
#define rll register long long
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline int Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
inline int fpow(rint a,rint b){
rint res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
int w[N],r[N],pd,t;
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd^n){
for(rint i=0;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
pd=n;
}
for(rint i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(rint mid=1;mid<n;mid<<=1)
for(rint R=mid<<1,i=0;i<n;i+=R)
for(rint j=0;j<mid;j++)
t=Mul(w[mid+j],a[i+j+mid]),a[i+j+mid]=Del(a[i+j],t),a[i+j]=Add(a[i+j],t);
if(typ==1)return;
rint inv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)a[i]=Mul(a[i],inv);
}
vector<int> Mul(vector<int> a,vector<int> b){
rint n=a.size(),m=b.size(),Max=1;
while(Max<n+m)Max<<=1;
ntt(Max,a,1);ntt(Max,b,1);
for(rint i=0;i<Max;i++)a[i]=Mul(a[i],b[i]);
ntt(Max,a,-1);
return a;
}
vector<int> vec[210],vec2[210],ans;
int main(){
rint n,m;
w[N>>1]=1;w[N/2+1]=n=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],n);
for(rint i=N/2-1;~i;i--)w[i]=w[i<<1];
scanf("%d%d",&n,&m);++n;++m;
vector<int> f,g;ans.resize(n);
f.resize(n);g.resize(m);
for(rint i=0;i<n;i++)scanf("%d",&f[i]);
for(rint i=0;i<m;i++)scanf("%d",&g[i]);
int sqr=sqrt(n)+1;
vec[0].resize(n);vec[0][0]=1;
for(int i=1;i<=sqr;i++){
vec[i]=Mul(vec[i-1],g);
vec[i].resize(n);
}
vec2[0].resize(n);vec2[0][0]=1;
for(int i=1;i<=sqr;i++){
vec2[i]=Mul(vec2[i-1],vec[sqr]);
vec2[i].resize(n);
}
for(rint i=0;i<sqr;i++){
vector<int> now;
now.clear();now.resize(n);
for(rint j=0;j<sqr;j++){
if(i*sqr+j>=n)break;
rint val=f[i*sqr+j];
for(rint k=0;k<n;k++){
now[k]=Add(now[k],Mul(val,vec[j][k]));
}
}
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,now,1);ntt(Max,vec2[i],1);
for(rint j=0;j<Max;j++)now[j]=Mul(vec2[i][j],now[j]);
ntt(Max,now,-1);now.resize(n);
for(rint i=0;i<n;i++)
ans[i]=Add(ans[i],now[i]);
}
for(rint i=0;i<n;i++)
printf("%d ",ans[i]);
return 0;
}

多项式复合逆

给定多项式 GG,求一个多项式 FF,使得 G(F(x))=xG(F(x))=x 。大概考虑一下这个东西有什么意义,比如有一个函数H(G(x))=P(x)H(G(x))=P(x),那么代入复合逆就能够得到,H(x)=P(F(x))H(x)=P(F(x)),只需要多项式复合就能够求出 HH

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include<cmath>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1<<15;
const int p=998244353;
#define rint register int
#define rll register long long
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline int Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
int fpow(rint a,rint b){
rint res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
int r[N],w[N],t,pd;
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd^n){
for(rint i=0;i<n;i++)
r[i]=(r[i>>1]>>1)|((i&1)?(n>>1):0);
pd=n;
}
for(rint i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(rint mid=1;mid<n;mid<<=1)
for(rint i=0,R=mid<<1;i<n;i+=R)
for(rint j=0;j<mid;j++)
t=Mul(a[i+j+mid],w[mid+j]),a[i+j+mid]=Del(a[i+j],t),a[i+j]=Add(a[i+j],t);
if(~typ)return ;
rint inv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)
a[i]=Mul(a[i],inv);
}
vector<int> vec[N],vec2[N];
void dfs_inv(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){f.resize(1);f[0]=fpow(g[0],p-2);return;}
dfs_inv(f,g,n+1>>1);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],Del(2,Mul(f[i],g[i])));
ntt(Max,f,-1);f.resize(n);
}
vector<int> Mul(vector<int> a,vector<int> b){
rint n=a.size(),m=b.size();
rint Max=1;
while(Max<n+m)Max<<=1;
ntt(Max,a,1);ntt(Max,b,1);
for(rint i=0;i<Max;i++)
a[i]=Mul(a[i],b[i]);
ntt(Max,a,-1);
return a;
}
int main(){
w[N/2]=1;w[N/2+1]=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],w[N/2+1]);
for(rint i=N/2-1;i;i--)w[i]=w[i<<1];
rint n;
vector<int> a;
scanf("%d",&n);
a.resize(n);
rint sqr=sqrt(n)+1;
scanf("%*d");--n;
for(rint i=0;i<n;i++)
scanf("%d",&a[i]);
dfs_inv(a,a,n);
//for(rint v:a)printf("xx%d\n",v);
vec[0].resize(n);vec[0][0]=1;
vec2[0].resize(n);vec2[0][0]=1;
for(rint i=1;i<=sqr;i++){
vec[i]=Mul(vec[i-1],a);
vec[i].resize(n);
}
for(rint i=1;i<=sqr;i++){
vec2[i]=Mul(vec2[i-1],vec[sqr]);
vec2[i].resize(n);
}
printf("0 ");
for(rint i=0;i<=sqr;i++){
for(rint j=1;j<=sqr;j++){
rint x=i*sqr+j-1;
if(x>=n)return 0;
rint res=0;
for(rint k=0;k<=x;k++)
res=Add(res,Mul(vec[j][k],vec2[i][x-k]));
printf("%d ",Mul(res,fpow(x+1,p-2)));
}
}
return 0;
}

拉格朗日反演

如果有F(G(x))=xF(G(x))=x,即F,GF,G互为复合逆,同时一定有G(F(x))=xG(F(x))=x,可以称 G(x)=F1(x)G(x)=F^{-1}(x),F(x)=G1(x)F(x)=G^{-1}(x)

在这种情况下,有这样的式子:

[xn]F(x)=1n[xn1](xG(x))n[x^n]F(x)=\frac{1}n[x^{n-1}](\frac{x}{G(x)})^n

[xn]H(F(x))=1n[xn1]H(x)(xG(x))n[x^n]H(F(x))=\frac{1}n[x^{n-1}]H'(x)(\frac{x}{G(x)})^n

可以 nlognnlogn 的求出一项系数。

考虑

F=i=1n1i[xi1](xG(x))iF=\sum_{i=1}^{n}\frac{1}i[x^{i-1}](\frac{x}{G(x)})^i

F=i=0Lj=1L1iL+j[xiL+j1](xG(x))iL+jF=\sum_{i=0}^{L}\sum_{j=1}^{L}\frac{1}{iL+j}[x^{iL+j-1}](\frac{x}{G(x)})^{iL+j}

F=i=0Lj=1L1iL+j[xiL+j1](xG(x))iL(xG(x))jF=\sum_{i=0}^{L}\sum_{j=1}^{L}\frac{1}{iL+j}[x^{iL+j-1}](\frac{x}{G(x)})^{iL}(\frac{x}{G(x)})^{j}

L=nL=\sqrt n,然后预处理一下后边的式子,可以 O(n)O(n) 求出两个多项式乘积的某一项,时间复杂度同上。

有标号连通图计数

考虑一个简单无向图的个数,显然是2(i2)2^{\binom{i}{2}},设其 EGFEGFFF

一个无向图是由若干连通图拼在一起的,设连通图的 EGFEGFGG,枚举联通块个数,有

F=i=0+Gii!F=\sum_{i=0}^{+\infty}\frac{G^i}{i!}

除掉阶乘是因为连通块之间没有区别。

所以F=eGF=e^GG=lnFG=\ln F

边双连通图计数

求大小为 nn 的边双连通图的数量。

由于边双联通图没有一个固定的点,所以不大好做,考虑将其转化为有根的边双,其实就是加入了一个特殊点方便计数而已。

容易发现,设无根边双的 EGFEGFPP ,有根边双为 DD,那么 Di=iPiD_i=iP_i,原因是每个点都成为过根。

考虑有根边双怎么求,一张有根连通图一定是由若干个边双联通分量构成的,考虑根所在的边双联通分量,除去这个边双联通分量后,原图一定会分为若干个有根联通图,枚举根所在的连通分量的大小和划分成的有根连通图的个数,设有根联通图的 EGFEGFFF,有

F=i=1+Dixij=0+Fjijj!i!F=\sum_{i=1}^{+\infty}\frac{D_ix^i\sum_{j=0}^{+\infty}\frac{F^ji^j}{j!}}{i!}

Fjijj!=(Fi)jj!=eFi\frac{F^ji^j}{j!}=\frac{(Fi)^j}{j!}=e^{Fi}

所以

F=D(xeF)F=D(xe^F)

FF 显然比较好求,考虑DD 怎么求,令H=xeFH=xe^F,有F(x)=D(H(x))F(x)=D(H(x)),两边同时带入 HH 的复合逆,有,F(H1(x))=D(x)F(H^{-1}(x))=D(x)

所以只需要求出F(H1)F(H^{-1}),由拉格朗日反演

[xi]F(H1)=1i[xi1]F(xH)i[x^i]F(H^{-1})=\frac{1}i[x^{i-1}]F'(\frac{x}{H})^i

代入H=xeFH=xe^F,有

[xi]F(H1)=1i[xi1]FeiF[x^i]F(H^{-1})=\frac{1}i[x^{i-1}]F'e^{-iF}

于是可以 nlognnlogn 求出 DD 的一项。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include<vector>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1<<18;
const int p=998244353;
#define rint register int
#define rll register long long
#define ll long long
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline ll Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
inline ll fpow(rll a,rll b){
rll res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
int w[N],r[N],t,pd,fac[N],inv[N];
unsigned long long A[N];
vector<int> G;
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd!=n){
pd=n;
for(rint i=0;i<n;i++)
r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
}
for(rint i=0;i<n;i++)A[i]=a[r[i]];
for(rint mid=1;mid<n;mid<<=1)
for(rint i=0,R=mid<<1;i<n;i+=R)
for(rint j=0;j<mid;j++)
t=A[i+j+mid]*w[mid+j]%p,A[i+j+mid]=A[i+j]-t+p,A[i+j]+=t;
for(rint i=0;i<n;i++)a[i]=A[i]%p;
if(typ==1)return ;
rint inv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)a[i]=Mul(inv,a[i]);
}
void dfs_inv(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){
f.resize(1);f[0]=fpow(g[0],p-2);
return ;
}
dfs_inv(f,g,n+1>>1);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],Del(2,Mul(f[i],g[i])));
ntt(Max,f,-1);f.resize(n);
}
void Ln(vector<int> &f,vector<int> g,rint n){
f.resize(n);g.resize(n);
for(rint i=1;i<n;i++)
f[i-1]=Mul(i,g[i]);
f[n-1]=0;
dfs_inv(g,g,n);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);
for(rint i=n-1;i;i--)
f[i]=Mul(f[i-1],fpow(i,p-2));
f[0]=0;f.resize(n);
}
void Exp(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){
f.resize(1);f[0]=1;return ;
}
Exp(f,g,n+1>>1);
vector<int> c=f;
Ln(c,c,n);
rint Max=1;
while(Max<n+n)Max<<=1;
g.resize(Max);c.resize(Max);
for(rint i=0;i<Max;i++)
g[i]=Del(g[i],c[i]);
g[0]++;
ntt(Max,g,1);ntt(Max,f,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);f.resize(n);
}
vector<int> D;
int solve(rint n){
vector<int> f=D,g=G;
for(rint i=0;i<n;i++)
g[i]=p-Mul(g[i],n);
Exp(g,g,n);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);
return Mul(fac[n-1],Mul(f[n-1],fpow(n,p-2)));
}
int main(){
w[N/2]=1;w[N/2+1]=t=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],t);
for(rint i=N/2-1;i;i--)w[i]=w[i<<1];
fac[0]=fac[1]=inv[0]=inv[1]=1;
for(rint i=2;i<N;i++){
fac[i]=Mul(fac[i-1],i);
inv[i]=p-Mul(p/i,inv[p%i]);
}
for(rint i=2;i<N;i++)inv[i]=Mul(inv[i-1],inv[i]);
G.resize(1e5+1);
for(rint i=0;i<=1e5;i++)
G[i]=Mul(fpow(2,1ll*(i-1)*i/2%(p-1)),inv[i]);
Ln(G,G,1e5+1);
D.resize(1e5+1);
for(rint i=0;i<=1e5;i++)
G[i]=Mul(i,G[i]);
for(rint i=1;i<=1e5;i++){
D[i-1]=Mul(i,G[i]);
}
for(rint i=1;i<=5;i++){
rint x;
scanf("%d",&x);
printf("%d\n",solve(x));
}
return 0;
}

多项式多点求值

给定一个 nn 次多项式和 mm 个值,求出 F(a1),F(a2),F(am)F(a_1),F(a_2),……F(a_m)

构造两个多项式P=i=1n2(xai),Q=i=n2+1n(xai)P=\prod_{i=1}^{\frac{n}2}(x-a_i),Q=\prod_{i=\frac{n}2+1}^n(x-a_i),那么如果令F=PA+BF=P*A+B,对于任意 [1,n2][1,\frac{n}2] 中的点,F(ai)=B(ai)F(a_i)=B(a_i),原因是代入后 PP00,所以就有了理论上是 nlognnlogn 实际上有大常数的多项式取模写法。

实际上取模的做法貌似已经成了时代的眼泪,我们现在已经有了更快的转置做法。

考虑一种减法卷积,定义 c=Mult(a,b)c=Mult(a,b)ci=j=0+ajbjic_i=\sum_{j=0}^{+\infty}a_jb_{j-i}

它有一个性质是,Mult(Mult(a,b),c)=Mult(a,bc)Mult(Mult(a,b),c)=Mult(a,b*c),可以利用这个性质进行多点求值。

F(aj)F(a_j) 展开,得到

F(aj)=i=0n[xi]FajiF(a_j)=\sum_{i=0}^{n}[x_i]Fa_j^i

观察发现,如果构造一个多项式 Gj=i=0najixiG_j=\sum_{i=0}^{n}a_j^ix^i,那么 $ F(a_j) $ 实际上是 Mult(F,Gj)Mult(F,G_j) 的常数项,证明很简单,代入即可。

GG 是一个 nn 次的多项式,暴力做的话是 nmlognnmlogn ,不如直接乘,考虑优化,将 GG 写成封闭形式,Gj=11ajxG_j=\frac{1}{1-a_jx},这样就可以 O(1)O(1) 的算出每一个 GjG_j 的分母,但是还是没啥用,因为乘起来还是要 O(nlogn)O(nlogn) ,最后我们要的是每一个 Mult(F,Gj)Mult(F,G_j) 的常数项,考虑求出 j=1mGj\prod_{j=1}^mG_j,可以用分治乘法和求逆求出,直接令F=Mult(F,j=1mGj)F=Mult(F,\prod_{j=1}^mG_j),再一次进行分治,如果进去左区间就把右区间的贡献去掉,具体的,设 T=j=n2+1n1GjT=\prod_{j=\frac{n}2+1}^{n}\frac{1}{G_j},只需要令 F=Mult(F,T)F=Mult(F,T),进入右区间时同理,原因是暴力做是对于右区间的每一个 jj,进行 Mult(Mult(F,1Gj),1Gj)Mult(Mult(F,\frac{1}{G_j}),\frac{1}{G_j}),由 Mult(Mult(a,b),c)=Mult(a,bc)Mult(Mult(a,b),c)=Mult(a,b*c) 知,可以先对右区间卷积起来一起做减法卷积。

然后就有了小常数的多点求值。

多项式快速插值

给定 n+1n+1 个点值,要求一个 nn 次多项式。

考虑拉格朗日插值法,有

F=i=0nyii!=jxxjxixjF=\sum_{i=0}^ny_i\prod_{i!=j}\frac{x-x_j}{x_i-x_j}

这个式子为什么是对的,其实很简单,发现带入任意 xix_i 都成立,然后就没了。

但是直接做是 O(n2)O(n^2) 的,将其进行简单变形。

F=i=0nyii!=jxixji!=jxxjF=\sum_{i=0}^n\frac{y_i}{\prod_{i!=j}x_i-x_j}\prod_{i!=j}{x-x_j}

考虑带有 yiy_i 的那一项,容易发现只需要求出下边的即可,有点难求于是构造一个多项式 G=i=0nxxiG=\prod_{i=0}^nx-x_i,只需要求出 G(x)xxi\frac{G(x)}{x-x_i},这个东西看起来好求但是实际上不是很好玩,因为分子分母在x=xix=x_i的时候都是00

这时候就要用到洛必达法则了。

众所周知,两个无穷小之比或两个无穷大之比的极限可能存在,也可能不存在。因此,求这类极限时往往需要适当的变形,转化成可利用极限运算法则或重要极限的形式进行计算。洛必达法则便是应用于这类极限计算的通用方法 。

洛必达法则:对于 f(x)g(x)\frac{f(x)}{g(x)},若limxaf(x)=0,limxag(x)=0\lim\limits_{x→a}f(x)=0,\lim\limits_{x→a}g(x)=0,那么limxaf(x)g(x)=limxaf(x)g(x)\lim\limits_{x→a}\frac{f(x)}{g(x)}=\lim\limits_{x→a}\frac{f'(x)}{g'(x)}

也就是说,当 x=xix=x_i 时,$\frac{G(x)}{x-x_i}=\frac{G’(x)}{(x-x_i)‘}=G’(x) $

于是有,

F=i=0nyiG(xi)i!=jxxjF=\sum_{i=0}^n\frac{y_i}{G'(x_i)}\prod_{i!=j}{x-x_j}

使用多项式多点求值即可求出所有的 GG' ,不妨令 yi=yiG(xi)y_i=\frac{y_i}{G'(x_i)}

那么实际上要求的是

F=i=0nyii!=jxxjF=\sum_{i=0}^ny_i\prod_{i!=j}{x-x_j}

Fl,r=i=lryii!=jxxjF_{l,r}=\sum_{i=l}^ry_i\prod_{i!=j}{x-x_j}

直接暴力拆开,

Fl,r=i=lmidyii!=jxxj+i=mid+1ryii!=jxxjF_{l,r}=\sum_{i=l}^{mid}y_i\prod_{i!=j}{x-x_j}+\sum_{i=mid+1}^ry_i\prod_{i!=j}{x-x_j}

Fl,r=i=lmidyii=l,i!=jmidxxji=mid+1nxxj+i=mid+1ryii=mid+1,i!=jnxxji=lmidxxjF_{l,r}=\sum_{i=l}^{mid}y_i\prod_{i=l,i!=j}^{mid}{x-x_j}\prod_{i=mid+1}^nx-x_j+\sum_{i=mid+1}^ry_i\prod_{i=mid+1,i!=j}^n{x-x_j}\prod_{i=l}^{mid}x-x_j

Fl,r=Fl,midi=mid+1nxxj+Fmid+1,ri=lmidxxjF_{l,r}=F_{l,mid}\prod_{i=mid+1}^nx-x_j+F_{mid+1,r}\prod_{i=l}^{mid}x-x_j

自底向上乘即可。

下降幂多项式乘法

给定下降幂多项式 F,GF,G,求 FGF*G

首先考虑一个下降幂单项式 xnx^{\underline n} 的点值 EGFEGF

\sum_{i=n}^{+\infty}\frac{i^\underline n}{i!}x^i

i=n+1(in)!xi\sum_{i=n}^{+\infty}\frac{1}{(i-n)!}x^i

i=0+1i!xi+n\sum_{i=0}^{+\infty}\frac{1}{i!}x^{i+n}

exxne^xx^{n} ,那么只需要将 F,GF,G 乘上 exe^x 即可转化为点值的 EGFEGF ,之后直接用点值相乘即可,注意得到的是点值的 EGFEGF 所以相乘的时候要乘上一个阶乘。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1<<19;
const int p=998244353;
const int inv2=p+1>>1;
#define rint register int
#define rll register long long
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline int Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
inline int fpow(rint a,rint b){
rint res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
int w[N],r[N],pd,t,inv[N],fac[N];
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd^n){
for(rint i=0;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
pd=n;
}
for(rint i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(rint mid=1;mid<n;mid<<=1)
for(rint R=mid<<1,i=0;i<n;i+=R)
for(rint j=0;j<mid;j++)
t=Mul(w[mid+j],a[i+j+mid]),a[i+j+mid]=Del(a[i+j],t),a[i+j]=Add(a[i+j],t);
if(typ==1)return;
rint inv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)a[i]=Mul(a[i],inv);
}
int main(){
rint n,m;
w[N>>1]=1;w[N/2+1]=n=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],n);
for(rint i=N/2-1;~i;i--)w[i]=w[i<<1];
vector<int> a,b,e,inve;
scanf("%d%d",&n,&m);++n;++m;
a.resize(n);b.resize(m);
for(rint i=0;i<n;i++)
scanf("%d",&a[i]);
for(rint i=0;i<m;i++)
scanf("%d",&b[i]);
e.resize(n+m);inve.resize(n+m);
inv[0]=inv[1]=fac[0]=fac[1]=1;
for(rint i=2;i<N;i++){
inv[i]=p-1ll*p/i*inv[p%i]%p;
fac[i]=Mul(fac[i-1],i);
}
for(rint i=2;i<N;i++)
inv[i]=Mul(inv[i],inv[i-1]);
for(rint i=0;i<n+m;i++)
e[i]=inv[i],inve[i]=(i&1)?p-inv[i]:inv[i];
rint Max=1;
while(Max<n*2+m*2)Max<<=1;
ntt(Max,a,1);ntt(Max,b,1);
ntt(Max,e,1);ntt(Max,inve,1);
for(rint i=0;i<Max;i++)
a[i]=Mul(a[i],e[i]),b[i]=Mul(b[i],e[i]);
ntt(Max,a,-1);ntt(Max,b,-1);
for(rint i=0;i<Max;i++)a[i]=Mul(fac[i],Mul(a[i],b[i]));
for(rint i=n+m-1;i<Max;i++)a[i]=0;
ntt(Max,a,1);
for(rint i=0;i<Max;i++)a[i]=Mul(a[i],inve[i]);
ntt(Max,a,-1);
for(rint i=0;i<n+m-1;i++)
printf("%d ",a[i]);
return 0;
}

下降幂多项式转普通多项式

给定一个下降幂多项式,求一个与其相等的普通多项式。

直接将下降幂转化成点值的形式然后快速插值即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1<<19;
const int p=998244353;
const int inv2=p+1>>1;
#define rint register int
#define rll register long long
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline int Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
inline int fpow(rint a,rint b){
rint res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
int w[N],r[N],pd,t;
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd^n){
for(rint i=0;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
pd=n;
}
for(rint i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(rint mid=1;mid<n;mid<<=1)
for(rint R=mid<<1,i=0;i<n;i+=R)
for(rint j=0;j<mid;j++)
t=Mul(w[mid+j],a[i+j+mid]),a[i+j+mid]=Del(a[i+j],t),a[i+j]=Add(a[i+j],t);
if(typ==1)return;
rint inv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)a[i]=Mul(a[i],inv);
}
void dfs_inv(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){f.resize(1);f[0]=fpow(g[0],p-2);return;}
dfs_inv(f,g,n+1>>1);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)f[i]=Mul(f[i],Del(2,Mul(f[i],g[i])));
ntt(Max,f,-1);f.resize(n);
}
void dfs_sqrt(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){f.resize(1);f[0]=1;return;}
dfs_sqrt(f,g,n+1>>1);
vector<int> a;
dfs_inv(a,f,n);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,a,1);ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(Add(Mul(f[i],f[i]),g[i]),Mul(inv2,a[i]));
ntt(Max,f,-1);f.resize(n);
}
void Ln(vector<int> &f,vector<int> g,rint n){
g.resize(n);
dfs_inv(f,g,n);
for(rint i=1;i<n;i++)g[i-1]=Mul(g[i],i);
g[n-1]=0;
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);
f.resize(n);
for(rint i=n-1;i;i--)f[i]=Mul(f[i-1],fpow(i,p-2));
f[0]=0;
}
void Exp(vector<int> &f,vector<int> g,int n){
g.resize(n);
if(n==1){f.resize(1);f[0]=1;return;}
Exp(f,g,n+1>>1);
vector<int> a;
Ln(a,f,n);
rint Max=1;
while(Max<n+n)Max<<=1;
g.resize(Max);a.resize(Max);
for(rint i=0;i<Max;i++)a[i]=Del(g[i],a[i]);a[0]++;
ntt(Max,a,1);ntt(Max,f,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],a[i]);
ntt(Max,f,-1);f.resize(n);
}
struct Eval{
vector<int> vec[N<<2],A,res,tmp;
vector<int> Mult(vector<int> a,vector<int> b,rint n){
rint Max=a.size();
for(int i=0;i<Max;i++)
a[i]=Mul(a[i],b[i]);
ntt(Max,a,1);a.resize(n);
return a;
}
void dfs1(rint x,rint l,rint r){
if(l==r){
vec[x].resize(2);
vec[x][0]=1;
vec[x][1]=p-tmp[l];
return;
}
rint mid=l+r>>1,L=x<<1,R=x<<1|1;
dfs1(L,l,mid);dfs1(R,mid+1,r);
rint Max=1,n=r-l+2;
while(Max<n)Max<<=1;vec[x].resize(Max);
ntt(Max,vec[L],1);ntt(Max,vec[R],1);
for(rint i=0;i<Max;i++)
vec[x][i]=Mul(vec[L][i],vec[R][i]);
ntt(Max,vec[x],-1);vec[x].resize(n);
}
void dfs2(rint x,rint l,rint r,vector<int> v){
rint n=r-l+2;
if(l==r){
res[l]=v[0];
return;
}
rint Max=1,mid=l+r>>1;
while(Max<n)Max<<=1;
ntt(Max,v,-1);
dfs2(x<<1,l,mid,Mult(v,vec[x<<1|1],mid-l+1));
dfs2(x<<1|1,mid+1,r,Mult(v,vec[x<<1],r-mid));
}
vector<int> query(vector<int> A,vector<int> a){
rint n=A.size(),m=a.size()-1;res.resize(m+1);tmp=a;
dfs1(1,1,m);
dfs_inv(vec[1],vec[1],n);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,vec[1],1);
ntt(Max,A,-1);
dfs2(1,1,m,Mult(A,vec[1],n));
return res;
}
}T;
struct Fin{
vector<int> vec[N<<2],vec2[N<<2],tmp,tmp2;
void dfs1(rint x,rint l,rint r){
if(l==r){
vec[x].resize(2);
vec[x][0]=p-tmp[l];
vec[x][1]=1;
return;
}
rint mid=l+r>>1,L=x<<1,R=x<<1|1,n=r-l+2;
dfs1(L,l,mid);dfs1(R,mid+1,r);
rint Max=1;
while(Max<n)Max<<=1;
ntt(Max,vec[L],1);ntt(Max,vec[R],1);
vec[x].resize(Max);
for(rint i=0;i<Max;i++)
vec[x][i]=Mul(vec[L][i],vec[R][i]);
ntt(Max,vec[x],-1);
vec[x].resize(n);
}
void dfs2(rint x,rint l,rint r){
if(l==r){
vec2[x].resize(1);
vec2[x][0]=Mul(tmp2[l-1],fpow(tmp[l],p-2));
return;
}
rint mid=l+r>>1,L=x<<1,R=x<<1|1,n=r-l+2;
dfs2(x<<1,l,mid);dfs2(x<<1|1,mid+1,r);
rint Max=1;
while(Max<n)Max<<=1;
vec2[x].resize(Max);
ntt(Max,vec2[L],1);ntt(Max,vec2[R],1);
for(rint i=0;i<Max;i++)
vec2[x][i]=Add(Mul(vec2[L][i],vec[R][i]),Mul(vec2[R][i],vec[L][i]));
ntt(Max,vec2[x],-1);vec2[x].resize(n-1);
}
vector<int> query(vector<int> x,vector<int> y){
int n=x.size()-1;tmp=x;tmp2=y;
dfs1(1,1,n);
for(rint i=1;i<=n;i++)
vec[1][i-1]=Mul(vec[1][i],i);
vec[1][n]=0;
tmp=T.query(vec[1],x);
dfs2(1,1,n);
return vec2[1];
}
}F;
int inv[N],fac[N];
int main(){
rint n,m;
w[N>>1]=1;w[N/2+1]=n=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],n);
for(rint i=N/2-1;~i;i--)w[i]=w[i<<1];
scanf("%d",&n);
vector<int> f,g;
f.resize(n);g.resize(n);
for(rint i=0;i<n;i++)scanf("%d",&f[i]);
inv[0]=inv[1]=fac[1]=fac[0]=1;
for(rint i=2;i<=n;i++)inv[i]=p-1ll*p/i*inv[p%i]%p;
for(rint i=2;i<=n;i++){
inv[i]=Mul(inv[i],inv[i-1]);
fac[i]=Mul(fac[i-1],i);
}
for(rint i=0;i<n;i++)g[i]=inv[i];
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);f.resize(n);
for(rint i=0;i<n;i++)f[i]=Mul(f[i],fac[i]);
g.resize(n+1);
for(rint i=1;i<=n;i++)g[i]=i-1;
f=F.query(g,f);
for(rint i=0;i<n;i++)
printf("%d ",f[i]);
return 0;
}

普通多项式转下降幂多项式

给定一个普通多项式,求一个与其相等的下降幂多项式。

使用多点求值求出下降幂的点值 EGFEGF ,然后乘上 exe^{-x} 即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1<<18;
const int p=998244353;
const int inv2=p+1>>1;
#define rint register int
#define rll register long long
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline int Mul(rll x,rint y){x*=y;return x>=p?x%p:x;}
inline int fpow(rint a,rint b){
rint res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
int w[N],r[N],pd,t;
void ntt(rint n,vector<int> &a,rint typ){
a.resize(n);
if(pd^n){
for(rint i=0;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
pd=n;
}
for(rint i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(rint mid=1;mid<n;mid<<=1)
for(rint R=mid<<1,i=0;i<n;i+=R)
for(rint j=0;j<mid;j++)
t=Mul(w[mid+j],a[i+j+mid]),a[i+j+mid]=Del(a[i+j],t),a[i+j]=Add(a[i+j],t);
if(typ==1)return;
rint inv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)a[i]=Mul(a[i],inv);
}
void dfs_inv(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){f.resize(1);f[0]=fpow(g[0],p-2);return;}
dfs_inv(f,g,n+1>>1);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)f[i]=Mul(f[i],Del(2,Mul(f[i],g[i])));
ntt(Max,f,-1);f.resize(n);
}
void dfs_sqrt(vector<int> &f,vector<int> g,rint n){
g.resize(n);
if(n==1){f.resize(1);f[0]=1;return;}
dfs_sqrt(f,g,n+1>>1);
vector<int> a;
dfs_inv(a,f,n);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,a,1);ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(Add(Mul(f[i],f[i]),g[i]),Mul(inv2,a[i]));
ntt(Max,f,-1);f.resize(n);
}
void Ln(vector<int> &f,vector<int> g,rint n){
g.resize(n);
dfs_inv(f,g,n);
for(rint i=1;i<n;i++)g[i-1]=Mul(g[i],i);
g[n-1]=0;
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);
f.resize(n);
for(rint i=n-1;i;i--)f[i]=Mul(f[i-1],fpow(i,p-2));
f[0]=0;
}
void Exp(vector<int> &f,vector<int> g,int n){
g.resize(n);
if(n==1){f.resize(1);f[0]=1;return;}
Exp(f,g,n+1>>1);
vector<int> a;
Ln(a,f,n);
rint Max=1;
while(Max<n+n)Max<<=1;
g.resize(Max);a.resize(Max);
for(rint i=0;i<Max;i++)a[i]=Del(g[i],a[i]);a[0]++;
ntt(Max,a,1);ntt(Max,f,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],a[i]);
ntt(Max,f,-1);f.resize(n);
}
struct Eval{
vector<int> vec[N<<2],A,res,tmp;
vector<int> Mult(vector<int> a,vector<int> b,rint n){
rint Max=a.size();
for(int i=0;i<Max;i++)
a[i]=Mul(a[i],b[i]);
ntt(Max,a,1);a.resize(n);
return a;
}
void dfs1(rint x,rint l,rint r){
if(l==r){
vec[x].resize(2);
vec[x][0]=1;
vec[x][1]=p-tmp[l];
return;
}
rint mid=l+r>>1,L=x<<1,R=x<<1|1;
dfs1(L,l,mid);dfs1(R,mid+1,r);
rint Max=1,n=r-l+2;
while(Max<n)Max<<=1;vec[x].resize(Max);
ntt(Max,vec[L],1);ntt(Max,vec[R],1);
for(rint i=0;i<Max;i++)
vec[x][i]=Mul(vec[L][i],vec[R][i]);
ntt(Max,vec[x],-1);vec[x].resize(n);
}
void dfs2(rint x,rint l,rint r,vector<int> v){
rint n=r-l+2;
if(l==r){
res[l]=v[0];
return;
}
rint Max=1,mid=l+r>>1;
while(Max<n)Max<<=1;
ntt(Max,v,-1);
dfs2(x<<1,l,mid,Mult(v,vec[x<<1|1],mid-l+1));
dfs2(x<<1|1,mid+1,r,Mult(v,vec[x<<1],r-mid));
}
vector<int> query(vector<int> A,vector<int> a){
rint n=A.size(),m=a.size()-1;res.resize(m+1);tmp=a;
dfs1(1,1,m);
dfs_inv(vec[1],vec[1],n);
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,vec[1],1);
ntt(Max,A,-1);
dfs2(1,1,m,Mult(A,vec[1],n));
return res;
}
}T;
struct Fin{
vector<int> vec[N<<2],vec2[N<<2],tmp,tmp2;
void dfs1(rint x,rint l,rint r){
if(l==r){
vec[x].resize(2);
vec[x][0]=p-tmp[l];
vec[x][1]=1;
return;
}
rint mid=l+r>>1,L=x<<1,R=x<<1|1,n=r-l+2;
dfs1(L,l,mid);dfs1(R,mid+1,r);
rint Max=1;
while(Max<n)Max<<=1;
ntt(Max,vec[L],1);ntt(Max,vec[R],1);
vec[x].resize(Max);
for(rint i=0;i<Max;i++)
vec[x][i]=Mul(vec[L][i],vec[R][i]);
ntt(Max,vec[x],-1);
vec[x].resize(n);
}
void dfs2(rint x,rint l,rint r){
if(l==r){
vec2[x].resize(1);
vec2[x][0]=Mul(tmp2[l],fpow(tmp[l],p-2));
return;
}
rint mid=l+r>>1,L=x<<1,R=x<<1|1,n=r-l+2;
dfs2(x<<1,l,mid);dfs2(x<<1|1,mid+1,r);
rint Max=1;
while(Max<n)Max<<=1;
vec2[x].resize(Max);
ntt(Max,vec2[L],1);ntt(Max,vec2[R],1);
for(rint i=0;i<Max;i++)
vec2[x][i]=Add(Mul(vec2[L][i],vec[R][i]),Mul(vec2[R][i],vec[L][i]));
ntt(Max,vec2[x],-1);vec2[x].resize(n-1);
}
vector<int> query(vector<int> x,vector<int> y){
int n=x.size()-1;tmp=x;tmp2=y;
dfs1(1,1,n);
for(rint i=1;i<=n;i++)
vec[1][i-1]=Mul(vec[1][i],i);
vec[1][n]=0;
tmp=T.query(vec[1],x);
dfs2(1,1,n);
return vec2[1];
}
}F;
int inv[N];
int main(){
rint n,m;
w[N>>1]=1;w[N/2+1]=n=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],n);
for(rint i=N/2-1;~i;i--)w[i]=w[i<<1];
scanf("%d",&n);
vector<int> f,g;
f.resize(n);g.resize(n+1);
for(rint i=0;i<n;i++)scanf("%d",&f[i]);
inv[0]=inv[1]=1;
for(rint i=1;i<=n;i++){
g[i]=i-1;
if(i!=1)inv[i]=p-1ll*p/i*inv[p%i]%p;
}
for(rint i=2;i<=n;i++)inv[i]=Mul(inv[i-1],inv[i]);
f=T.query(f,g);
for(rint i=0;i<f.size()-1;i++)f[i]=Mul(f[i+1],inv[i]);
f.resize(n);
g.resize(n);
for(rint i=0;i<n;i++)g[i]=(i&1)?p-inv[i]:inv[i];
rint Max=1;
while(Max<n+n)Max<<=1;
ntt(Max,f,1);ntt(Max,g,1);
for(rint i=0;i<Max;i++)f[i]=Mul(f[i],g[i]);
ntt(Max,f,-1);f.resize(n);
for(rint i=0;i<n;i++)
printf("%d ",f[i]);
return 0;
}

Chirp Z-Transform

给定一个 nn 项多项式 P(x)P(x) 以及 c,mc, m,请计算 P(c0),P(c1),,P(cm1)P(c^0),P(c^1),\dots,P(c^{m-1})。所有答案都对 998244353998244353 取模。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include<vector>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define db double
#define ll long long
#define rint register int
#define rll register long long
#define ull unsigned long long
const int N=1<<22;
const int p=998244353;
const int sqr=1<<15;
inline int read(){
rint x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch<='9'&&ch>='0'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline int Add(rint x,rint y){x+=y;return x>=p?x-p:x;}
inline int Del(rint x,rint y){x-=y;return x>=0?x:x+p;}
inline ll Mul(rll x,rll y){x*=y;return x>=p?x%p:x;}
inline ll fpow(rll a,rll b){
rll res=1;
for(;b;b>>=1,a=Mul(a,a))
if(b&1)res=Mul(res,a);
return res;
}
int w[N],r[N],t,pd;
void ntt(rint n,vector<int>&a,rint typ){
a.resize(n);
for(rint i=0;i<n;i++)
if(i<r[i])swap(a[i],a[r[i]]);
for(rint mid=1;mid<n;mid<<=1)
for(rint i=0,R=mid<<1;i<n;i+=R)
for(rint j=0;j<mid;j++)
t=Mul(a[i+j+mid],w[mid+j]),a[i+j+mid]=Del(a[i+j],t),a[i+j]=Add(a[i+j],t);
if(~typ)return ;
rint inv=fpow(n,p-2);
reverse(a.begin()+1,a.end());
for(rint i=0;i<n;i++)
a[i]=Mul(a[i],inv);
}
int pw1[sqr+1],pw2[sqr+1];
inline int pw(rint k){
return Mul(pw1[k%sqr],pw2[k/sqr]);
}
int main(){
w[N/2]=1;w[N/2+1]=t=fpow(3,(p-1)/N);
for(rint i=N/2+2;i<N;i++)w[i]=Mul(w[i-1],t);
for(rint i=N/2-1;i;i--)w[i]=w[i<<1];
rint n=read(),c=read(),m=read();
pw1[0]=pw2[0]=1;
for(rint i=1;i<=sqr;i++)pw1[i]=Mul(pw1[i-1],c);
for(rint i=1;i<=sqr;i++)pw2[i]=Mul(pw2[i-1],pw1[sqr]);
vector<int> f,g;f.resize(n),g.resize(n+m);
for(rint i=0;i<n;i++){
f[i]=read();
f[i]=Mul(f[i],pw(p-1-1ll*(i-1)*i/2%(p-1)));
}
for(rint i=0;i<n+m;i++)
g[i]=pw(1ll*i*(i-1)/2%(p-1));
rint Max=1;
while(Max<n+m)Max<<=1;
for(rint i=0;i<Max;i++)
r[i]=(r[i>>1]>>1)|((i&1)?Max>>1:0);
ntt(Max,g,-1);ntt(Max,f,1);
for(rint i=0;i<Max;i++)
f[i]=Mul(f[i],g[i]);
ntt(Max,f,1);
for(rint i=0;i<m;i++)
printf("%d ",Mul(pw(p-1-1ll*(i-1)*i/2%(p-1)),f[i]));
return 0;
}

常系数齐次线性递推

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include<cstdio>
#include<algorithm>
using namespace std;const int N=65536+10;typedef long long ll;const ll mod=998244353;
int n;int k;int rv[20][N];ll rt[20][20];int Len;ll tr1[N];ll tr2[N];long long st[N];long long xs[N];
ll sg[N];ll a[N];ll res[N];ll irg[N];ll q[N];ll rf[N];int DL=-1;ll ans=0;ll ret[N];
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
inline void ntt(ll* a,int o,int len,int d)//ntt
{
for(int i=0;i<len;i++)if(i<rv[d][i])swap(a[i],a[rv[d][i]]);
for(int k=1,j=1;k<len;k<<=1,j++)
for(int s=0;s<len;s+=(k<<1))
for(int i=s,w=1;i<s+k;i++,w=w*rt[o][j]%mod)
{ll a0=a[i];ll a1=a[i+k]*w%mod;a[i]=(a0+a1)%mod,a[i+k]=(a0+mod-a1)%mod;}
if(o==1){ll inv=po(len,mod-2);for(int i=0;i<len;i++)(a[i]*=inv)%=mod;}
}
inline void poly_inv(ll* a,ll* b,int len)//求逆
{
b[0]=po(a[0],mod-2);
for(int k=1,j=0;k<=len;k<<=1,j++)
{
for(int i=0;i<k;i++)tr1[i]=a[i];for(int i=0;i<k;i++)tr2[i]=b[i];
ntt(tr1,0,k<<1,j);ntt(tr2,0,k<<1,j);
for(int i=0;i<(k<<1);i++)b[i]=tr2[i]*(2+mod-tr1[i]*tr2[i]%mod)%mod;
ntt(b,1,k<<1,j);for(int i=k;i<(k<<1);i++)b[i]=0;
}
}
inline void poly_mod(ll* a)//取模
{
int mi=(k<<1);while(a[--mi]==0);if(mi<k)return;
for(int i=0;i<(Len<<1);i++)rf[i]=0;for(int i=0;i<=mi;i++)rf[i]=a[i];
reverse(rf,rf+mi+1);for(int i=mi-k+1;i<=mi;i++)rf[i]=0;ntt(rf,0,Len<<1,DL+1);
for(int i=0;i<(Len<<1);i++)q[i]=(rf[i]*irg[i])%mod;ntt(q,1,(Len<<1),DL+1);
for(int i=mi-k+1;i<=(Len<<1);i++)q[i]=0;reverse(q,q+mi-k+1);ntt(q,0,(Len<<1),DL+1);
for(int i=0;i<(Len<<1);i++)(q[i]*=sg[i])%=mod;ntt(q,1,(Len<<1),DL+1);
for(int i=0;i<k;i++)(a[i]+=mod-q[i])%=mod;for(int i=k;i<=mi;i++)a[i]=0;
}
int main()
{
for(int i=0;i<=15;i++)
for(int j=0;j<(1<<(i+1));j++)rv[i][j]=(rv[i][j>>1]>>1)|((j&1)<<i);
for(int t=2,j=1;j<=18;t<<=1,j++)rt[0][j]=po(3,(mod-1)/t);
for(int t=2,j=1;j<=18;t<<=1,j++)rt[1][j]=po(332748118,(mod-1)/t);
scanf("%d%d",&n,&k);
for(Len=1;Len<=k;Len<<=1,DL++); //预处理
for(int i=1;i<=k;i++){scanf("%lld",&xs[i]);xs[i]=xs[i]<0?xs[i]+mod:xs[i];}
for(int i=0;i<k;i++){scanf("%lld",&st[i]);st[i]=st[i]<0?st[i]+mod:st[i];}
for(int i=1;i<=k;i++)sg[k-i]=mod-xs[i];sg[k]=1;for(int i=0;i<=k;i++)ret[i]=sg[i];
for(int i=0;i<=k;i++)rf[i]=sg[i];reverse(rf,rf+k+1);poly_inv(rf,irg,Len);
for(int i=0;i<=k;i++)rf[i]=0;ntt(sg,0,Len<<1,DL+1);ntt(irg,0,Len<<1,DL+1);a[1]=1;res[0]=1;
while(n)//快速幂
{
if(n&1)
{
ntt(res,0,Len<<1,DL+1);ntt(a,0,Len<<1,DL+1);
for(int i=0;i<(Len<<1);i++)(res[i]*=a[i])%=mod;
ntt(res,1,Len<<1,DL+1);ntt(a,1,Len<<1,DL+1);poly_mod(res);
}ntt(a,0,Len<<1,DL+1);for(int i=0;i<(Len<<1);i++)(a[i]*=a[i])%=mod;
ntt(a,1,Len<<1,DL+1);poly_mod(a);n>>=1;
}for(int i=0;i<k;i++)(ans+=res[i]*st[i])%=mod;printf("%lld",ans);return 0;
}

常系数非齐次线性递推

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
#pragma GCC optimize ("unroll-loops")
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#define N 131077
#define p 998244353
#define ll long long
#define reg register
#define add(x,y) (x+y>=p?x+y-p:x+y)
#define dec(x,y) (x<y?x-y+p:x-y)
using namespace std;

inline void read(int &x){
x = 0;
char c = getchar();
while(c<'0'||c>'9') c = getchar();
while(c>='0'&&c<='9'){
x = (x<<3)+(x<<1)+(c^48);
c = getchar();
}
}

void print(int x){
if(x>9) print(x/10);
putchar(x%10+'0');
}

inline int power(int a,int t){
int res = 1;
while(t){
if(t&1) res = (ll)res*a%p;
a = (ll)a*a%p;
t >>= 1;
}
return res;
}

int rt[N],rev[N],inv[N],fac[N],ifac[N];
int siz;

inline int binom(int x,int y){
if(x<y) return 0;
return (ll)fac[x]*ifac[y]%p*ifac[x-y]%p;
}

void init(int n){
int w,lim = 1;
while(lim<=n) lim <<= 1,++siz;
for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
w = power(3,(p-1)>>siz);
fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = rt[lim>>1] = 1;
for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
for(reg int i=2;i<=lim;++i) inv[i] = (ll)(p-p/i)*inv[p%i]%p;
for(reg int i=2;i<=lim;++i) ifac[i] = fac[i] = (ll)fac[i-1]*i%p;
ifac[lim] = power(fac[lim],p-2);
for(reg int i=lim-1;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p;
}

inline int getlen(int n){
return 1<<(32-__builtin_clz(n));
}

inline void NTT(int *f,int type,int lim){
if(type==-1) reverse(f+1,f+lim);
reg int x,shift = siz-__builtin_ctz(lim);
static int a[N];
for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
for(reg int mid=1;mid!=lim;mid<<=1){
for(reg int j=0;j!=lim;j+=(mid<<1)){
for(reg int k=0;k!=mid;++k){
x = (ll)a[j|k|mid]*rt[mid|k]%p;
a[j|k|mid] = dec(a[j|k],x);
a[j|k] = add(a[j|k],x);
}
}
}
memcpy(f,a,lim<<2);
if(type==1) return;
x = inv[lim];
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%p;
}

void inverse(const int *f,int *R,int n){
static int g[N],h[N],q[N];
memset(g,0,getlen(n<<1)+2<<2);
int lim = 1,top = 0;
int s[30];
while(n){
s[++top] = n;
n >>= 1;
}
g[0] = power(f[0],p-2);
while(top--){
n = s[top+1];
while(lim<=(n<<1)) lim <<= 1;
memcpy(q,g,(n+1)<<2);
memcpy(h,f,(n+1)<<2);
memset(h+n+1,0,(lim-n)<<2);
NTT(g,1,lim),NTT(h,1,lim);
for(reg int i=0;i!=lim;++i) g[i] = (ll)g[i]*g[i]%p*h[i]%p;
NTT(g,-1,lim);
for(reg int i=0;i<=n;++i) g[i] = dec(add(q[i],q[i]),g[i]);
memset(g+n+1,0,(lim-n)<<2);
}
memcpy(R,g,(n+1)<<2);
}

void divide(const int *f,const int *g,int n,int m,int *R){
static int A[N],B[N];
memcpy(A,f,(n+1)<<2);
memcpy(B,g,(m+1)<<2);
reverse(A,A+n+1);
reverse(B,B+m+1);
int tt = n-m,lim = getlen((n-m)<<1);
for(reg int i=tt+1;i!=lim;++i) A[i] = 0;
for(reg int i=min(m,tt)+1;i!=lim;++i) B[i] = 0;
inverse(B,B,tt);
NTT(A,1,lim),NTT(B,1,lim);
for(reg int i=0;i!=lim;++i) A[i] = (ll)A[i]*B[i]%p;
NTT(A,-1,lim);
reverse(A,A+tt+1);
memcpy(R,A,(tt+1)<<2);
}

void mod(const int *f,const int *g,int n,int m,int *R){
if(n<m){
memcpy(R,f,(n+1)<<2);
return;
}
static int A[N],B[N];
memcpy(B,f,(n+1)<<2);
int lim = getlen(n);
divide(f,g,n,m,R);
for(int i=0;i<=m;++i) A[i] = g[i];
for(int i=m+1;i!=lim;++i) A[i] = 0;
for(int i=n-m+1;i!=lim;++i) R[i] = 0;
NTT(A,1,lim),NTT(R,1,lim);
for(reg int i=0;i!=lim;++i) R[i] = (ll)A[i]*R[i]%p;
NTT(R,-1,lim);
for(reg int i=0;i!=m;++i) R[i] = dec(B[i],R[i]);
for(int i=m;i!=lim;++i) R[i] = 0;
}

#define mid ((l+r)>>1)
#define ls (u<<1)
#define rs (u<<1|1)
int bflim;
int *P[N],len[N];

void prepare(int l,int r,int u,const int *a){ //分治乘,多点求值用
if(l==r){
len[u] = 1;
P[u] = new int[2];
P[u][0] = p-a[l],P[u][1] = 1;
return;
}
prepare(l,mid,ls,a);
prepare(mid+1,r,rs,a);
len[u] = r-l+1;
int lim = getlen(len[u]);
P[u] = new int[len[u]+1];
int F[lim+1],G[lim+1];
memcpy(F,P[ls],(len[ls]+1)<<2);
memcpy(G,P[rs],(len[rs]+1)<<2);
if(r-l>bflim){
memset(F+len[ls]+1,0,(lim-len[ls]+1)<<2);
memset(G+len[rs]+1,0,(lim-len[rs]+1)<<2);
NTT(F,1,lim),NTT(G,1,lim);
for(reg int i=0;i!=lim;++i) F[i] = (ll)F[i]*G[i]%p;
NTT(F,-1,lim);
memcpy(P[u],F,(len[u]+1)<<2);
}else{
memset(P[u],0,(len[u]+1)<<2);
for(reg int i=0;i<=len[ls];++i)
for(reg int j=0;j<=len[rs];++j)
P[u][i+j] = (P[u][i+j]+(ll)F[i]*G[j])%p;
}
}

void solve(const int *F,int l,int r,const int *a,int u,int n,int *R){ //多点求值
if(r-l<=bflim){ //小范围暴力
ll pw[17];
int res,x;
ll s1,s2,s3,s4;
pw[0] = 1;
for(reg int j=l;j<=r;++j){
res = F[n],x = a[j];
reg int i = 1;
for(;i<=16;++i) pw[i] = pw[i-1]*x%p;
i = n-1;
while(i>=15){
s1 = res*pw[16]+F[i]*pw[15]+F[i-1]*pw[14]+F[i-2]*pw[13];
s2 = F[i-3]*pw[12]+F[i-4]*pw[11]+F[i-5]*pw[10]+F[i-6]*pw[9];
s3 = F[i-7]*pw[8]+F[i-8]*pw[7]+F[i-9]*pw[6]+F[i-10]*pw[5];
s4 = F[i-11]*pw[4]+F[i-12]*pw[3]+F[i-13]*pw[2]+F[i-14]*x;
res = ((F[i-15]+s1+s2)%p+s3+s4)%p;
i -= 16;
}
i = (n&15)-1;
for(;~i;--i) res = ((ll)res*x+F[i])%p;
R[j] = res;
}
return;
}
int G[getlen(n<<1)+1];
memset(G,0,sizeof(G));
mod(F,P[ls],n,len[ls],G);
solve(G,l,mid,a,ls,len[ls]-1,R);
memset(G,0,sizeof(G));
mod(F,P[rs],n,len[rs],G);
solve(G,mid+1,r,a,rs,len[rs]-1,R);
}

void evaluation(const int *F,int *a,int n,int m,int *R){
bflim = log2(m);
prepare(1,m,1,a);
solve(F,1,m,a,1,n,R);
}

#undef ls
#undef rs
#undef mid

void multiply(const int *F,const int *G,int n,int m,int len,int *R){
static int A[N],B[N];
memcpy(A,F,(n+1)<<2);
memcpy(B,G,(m+1)<<2);
int lim = getlen(n+m);
memset(A+n+1,0,(lim-n+2)<<2);
memset(B+m+1,0,(lim-m+2)<<2);
NTT(A,1,lim),NTT(B,1,lim);
for(reg int i=0;i!=lim;++i) R[i] = (ll)A[i]*B[i]%p;
NTT(R,-1,lim);
memset(R+len+1,0,(lim-len+2)<<2);
}

void mod_power(const int *G,int k,int t,int *R){ //多项式快速幂 (模G(x))
int f[N],g[N];
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
int n = 1,m = 0;
f[1] = g[0] = 1;
while(t){
if(t&1){
multiply(f,g,n,m,n+m,g);
mod(g,G,n+m,k,g);
m = min(n+m,k-1);
}
int lim = getlen(n<<1);
NTT(f,1,lim);
for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*f[i]%p;
NTT(f,-1,lim);
mod(f,G,n<<1,k,f);
n = min(n<<1,k-1);
t >>= 1;
}
memcpy(R,g,k<<2);
}

int n,m,k,ans,T;
int F[N],G[N],B[N],A[N],a[N],f[N],d[N];

int main(){
init(100000);
read(n),read(m),read(k);
for(reg int i=0;i!=k;++i) read(a[i]);
for(reg int i=1;i<=k;++i) read(f[i]);
for(reg int i=0;i<=m;++i) read(G[i]);
for(reg int i=0;i<=m;++i) d[i+1] = i+k;
evaluation(G,d,m,m+1,B); //构造B的后半部分
for(reg int i=k+m;i>=k;--i) B[i] = B[i-k+1];
for(reg int i=0;i!=k;++i) B[i] = 0;
multiply(f,a,k,k,k-1,d); //一波卷积求出B的前半部分
for(reg int i=0;i!=k;++i) B[i] = dec(a[i],d[i]);
f[0] = p-1;
for(reg int i=0;i<=k;++i) F[i] = p-f[i];
inverse(F,F,m+k);
multiply(F,B,m+k,m+k,m+k,A); //求逆算出A
for(reg int i=0;i<=m+k;++i) a[i] = A[i];
memset(F,0,sizeof(F));
for(reg int i=0;i<=m+1;++i) F[i] = (m+1-i)&1?p-binom(m+1,i):binom(m+1,i); //高阶差分系数
T = m+k+1;
multiply(F,f,m+1,k,T,f); //化为齐次递推
memset(G,0,sizeof(G));
for(reg int i=0;i<=T;++i) G[T-i] = p-f[i];
mod_power(G,T,n,F);
for(reg int i=0;i!=T;++i) ans = (ans+(ll)F[i]*a[i])%p;
print(ans);
return 0;
}

字符串

最小表示法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=3e5+10;
int a[N];
int main(){
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)
scanf("%d",&a[i]);
int i=0,j=1,k=0;
while(i<n&&j<n&&k<n){
int p=a[(i+k)%n]-a[(j+k)%n];
if(!p)k++;
else{
if(p>0)i+=k+1;
else j+=k+1;
if(i==j)j++;
k=0;
}
}
k=min(i,j);
for(int i=0;i<n;i++)
printf("%d ",a[(i+k)%n]);
return 0;
}

kmp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=1e6+10;
int nxt[N];
char s1[N],s2[N];
int main(){
scanf("%s%s",s1+1,s2+1);
rint n=strlen(s1+1),m=strlen(s2+1);
for(rint i=2,j=0;i<=m;i++){
while(j&&s2[j+1]!=s2[i])j=nxt[j];
if(s2[i]==s2[j+1])j++;
nxt[i]=j;
}
for(rint i=1,j=0;i<=n;i++){
while(j&&s2[j+1]!=s1[i])j=nxt[j];
if(s2[j+1]==s1[i])j++;
if(j==m){
printf("%d\n",i-m+1);
j=nxt[j];
}
}
for(rint i=1;i<=m;i++)
printf("%d ",nxt[i]);
return 0;
}

exkmp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <iostream>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 2e7 + 10;
ll nxt[N], ext[N];
void qnxt(char *c)
{
int len = strlen(c);
int p = 0, k = 1, l; //我们会在后面先逐位比较出 nxt[1] 的值,这里先设 k 为 1
//如果 k = 0,p 就会锁定在 |c| 不会被更改,无法达成算法优化的效果啦
nxt[0] = len; //以 c[0] 开始的后缀就是 c 本身,最长公共前缀自然为 |c|
while(p + 1 < len && c[p] == c[p + 1]) p++;
nxt[1] = p; //先逐位比较出 nxt[1] 的值
for(int i = 2; i < len; i++)
{
p = k + nxt[k] - 1; //定义
l = nxt[i - k]; //定义
if(i + l <= p) nxt[i] = l; //如果灰方框小于初始的绿方框,直接确定 nxt[i] 的值
else
{
int j = max(0, p - i + 1);
while(i + j < len && c[i + j] == c[j]) j++; //否则进行逐位比较
nxt[i] = j;
k = i; //此时的 x + nxt[x] - 1 一定刷新了最大值,于是我们要重新赋值 k
}
}
}
void exkmp(char *a, char *b)
{
int la = strlen(a), lb = strlen(b);
int p = 0, k = 0, l;
while(p < la && p < lb && a[p] == b[p]) p++; //先算出初值用于递推
ext[0] = p;
for(int i = 1; i < la; i++) //下面都是一样的逻辑啦
{
p = k + ext[k] - 1;
l = nxt[i - k];
if(i + l <= p) ext[i] = l;
else
{
int j = max(0, p - i + 1);
while(i + j < la && j < lb && a[i + j] == b[j]) j++;
ext[i] = j;
k = i;
}
}
}
int la, lb;
char a[N], b[N];
ll ans;
int main()
{
cin.tie(0); cout.tie(0);
ios::sync_with_stdio(0);
cin >> a >> b;
qnxt(b);
exkmp(a, b);
la = strlen(a), lb = strlen(b), ans = 0;
for(int i = 0; i < lb; i++) //要注意下标从 0 开始
{
ans ^= (i + 1) * (nxt[i] + 1);
}
cout << ans << "\n";
ans = 0;
for(int i = 0; i < la; i++)
{
ans ^= (i + 1) * (ext[i] + 1);
}
cout << ans;
return 0;
}

manacher

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=3e7+10;
char s1[N],s2[N];
int n,m,f[N];
int main(){
scanf("%s",s1+1);
n=strlen(s1+1);
s2[m]='~';s2[++m]='#';
for(rint i=1;i<=n;i++)
s2[++m]=s1[i],s2[++m]='#';
rint res=0;
for(rint i=1,mid=0,r=0;i<=m;i++){
if(i<=r)f[i]=min(r-i+1,f[2*mid-i]);
while(s2[i+f[i]]==s2[i-f[i]])f[i]++;
if(i+f[i]-1>r)r=i+f[i]-1,mid=i;
res=max(res,f[i]-1);
}
printf("%d\n",res);
return 0;
}

ac自动机

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include<cstdio>
#include<queue>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e6+10;
int Trie[N][26],cnt,val[N],fail[N];
char s[N];
void Ins(char s[]){
int l=strlen(s),now=0;
for(int i=0;i<l;i++){
int t=s[i]-'a';
if(!Trie[now][t])Trie[now][t]=++cnt;
now=Trie[now][t];
}
val[now]++;
}
void bfs(){
queue<int > q;
for(int i=0;i<26;i++)
if(Trie[0][i])q.push(Trie[0][i]);
while(!q.empty()){
int u=q.front();q.pop();
for(int i=0;i<26;i++){
if(Trie[u][i])fail[Trie[u][i]]=Trie[fail[u]][i],q.push(Trie[u][i]);
else Trie[u][i]=Trie[fail[u]][i];
}
}
}
int query(char s[]){
int l=strlen(s),now=0,tot=0;
for(int i=0;i<l;i++){
int t=s[i]-'a';
now=Trie[now][t];
for(int j=now;j&&val[j]!=-1;j=fail[j]){
tot+=val[j];
val[j]=-1;
}
}
return tot;
}
int main(){
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%s",s);
Ins(s);
}
bfs();
scanf("%s",s);
printf("%d\n",query(s));
return 0;
}

回文自动机

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=5e5+10;
char s[N];
int ch[N][26],siz[N],fail[N],len[N],lst,cnt;
int get(int x,int n){
while(s[n]!=s[n-len[x]-1])x=fail[x];
return x;
}
void extend(int n,int c){
int p=get(lst,n);
if(!ch[p][c]){
++cnt;
fail[cnt]=ch[get(fail[p],n)][c];
len[cnt]=len[p]+2;ch[p][c]=cnt;
siz[cnt]=siz[fail[cnt]]+1;
}
lst=ch[p][c];
}
int main(){
scanf("%s",s+1);
int n=strlen(s+1);
cnt++;len[1]=-1;fail[0]=1;s[0]=-1;
for(int i=1,res=0;i<=n;i++){
s[i]=(s[i]-97+res)%26;
extend(i,s[i]);
printf("%d ",res=siz[lst]);
}
}

后缀数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=6e5+10;
typedef long long ll;
char s[N];
int rk[N],tp[N],sa[N],n,m,t[N],ht[N];
void suf(){
for(int i=1;i<=m;i++)t[i]=0;
for(int i=1;i<=n;i++)t[rk[i]]++;
for(int i=1;i<=m;i++)t[i]+=t[i-1];
for(int i=n;i;i--)sa[t[rk[tp[i]]]--]=tp[i];
}
ll a[N],stmx[N][25],stmi[N][25];
int b[N],stk[N],top,L[N],R[N];
ll qmax(int l,int r){
int len=b[r-l+1];
return max(stmx[l][len],stmx[r-(1<<len)+1][len]);
}
ll qmin(int l,int r){
int len=b[r-l+1];
return min(stmi[l][len],stmi[r-(1<<len)+1][len]);
}
ll ans1[N],ans2[N];
int main(){
scanf("%d%s",&n,s+1);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
rk[i]=s[i]-'a'+1;tp[i]=i;
}
m=26;
suf();
for(int w=1,p=0;p!=n;m=p,w<<=1){
p=0;
for(int i=n-w+1;i<=n;i++)tp[++p]=i;
for(int i=1;i<=n;i++)
if(sa[i]>w)tp[++p]=sa[i]-w;
suf();
memcpy(tp,rk,sizeof(rk));
rk[sa[1]]=p=1;
for(int i=2;i<=n;i++)
rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[sa[i]+w]==tp[sa[i-1]+w])?p:++p;
}
for(int i=1,j,k=0;i<=n;i++){
if(k)--k;
j=sa[rk[i]-1];
while(s[i+k]==s[j+k])k++;
ht[rk[i]]=k;
stmx[i][0]=stmi[i][0]=a[sa[i]];
}
for(int j=1;(1<<j)<=n;j++){
b[1<<j]=j;
for(int i=1;i+(1<<j)-1<=n;i++){
stmx[i][j]=max(stmx[i][j-1],stmx[i+(1<<j-1)][j-1]);
stmi[i][j]=min(stmi[i][j-1],stmi[i+(1<<j-1)][j-1]);
}
}
for(int i=1;i<=n;i++)if(!b[i])b[i]=b[i-1];
for(int i=2;i<=n;i++){
while(top&&ht[stk[top]]>ht[i])R[stk[top]]=i-1,top--;
stk[++top]=i;
}
while(top)R[stk[top]]=n,top--;
for(int i=n;i>1;i--){
while(top&&ht[stk[top]]>=ht[i])L[stk[top]]=i+1,top--;
stk[++top]=i;
}
memset(ans2,-0x3f,sizeof(ans2));
while(top)L[stk[top]]=2,top--;
for(int i=2;i<=n;i++){
ll ls=i-L[i]+1,rs=R[i]-i+1;
ans1[0]+=ls*rs;
ans1[ht[i]+1]-=ls*rs;
ls=qmax(L[i]-1,i-1);rs=qmax(i,R[i]);
ans2[ht[i]]=max(ans2[ht[i]],ls*rs);
ls=qmin(L[i]-1,i-1);rs=qmin(i,R[i]);
ans2[ht[i]]=max(ans2[ht[i]],ls*rs);
}
for(int i=1;i<=n;i++)ans1[i]+=ans1[i-1];
for(int i=n;~i;i--){
if(!ans1[i])ans2[i]=0;
if(ans1[i+1])ans2[i]=max(ans2[i],ans2[i+1]);
}
for(int i=0;i<n;i++)printf("%lld %lld\n",ans1[i],ans2[i]);
return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include<cstdio>
#include<cstring>
#include<algorithm>
const int MAX_LEN = 1e6 + 5;
const int SIGMA = 128;
struct SA{
int str[MAX_LEN << 1];
int type[MAX_LEN << 1];
int LMS[MAX_LEN << 1];
int sa[MAX_LEN];
int rank[MAX_LEN];
int cnt[MAX_LEN];
int cur[MAX_LEN];
void inducedsort(int *str, int len, int sigma, int *LMS, int LMSC, int *type) {
#define PUSH_S(x) (sa[cur[str[x]]--] = x)
#define PUSH_L(x) (sa[cur[str[x]]++] = x)
std::fill_n(sa, len, -1), std::fill_n(cnt, sigma, 0);
for(register int i = 0; i < len; ++i) ++cnt[str[i]];
for(register int i = 1; i < sigma; ++i) cnt[i] += cnt[i - 1];
for(register int i = 0; i < sigma; ++i) cur[i] = cnt[i] - 1;
for(register int i = LMSC - 1; i >= 0; --i) PUSH_S(LMS[i]);
for(register int i = 1; i < sigma; ++i) cur[i] = cnt[i - 1];
for(register int i = 0; i < len; ++i) sa[i] > 0 && type[sa[i] - 1] == 0 && PUSH_L(sa[i] - 1);
for(register int i = 0; i < sigma; ++i) cur[i] = cnt[i] - 1;
for(register int i = len - 1; i >= 0; --i) sa[i] > 0 && type[sa[i] - 1] && PUSH_S(sa[i] - 1);
#undef PUSH_S
#undef PUSH_L
}
void sais(int *str, int len, int sigma, int *LMS, int *type) {
type[len - 1] = 1;
for(register int i = len - 2; i >= 0; --i) type[i] = (str[i] == str[i + 1] ? type[i + 1] : str[i] < str[i + 1]);
int LMSC = 0;
rank[0] = -1;
for(register int i = 1; i < len; ++i) rank[i] = (type[i] && !type[i - 1]) ? (LMS[LMSC] = i, LMSC++) : -1;
inducedsort(str, len, sigma, LMS, LMSC, type);
int tot = -1;
int *s1 = str + len;
for(register int i = 0, now, last; i < len; ++i) {
if(-1 == (now = rank[sa[i]])) continue;
if(tot < 1 || LMS[now + 1] - LMS[now] != LMS[last + 1] - LMS[last]) ++tot;
else for(register int j = LMS[now], k = LMS[last]; j < LMS[now + 1]; ++j, ++k) if((str[j] << 1 | type[j]) != (str[k] << 1 | type[k])) {
++tot;
break;
}
s1[last = now] = tot;
}
if(tot + 1 < LMSC) sais(s1, LMSC, tot + 1, LMS + LMSC, type + len);
else for(register int i = 0; i < LMSC; ++i) sa[s1[i]] = i;
for(register int i = 0; i < LMSC; ++i) s1[i] = LMS[sa[i]];
inducedsort(str, len, sigma, s1, LMSC, type);
}
void Getsa(char *s,int *ans){
int len = strlen(s);
for(register int i = 0; i < len; ++i) {
str[i] = s[i];
}
sais(str, len + 1, SIGMA, LMS, type);
for(register int i = 0; i < len; ++i) ans[i+1]=sa[i + 1] + 1;
}
}S;
const int N=1e6+10;
char s[N];
int sa[N];
int main(){
scanf("%s",s+1);
int n=strlen(s+1);
S.Getsa(s+1,sa);
for(int i=1;i<=n;i++)
printf("%d ",sa[i]);
}

树上后缀数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<iostream>
#include<vector>
#include<utility>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N=5e5+5,base1=20041001,base2=20040607,md1=167772161,md2=104857601;
int n,F[19][N],head[N],cnt,dep[N],x[N],y[N],c[N],sa[N],hx1[N],hx2[N];
char s[N];
vector<int>vec[N],vd[N];
vector<pair<int,int>>pd[N];
struct edge{
int to,nxt;
}e[N];
void dfs(int now){
hx1[now]=((LL)hx1[F[0][now]]*base1+s[now])%md1,hx2[now]=((LL)hx2[F[0][now]]*base2+s[now])%md2;
for(int i=head[now];i;i=e[i].nxt)dep[e[i].to]=dep[now]+1,F[0][e[i].to]=now,dfs(e[i].to);
}
void ssort(){
int m=256;
for(int i=1;i<=m;++i)c[i]=0;
for(int i=1;i<=n;++i)++c[x[i]=s[i]];
for(int i=1;i<=m;++i)c[i]+=c[i-1];
for(int i=n;i;--i)sa[c[x[i]]--]=i;
for(int k=0;1<<k<=n;++k){
for(int i=1;i<=n;++i)vec[i].clear();
for(int i=1;i<=n;++i)if(F[k][i])vec[F[k][i]].push_back(i);
int p=0;
for(int i=1;i<=1<<k;++i)
for(int j:vd[i])y[++p]=j;
for(int i=1;i<=n;++i)
for(int j:vec[sa[i]])y[++p]=j;
for(int i=1;i<=m;++i)c[i]=0;
for(int i=1;i<=n;++i)++c[x[i]];
for(int i=1;i<=m;++i)c[i]+=c[i-1];
for(int i=n;i;--i)sa[c[x[y[i]]]--]=y[i];
std::swap(x,y);
x[sa[1]]=p=1;
for(int i=2;i<=n;++i)
x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[F[k][sa[i]]]==y[F[k][sa[i-1]]]?p:++p);
if(p==n)break;
m=p;
}
}
inline int cmp(int a,int b){return F[0][a]!=F[0][b]?x[F[0][a]]<x[F[0][b]]:a<b;}
int main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>n;
for(int i=2;i<=n;++i){
int&f=F[0][i];
cin>>f;
e[++cnt]=(edge){i,head[f]},head[f]=cnt;
}
cin>>(s+1);
dfs(dep[1]=1);
for(int i=1;i<=n;++i)vd[dep[i]].push_back(i);
for(int i=1;i<19;++i)
for(int j=1;j<=n;++j)
F[i][j]=F[i-1][F[i-1][j]];
ssort();
int lst=1;
for(int i=2;i<=n+1;++i)
if(hx1[sa[i]]!=hx1[sa[i-1]]||hx2[sa[i]]!=hx2[sa[i-1]])pd[dep[sa[i-1]]].emplace_back(lst,i-1),lst=i;
for(int i=1;i<=n;++i)
for(auto j:pd[i]){
sort(sa+j.first,sa+j.second+1,cmp);
for(int k=j.first;k<=j.second;++k)
x[sa[k]]=k;
}
for(int i=1;i<=n;++i)cout<<sa[i]<<' ';
return 0;
}

后缀自动机

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define rint register int
const int N=2e6+10;
char s[N];
int ch[N][26],len[N],fa[N],lst,cnt;
long long siz[N];
void extend(rint c){
rint p=lst;rint np=lst=++cnt;
len[np]=len[p]+1;siz[np]=1;
for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;
if(!p)fa[np]=1;
else {
rint t=ch[p][c];
if(len[t]==len[p]+1)fa[np]=t;
else {
rint ct=++cnt;
memcpy(ch[ct],ch[t],sizeof(ch[t]));
fa[ct]=fa[t];fa[t]=fa[np]=ct;
len[ct]=len[p]+1;
for(;p&&ch[p][c]==t;p=fa[p])ch[p][c]=ct;
}
}
}
int a[N],t[N];
int main(){
rint n;
scanf("%s",s+1);
n=strlen(s+1);lst=++cnt;
for(rint i=1;i<=n;i++)extend(s[i]-'a');
for(rint i=1;i<=cnt;i++)t[len[i]]++;
for(rint i=1;i<=n;i++)t[i]+=t[i-1];
for(rint i=1;i<=cnt;i++)a[t[len[i]]--]=i;
long long res=0;
for(rint i=cnt;i;i--){
siz[fa[a[i]]]+=siz[a[i]];
if(siz[a[i]]>1)res=max(res,siz[a[i]]*len[a[i]]);
}
printf("%lld\n",res);
return 0;
}

后缀平衡树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include<bits/stdc++.h>
using namespace std;
const int N=8e5+5,M=3e6+6;
char buf[M+5],*p1,*p2,c;
#define gc (p1==p2&&(p2=(p1=buf)+fread(buf,1,M,stdin),p1==p2)?EOF:*p1++)
inline void read(int &x){
for(c=gc;c<'0'||c>'9';c=gc);
for(x=0;c>='0'&&c<='9';x=x*10+c-'0',c=gc);
}
char s[N],st[M];
int t[N][2],Q,n,tot,mask,ans,rt,*nrb,sz[N],val,rb[N],bt;
typedef long double ld;
ld d[N][3],lv,rv;
inline void read(char *s){
while(!isalpha(c=gc));n=0;
do s[++n]=c;while(isalpha(c=gc));
}
inline void decode(char *st,int mask){
for(int i=0;i<n;i++)
mask=(mask*131+i)%n,swap(st[i],st[mask]);
}
void Rem(int x){if(t[x][0])Rem(t[x][0]);rb[++bt]=x;if(t[x][1])Rem(t[x][1]);}
int Mak(int l,int r,ld L,ld R){
if(l>r)return 0;int md=l+r>>1,x=rb[md];ld Md=(L+R)/2;
t[x][0]=Mak(l,md-1,L,Md),t[x][1]=Mak(md+1,r,Md,R);
d[x][0]=L,d[x][1]=Md,d[x][2]=R,sz[x]=sz[t[x][0]]+sz[t[x][1]]+1;
return x;
}
void Reb(int &x){bt=0,Rem(x),x=Mak(1,bt,d[x][0],d[x][2]);}
void Ins(int &x,ld L,ld R){
if(!x){
x=val,t[x][0]=t[x][1]=0,sz[x]=1;
d[x][0]=L,d[x][1]=(L+R)/2,d[x][2]=R;
}else{
int k=s[x]==s[val]?d[x-1][1]<d[val-1][1]:s[x]<s[val];
Ins(t[x][k],d[x][k],d[x][k+1]),++sz[x];
if((rand()+rand())%sz[x]==0)nrb=&x;
}
}
void Del(int &x){
int k=d[val][1]>=d[x][1],&y=t[x][k],&z=t[x][!k];
if(y){
Del(y);if(k&&val){
t[val][0]=t[x][0],t[val][1]=t[x][1];
d[val][0]=d[x][0],d[val][1]=d[x][1];
d[val][2]=d[x][2],x=val,val=0;
}sz[x]=sz[y]+sz[z]+1;
}else {val=(x==val)?0:x;x=z;}
return;
}
void Ins(int x){val=x,nrb=NULL,Ins(rt,1,1e18);if(nrb)Reb(*nrb);}
int lwb(int x){
if(!x)return 0;int i,k=0;
for(i=0;i<n;++i)
if(i==x){k=1;break;}
else if(s[x-i]!=st[n-i]){k=s[x-i]<st[n-i];break;}
return k?sz[t[x][0]]+1+lwb(t[x][1]):lwb(t[x][0]);
}
int main(){
srand(1919810u^time(0));
read(Q),read(s);int i,j,k;
for(i=1;i<=n;++i)Ins(i);tot=n;
while(Q--){
read(st);
if(st[1]=='A'){
read(st),decode(st+1,mask);
for(i=1;i<=n;++i)s[++tot]=st[i],Ins(tot);
}else if(st[1]=='Q'){
read(st),decode(st+1,mask),ans=lwb(rt),++st[1];
ans=lwb(rt)-ans;printf("%d\n",ans),mask^=ans;
}else for(i=1,read(n);i<=n;++i)val=tot--,Del(rt);
}
return 0;
}

杂项

LGV引理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=410;
const int p=998244353;
typedef long long ll;
ll a[N][N],fac[2100000],inv[2100000];
ll fpow(ll a,ll b){
ll res;
for(res=1;b;b>>=1,a=a*a%p)
if(b&1)res=res*a%p;
return res;
}
ll C(int n,int m){
if(n<m)return 0;
return fac[n]*inv[m]%p*inv[n-m]%p;
}
int x[N],y[N];
ll calc(int n){
ll res=1;
for(int i=1;i<=n;i++){
if(a[i][i]==0){
for(int j=i+1;j<=n;j++){
if(a[j][i]){
swap(a[i],a[j]);
res=-res;
goto L;
}
}
return 0;
}L:;
for(int j=i+1;j<=n;j++){
ll r=a[j][i]*fpow(a[i][i],p-2)%p;
for(int k=i;k<=n;k++)
a[j][k]=(a[j][k]-a[i][k]*r%p+p)%p;
}
}
for(int i=1;i<=n;i++)res=res*a[i][i]%p;
return (res+p)%p;
}
int main(){
fac[0]=fac[1]=inv[0]=inv[1]=1;
for(int i=2;i<=2e6;i++){
fac[i]=(ll)fac[i-1]*i%p;
inv[i]=p-(ll)p/i*inv[p%i]%p;
}
for(int i=2;i<=2e6;i++)
inv[i]=inv[i-1]*inv[i]%p;
int T;
scanf("%d",&T);
while(T--){
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=m;i++)scanf("%d%d",&x[i],&y[i]);
for(int i=1;i<=m;i++)
for(int j=1;j<=m;j++)
a[i][j]=C(y[j]-x[i]+n-1,n-1);
printf("%lld\n",calc(m));
}
return 0;
}

树哈希

h(x)=1+vson(x)f(h(v))h(x)=1+\sum_{v\in son(x)}f(h(v))

1
2
3
4
5
6
7
ll h(ll x) {
return x * x * x * 1237123 + 19260817;
}
ll f(ll x) {
ll cur = h(x & ((1 << 31) - 1)) + h(x >> 31);
return cur;
}

spfa优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include<ctime>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define reg register
#define uni unsigned
using namespace std;
const uni N=100010,W=5,WT=10,INF=0x7fffffff; //W为取首位若干项数,WT为随机次数
uni n,m,s,d[N<<4],a[N],b[N<<2][3],t[N],seed=time(0),w[N],OnF,st[10];
bool vis[N];
char BuF[1<<24],*InF=BuF,WuF[1<<23];
uni read(){
reg uni x=0;
for(;47>*InF||*InF>58;++InF);
for(;47<*InF&&*InF<58;x=x*10+(*InF++^48));
return(x);
}
void write(uni x){
if(!x){
WuF[OnF++]=48;
}
for(;x;x/=10){
st[++st[0]]=x%10+48;
}
for(;st[0];WuF[OnF++]=st[st[0]--]);
WuF[OnF++]=' ';
}
uni rnd(uni l,uni r){
seed^=seed<<17;seed^=seed>>15;seed^=seed<<3;
return(seed%(r-l+1)+l);
}
void spfa(uni s){
memset(w,127,sizeof(w));
w[d[0]=s]=0;
for(reg uni h=0,t=0;h<=t;++h){
for(reg uni i=0;i<=WT&&h+W+W+1<t;++i){
reg uni x=rnd(h+W+1,t-W-1);
if(w[d[h]]>w[d[x]]){
swap(d[h],d[x]); //随机判断交换
}
}
for(reg uni i=0;i<=W&&h+i+i<t;++i){
if(w[d[h]]>w[d[h+i+1]]){
swap(d[h],d[h+i+1]); //与队首判断
}
if(w[d[h]]>w[d[t-i]]){
swap(d[h],d[t-i]); //与队尾判断
}
}
for(reg uni i=a[d[h]],mi=INF;i;i=b[i][0]){
reg uni nxt=b[i][1];
if(w[nxt]>w[d[h]]+b[i][2]){
w[nxt]=w[d[h]]+b[i][2];
if(!vis[nxt]){
d[++t]=nxt;
if(mi>w[nxt]){
if(mi<INF){
swap(d[t],d[t-1]); //保持单次松弛最小值在队尾
}
mi=w[nxt];
}
vis[nxt]=1;
}
}
}
vis[d[h]]=0;
}
}
int main(){
fread(BuF,1,1<<24,stdin);
n=read();m=read();s=read();
for(reg uni last=0;m;--m){
reg uni x=read(),y=read(),z=read();
b[++last][0]=a[x];
b[a[x]=last][1]=y;
b[last][2]=z;
}
spfa(s);
for(reg uni *i=w+1,*r=w+n+1;i!=r;++i){
write(*i);
}
fwrite(WuF,1,OnF,stdout);
fclose(stdin);
fclose(stdout);
return(0);
}

2-sat

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define reset(a) memset((a),0,sizeof((a)))
using namespace std;
static char buf[100000],*pa,*pd;
#define gc pa==pd&&(pd=(pa=buf)+fread(buf,1,100000,stdin),pa==pd)?EOF:*pa++
inline int read(){
register int x(0);register char c(gc);
while(c>'9'||c<'0')c=gc;
while(c>='0'&&c<='9')x=x*10+c-48,c=gc;
return x;
}
const int N=4001000;
struct edge{
int to,next;
}e[N];
int head[N],tot,HEAD[N];
int n,m,cnt,turn[N],belong[N],vis[N];
void add(int x,int y){
e[++tot].to=y;e[tot].next=head[x];head[x]=tot;
e[++tot].to=x;e[tot].next=HEAD[y];HEAD[y]=tot;
}
void dfs1(int u){
int i;
vis[u]=1;
for(i=head[u];i;i=e[i].next)
if(!vis[e[i].to])
dfs1(e[i].to);
turn[++cnt]=u;
}
void dfs2(int u){
belong[u]=cnt;vis[u]=1;
for(int i=HEAD[u];i;i=e[i].next)
if(!vis[e[i].to])
dfs2(e[i].to);
}
void kosaraju(){
for(int i=1;i<=2*n;i++)
if(!vis[i])dfs1(i);
reset(vis);cnt=0;
for(int i=2*n;i>=1;i--)
if(!vis[turn[i]])
cnt++,dfs2(turn[i]);
}
int main(){
n=read();m=read();
register int i,x,y,f1,f2;
for(i=1;i<=m;i++){
x=read();f1=read();y=read();f2=read();
add(x+n*(f1&1),y+n*(f2^1));
add(y+n*(f2&1),x+n*(f1^1));
}
kosaraju();
for(i=1;i<=n;i++)
if(belong[i]==belong[i+n]){
cout<<"IMPOSSIBLE";return 0;
}
cout<<"POSSIBLE\n";
for(i=1;i<=n;i++){
cout<<(belong[i]>belong[i+n])<<' ';
}
return 0;
}

差分约束

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
const int N=5e3+10;
struct Edge{
int to,nxt,val;
}e[N<<1];
int h[N],idx;
void Ins(rint x,rint y,rint z){
e[++idx].to=y;e[idx].val=z;
e[idx].nxt=h[x];h[x]=idx;
}
int dis[N],cnt[N],n,m;
bool inq[N];
queue<int> q;
bool spfa(rint s){
memset(dis,0x3f,sizeof(dis));
dis[s]=0;
q.push(s);inq[s]=0;
while(!q.empty()){
rint u=q.front();q.pop();inq[u]=0;
for(rint i=h[u];i;i=e[i].nxt){
rint v=e[i].to;
if(dis[v]>dis[u]+e[i].val){
dis[v]=dis[u]+e[i].val;
if(!inq[v]){
inq[v]=1;
q.push(v);
if(++cnt[v]>n+1)return 0;
}
}
}
}
return 1;
}
int main(){
scanf("%d%d",&n,&m);
for(rint i=1;i<=m;i++){
rint x,y,z;
scanf("%d%d%d",&x,&y,&z);
Ins(y,x,z);
}
for(rint i=1;i<=n;i++)
Ins(0,i,0);
if(spfa(0)){
for(rint i=1;i<=n;i++)
printf("%d ",dis[i]);
}else {
printf("NO");
}
return 0;
}

prufer 序列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include<bits/stdc++.h>
using namespace std;
const int maxn=5000005;
int read(){//快读
int res,f=1;
char c;
while((c=getchar())<'0'||c>'9')
if(c=='-')f=-1;
res=c-48;
while((c=getchar())>='0'&&c<='9')
res=res*10+c-48;
return res*f;
}
int n,type;
int f[maxn],d[maxn],topss[maxn],x[maxn],top,maxv[maxn],p[maxn];
long long make_prufer(int n){
int headss=n;long long ans=0;
for(int i=1;i<n;++i)f[i]=read(),++d[f[i]];
for (int i=1,j=1;i<=n-2;++i,++j){
while (d[j])++j;
p[i]=f[j];
while (i<=n-2&&!--d[p[i]]&&p[i]<j)p[i+1]=f[p[i]],++i;//代替朴素算法中的if
}
for(int i=1;i<=n-2;++i)ans^=(long long)i*p[i];
return ans;
}
long long make_fa(int n){
long long ans=0;
for(int i=1;i<=n-2;++i)p[i]=read(),++d[p[i]];
p[n-1]=n;
for(int i=1,j=1;i<n;++i,++j){
while(d[j])++j;
f[j]=p[i];
while(i<n&&!--d[p[i]]&&p[i]<j)f[p[i]]=p[i+1],++i;
}
for(int i=1;i<n;++i)ans^=(long long)i*f[i];
return ans;
}
int main(){
n=read(),type=read();
if(type==1)printf("%lld",make_prufer(n));
if(type==2)printf("%lld",make_fa(n));
return 0;
}

每一个prufer序列和一个有标号无根树对应

竞赛图相关

兰道定理,用于判断竞赛图

其中 sis_i 是每个点出度从小到大排序,且 k=nk=n 时必须取等号

1kn,i=1ksi(k2)\forall 1 \le k \le n,\sum_{i=1}^{k}s_i\ge \binom{k}{2}

若竞赛图有环,则一定存在三元环

一般图最大匹配

带花树算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include<bits/stdc++.h>
using namespace std;
#define I inline int
#define V inline void
#define FOR(i,a,b) for(int i=a;i<=b;i++)
#define REP(u) for(int i=h[u],v;v=e[i].t,i;i=e[i].n)
const int N=1e3+1,M=1e5+1;
queue<int>q;
int n,m,tot,qwq,ans;
int h[N],lk[N],tag[N],fa[N],pre[N],dfn[N];
struct edge{int t,n;}e[M];
V link(int x,int y){lk[x]=y,lk[y]=x;}
V add_edge(int x,int y){
if(!lk[x]&&!lk[y])link(x,y),ans++;
e[++tot]=(edge){y,h[x]},h[x]=tot;
e[++tot]=(edge){x,h[y]},h[y]=tot;
}
V rev(int x){if(x)rev(x[pre][lk]),link(x,pre[x]);}
I find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
I lca(int x,int y){
for(qwq++;;x=x[lk][pre],swap(x,y))
if(dfn[x=find(x)]==qwq)return x;
else if(x)dfn[x]=qwq;
}
V shrink(int x,int y,int p){
for(;find(x)!=p;x=pre[y]){
pre[x]=y,y=lk[x],fa[x]=fa[y]=p;
if(tag[y]==2)tag[y]=1,q.push(y);
}
}
I blossom(int u){
FOR(i,1,n)tag[i]=pre[i]=0,fa[i]=i;
tag[u]=1,q=queue<int>(),q.push(u);
for(int p;!q.empty();q.pop())REP(u=q.front())
if(tag[v]==1)
p=lca(u,v),shrink(u,v,p),shrink(v,u,p);
else if(!tag[v]){
pre[v]=u,tag[v]=2;
if(!lk[v])return rev(v),1;
else tag[lk[v]]=1,q.push(lk[v]);
}
return 0;
}
int main(){
scanf("%d%d",&n,&m);
for(int x,y;m--;add_edge(x,y))scanf("%d%d",&x,&y);
FOR(i,1,n)ans+=!lk[i]&&blossom(i);
cout<<ans<<'\n';
FOR(i,1,n)cout<<lk[i]<<' ';
return 0;
}

一般图最大权匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200


#include<bits/stdc++.h>
#define For(i,a,b) for(int i=(a);i<=(b);++i)
#define Rep(i,a,b) for(int i=(a);i>=(b);--i)
using namespace std;
inline int read()
{
char c=getchar();int x=0;bool f=0;
for(;!isdigit(c);c=getchar())f^=!(c^45);
for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+(c^48);
if(f)x=-x;return x;
}

#define fi first
#define se second
#define pb push_back
#define mkp make_pair
typedef pair<int,int>pii;
typedef vector<int>vi;

#define maxn 4000005
#define inf 0x3f3f3f3f

#define N 805
int n,m;
struct edge{
int u,v,w;
}e[N][N];
int lab[N],sl[N],q[N],hd,tl,rt[N],col[N],mat[N],frm[N][N],fa[N];
int vis[N],tim;
vi p[N];
int d(edge x){return lab[x.u]+lab[x.v]-e[x.u][x.v].w*2;}
void upd(int u,int v){if(!sl[v]||d(e[u][v])<d(e[sl[v]][v]))sl[v]=u;}
void updall(int u){
sl[u]=0;
For(i,1,n)if(e[i][u].w>0 && rt[i]!=u && col[rt[i]]==0) upd(i,u);
}
void push(int u){
if(u<=n) q[++tl]=u;
else for(int x:p[u]) push(x);
}
void setrt(int u,int r){
rt[u]=r;
if(u>n) for(int x:p[u]) setrt(x,r);
}
int rot(int b,int x){
int pos=find(p[b].begin(),p[b].end(),x)-p[b].begin();
if(pos%2==0)return pos;
reverse(p[b].begin()+1,p[b].end());
return p[b].size()-pos;
}
void match(int u,int v){
mat[u]=e[u][v].v;
if(u<=n)return;
auto w=e[u][v];
int x=frm[u][w.u],pos=rot(u,x);
For(i,0,pos-1) match(p[u][i],p[u][i^1]);
match(x,v);
rotate(p[u].begin(),p[u].begin()+pos,p[u].end());
}
void aug(int u,int v){
assert(u&&v);
int t=rt[mat[u]];
match(u,v);
if(t) match(t,rt[fa[t]]),aug(rt[fa[t]],t);
}

int lca(int u,int v){
++tim;
while(u||v){
if(vis[u]==tim)return u;
vis[u]=tim;
u=rt[mat[u]];
if(u)u=rt[fa[u]];
if(!u)swap(u,v);
}return 0;
}

int newn(){
int x=n+1;
while(x<=m&&rt[x])++x; if(x>m)++m;
p[x].clear(),lab[x]=mat[x]=rt[x]=col[x]=0;
return x;
}
void blossom(int u,int v,int r){
int x=newn();
mat[x]=mat[r];
p[x].pb(r);
for(int i=u;i!=r;i=rt[fa[i]]){
p[x].pb(i),p[x].pb(i=rt[mat[i]]);
push(i);
}
reverse(p[x].begin()+1,p[x].end());
for(int i=v;i!=r;i=rt[fa[i]]){
p[x].pb(i),p[x].pb(i=rt[mat[i]]);
push(i);
}
setrt(x,x);
For(i,1,m)e[x][i].w=e[i][x].w=0,frm[x][i]=0;
for(int t:p[x]){
For(i,1,m){
if(!e[x][i].w||d(e[t][i])<d(e[x][i]))
e[x][i]=e[t][i],e[i][x]=e[i][t];
}
For(i,1,n)if(frm[t][i])frm[x][i]=t;
}
updall(x);
}
void expand(int x){
for(int u:p[x]) setrt(u,u);
int u=frm[x][e[x][fa[x]].u],pos=rot(x,u);
for(int i=0;i<pos;i+=2){
int a=p[x][i],b=p[x][i+1];
fa[a]=e[b][a].u;
col[a]=1,col[b]=0;
sl[a]=0;
updall(b),push(b);
}
col[u]=1;
fa[u]=fa[x];
for(int i=pos+1;i<p[x].size();++i)
col[p[x][i]]=-1,updall(p[x][i]);
rt[x]=0;
}

bool path(edge e){
int u=rt[e.u],v=rt[e.v];
assert(!d(e));
if(col[v]==-1){
fa[v]=e.u;
col[v]=1;
int t=rt[mat[v]];
sl[v]=sl[t]=col[t]=0;
push(t);
}
else if(!col[v]){
int t=lca(u,v);
if(!t)return aug(u,v),aug(v,u),1;
else blossom(u,v,t);
}
return 0;
}
bool bfs()
{
For(i,0,m) col[i]=-1,sl[i]=0;
hd=1,tl=0;
For(i,1,m)if(rt[i]==i&&!mat[i])fa[i]=col[i]=0,push(i);
if(!tl)return 0;
while(1){
while(hd<=tl){
int u=q[hd++];
For(v,1,n)
if(e[u][v].w>0&&rt[u]!=rt[v]){
if(d(e[u][v])) upd(u,rt[v]);
else if(path(e[u][v])) return 1;
}
}
int tmp=inf;
For(i,1,n) if(col[rt[i]]==0) tmp=min(tmp,lab[i]);
For(i,n+1,m) if(rt[i]==i && col[i]==1) tmp=min(tmp,lab[i]>>1);
For(i,1,m) if(rt[i]==i && sl[i] && col[i]!=1) tmp=min(tmp,d(e[sl[i]][i])>>(col[i]+1));
For(i,1,n){
if(col[rt[i]]==0) lab[i]-=tmp;
if(col[rt[i]]==1) lab[i]+=tmp;
if(lab[i]<=0) return 0;
}
For(i,n+1,m){
if(rt[i]==i){
if(col[i]==0) lab[i]+=tmp*2;
if(col[i]==1) lab[i]-=tmp*2;
}
}
hd=1,tl=0;
For(i,1,m)
if(rt[i]==i && sl[i] && rt[sl[i]]!=i && !d(e[sl[i]][i]) && path(e[sl[i]][i])) return 1;
For(i,n+1,m)
if(rt[i]==i && col[i]==1 && !lab[i]) expand(i);
}
}

signed main()
{
n=read(),m=n;
int M=read(),mx=0;
For(i,1,n)For(j,1,n)e[i][j]=(edge){i,j,0},frm[i][j]=0;
For(i,1,M){
int u=read(),v=read(),w=read();
mx=max(mx,w);
e[u][v].w=e[v][u].w=max(e[u][v].w,w);
}
For(i,1,n)frm[i][i]=i;
For(i,1,n)lab[i]=mx,rt[i]=i,p[i].clear();
while(bfs());
long long res=0;
For(i,1,n)if(mat[i]&&i<mat[i])res+=e[i][mat[i]].w;
cout<<res<<"\n";
For(i,1,n)cout<<mat[i]<<" ";
return 0;
}

最小树形图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include<bits/stdc++.h>
using namespace std;
#define rint register int
#define ll long long
#define rll register long long
#define db long double
const int N=1e4+10;
struct Edge{
int fr,to,val;
}e[N<<1];
int Min[N],fa[N],col[N],vis[N],exis[N];
int main(){
rint n,m,r;
scanf("%d%d%d",&n,&m,&r);
for(rint i=1;i<=m;i++)
scanf("%d%d%d",&e[i].fr,&e[i].to,&e[i].val);
rint res=0;
for(rint i=1;i<=n;i++)
exis[i]=1;
while(1){
for(rint i=1;i<=n;i++)
col[i]=i,vis[i]=-1,Min[i]=1e9;
for(rint i=1;i<=m;i++){
if(e[i].fr==e[i].to)continue;
if(e[i].val<Min[e[i].to])
Min[e[i].to]=e[i].val,fa[e[i].to]=e[i].fr;
}
bool chk=1;
for(rint i=1;i<=n;i++){
if(i==r||!exis[i])continue;
if(Min[i]==1e9)return puts("-1"),0;
res+=Min[i];
rint x;
for(x=i;x!=r&&vis[x]==-1;x=fa[x])
vis[x]=i;
if(vis[x]==i){
chk=0;
for(rint t=fa[x];t!=x;t=fa[t])
col[t]=x;
}
}
if(chk)break;
for(rint i=1;i<=n;i++)
exis[i]=0;
for(rint i=1;i<=m;i++){
e[i].val-=Min[e[i].to],e[i].fr=col[e[i].fr],e[i].to=col[e[i].to];
exis[e[i].fr]=exis[e[i].to]=1;
}
r=col[r];
}
printf("%d\n",res);
return 0;
}

exgcd

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include<bits/stdc++.h>
using namespace std;
#define LL long long
int qread(){
int w=1,c,ret=0;
while((c=getchar())> '9'||c< '0')
w=(c=='-'?-1:1); ret=c-'0';
while((c=getchar())>='0'&&c<='9')
ret=ret*10+c-'0';
return ret*w;
}
LL exgcd(LL a,LL b,LL &x,LL &y){
LL d=a; if(b==0) x=1,y=0; else{
d=exgcd(b,a%b,y,x),y-=a/b*x;
}
return d;
}
int main(){
int T=qread();
while(T--){
LL a=qread(),b=qread(),c=qread(),x,y;
LL d=exgcd(a,b,x,y);
if(c%d!=0) puts("-1"); else{
x*=c/d,y*=c/d; LL p=b/d,q=a/d,k;
if(x<0) k=ceil((1.0-x)/p),x+=p*k,y-=q*k; else //将x提高到最小正整数
if(x>=0)k=(x-1)/p ,x-=p*k,y+=q*k; //将x降低到最小正整数
if(y>0){ //有正整数解
printf("%lld ",(y-1)/q+1); //将y减到1的方案数即为解的个数
printf("%lld ",x); //当前的x即为最小正整数x
printf("%lld ",(y-1)%q+1); //将y取到最小正整数
printf("%lld ",x+(y-1)/q*p); //将x提升到最大
printf("%lld ",y); //特解即为y最大值
} else{ //无整数解
printf("%lld " ,x); //当前的x即为最小的正整数x
printf("%lld",y+q*(LL)ceil((1.0-y)/q)); //将y提高到正整数
}
puts("");
}
}
return 0;
}

fwt

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include<cstdio>
const int N=1e6+10;
const int p=998244353;
int a[N],b[N],t1[N],t2[N];
#define rg register
void fwtor(int *a,int n,int typ){
for(rg int r=2,k=1;r<=n;k<<=1,r<<=1)
for(rg int i=0;i<n;i+=r)
for(rg int j=0;j<k;j++)
a[i+j+k]=((a[i+j+k]+a[i+j]*typ)%p+p)%p;
}
void fwtand(int *a,int n,int typ){
for(rg int r=2,k=1;r<=n;k<<=1,r<<=1)
for(rg int i=0;i<n;i+=r)
for(rg int j=0;j<k;j++)
a[i+j]=((a[i+j]+a[i+j+k]*typ)%p+p)%p;
}
void fwtxor(int *a,int n,int typ){
for(rg int r=2,k=1;r<=n;k<<=1,r<<=1)
for(rg int i=0;i<n;i+=r)
for(rg int j=0;j<k;j++){
int x=a[i+j],y=a[i+j+k];
a[i+j]=(x+y)%p;
a[i+j+k]=(x-y+p)%p;
if(typ)a[i+j]=1ll*a[i+j]*typ%p,a[i+j+k]=1ll*a[i+j+k]*typ%p;
}
}
int main(){
int n;scanf("%d",&n);
n=1<<n;
for(rg int i=0;i<n;i++)scanf("%d",&t1[i]);
for(rg int i=0;i<n;i++)scanf("%d",&t2[i]);
for(rg int i=0;i<n;i++)a[i]=t1[i],b[i]=t2[i];
fwtor(a,n,1);fwtor(b,n,1);
for(rg int i=0;i<n;i++)a[i]=1ll*a[i]*b[i]%p;
fwtor(a,n,-1);
for(rg int i=0;i<n;i++)printf("%d ",a[i]);
puts("");
for(rg int i=0;i<n;i++)a[i]=t1[i],b[i]=t2[i];
fwtand(a,n,1);fwtand(b,n,1);
for(rg int i=0;i<n;i++)a[i]=1ll*a[i]*b[i]%p;
fwtand(a,n,-1);
for(rg int i=0;i<n;i++)printf("%d ",a[i]);
puts("");
for(rg int i=0;i<n;i++)a[i]=t1[i],b[i]=t2[i];
fwtxor(a,n,0);fwtxor(b,n,0);
for(rg int i=0;i<n;i++)a[i]=1ll*a[i]*b[i]%p;
fwtxor(a,n,p+1>>1);
for(rg int i=0;i<n;i++)printf("%d ",a[i]);
return 0;
}

算法总览
https://suzipei.github.io/2023/02/22/template/
作者
Su_Zipei
发布于
2023年2月22日
许可协议