多项式板子

某一天发现自己的多项式板子好像全机房最慢唉。

于是找了些博客来学习卡常。

然后喜提目前全机房最快。

如何卡常

取模优化

众所周知 C++ 取模慢得出奇,想优化常数,自然而然想到在取模上下工夫。

曾经为了避免加减取模写了这样的东西:

1
2
3
const int mod = 998244353;
inline void inc(int &x, const int &y){if((x += y) >= mod) x -= mod;}
inline void dec(int &x, const int &y){if((x -= y) < 0) x += mod;}

后来发现加一加减一减比一比也挺慢,于是就有了这样的东西:

1
2
const int mod = 998244353;
inline int qmo(const int &x){return x + ((x >> 31) & mod);}

大概就是用了下对int型右移 31 位会使其全部变成符号位的性质。

预处理原根

每次还要根据长度重新处理蝴蝶变换的数组?这里要做大量的乘法和取模,不如直接一次处理出来优化常数。

预处理要 $O(n\log n)$?长度总是 2 的次幂,完全可以只用 $O(n)$。

清空与移动数组

memsetmemcpy无论是否在-O2下都有优秀的表现。

DFT时使用64位无符号整数

稍微算算,发现在模数是int范围内是不会爆unsigned long long的。

无符号整数在做加减乘的时候会稍微快一些。

模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#include <bits/stdc++.h>
namespace my_std {
using namespace std;
#define reg register
#define Rint register int
#define FOR(i, a, b) for (register int i = (a), ed_##i = (b); i <= ed_##i; ++i)
#define ROF(i, a, b) for (register int i = (a), ed_##i = (b); i >= ed_##i; --i)
typedef unsigned int u32;
typedef long long i64;
typedef unsigned long long u64;
#define Templ(T) template <typename T>
#ifdef LOCAL
#define gcr() getchar()
#define pcr(x) putchar(x)
#define F_In void()
#define F_Out void()
#else
static char InB[1 << 24], *In_s = InB;
static char OutB[1 << 24], *Out_s = OutB;
#define gcr() (*In_s ++)
#define pcr(x) (*Out_s ++ = x)
#define F_In (fread(InB, 1, 1 << 24, stdin))
#define F_Out (fwrite(OutB, 1, Out_s - OutB, stdout), Out_s = OutB)
#endif
inline int read() {
// reg int ans = 0, f = 1;
reg int ans = 0;
reg char c = gcr();
// while (!isdigit(c)) f ^= (c == '-'), c = gcr();
while (!isdigit(c)) c = gcr();
for (; isdigit(c); c = gcr()) ans = (ans << 1) + (ans << 3) + (c ^ 48);
// return f ? ans : -ans;
return ans;
}
Templ(T) inline void write(T x){
static char sta[20];
reg char *t = sta;
// if(x < 0) return pcr('-'), write(-x);
do{
*++t = (x % 10) ^ 48, x /= 10;
}while(x);
while(t != sta) pcr(*t--);
}
Templ(_Tp) inline int chkmin(_Tp &x, _Tp y) { return x > y ? x = y, 1 : 0; }
Templ(_Tp) inline int chkmax(_Tp &x, _Tp y) { return x < y ? x = y, 1 : 0; }
#define using_mod
const int mod = 998244353;
#ifdef using_mod
inline void inc(int &x, const int &y) { if ((x += y) >= mod) x -= mod; }
inline void dec(int &x, const int &y) { if ((x -= y) < 0) x += mod; }
inline int ksm(int x, int y) {
reg int res = 1;
for (; y; y >>= 1, x = 1ll * x * x % mod)
if (y & 1) res = 1ll * res * x % mod;
return res;
}
inline int qmo(const int &x){ return x + ((x >> 31) & mod); }
#endif
} // namespace my_std
using namespace my_std;

#define swap(x, y) (x ^= y ^= x ^= y)
// const int N = 270010;
const int N = 2100010;

int LMT = 1;
int rev[N], omg[N], inv[N];
int l2g[N];
inline void init(const int &n){
inv[1] = 1;
FOR(i, 2, n) inv[i] = (i64)(mod - mod / i) * inv[mod % i] % mod;
l2g[1] = 0;
FOR(i, 2, n << 1){
l2g[i] = l2g[i >> 1] + 1;
}
}

inline int get_len(const int &n){
return 1 << (l2g[n] + 1);
}

inline void poly_init(const int &n){
Rint l = 0;
while(LMT <= n) LMT <<= 1, ++ l;
FOR(i, 1, LMT - 1) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
reg const int t = ksm(3, (mod - 1) >> l);
omg[LMT >> 1] = 1;
FOR(i, (LMT >> 1) + 1, LMT - 1) omg[i] = (i64)omg[i - 1] * t % mod;
ROF(i, (LMT >> 1) - 1, 1) omg[i] = omg[i << 1];
LMT = l;
}

inline void DFT(int *a, const int &n){
// static i64 tmp[N];
static u64 tmp[N];
reg const int fix = LMT - l2g[n];
Rint t;
FOR(i, 0, n - 1) tmp[i] = a[rev[i] >> fix];
for(Rint i = 1; i < n; i <<= 1){
for(Rint j = 0; j < n; j += i << 1){
FOR(k, j, j + i - 1){
t = tmp[i + k] * omg[i + k - j] % mod;
// tmp[i + k] = qmo(tmp[k] - t);
// tmp[k] = qmo(tmp[k] - mod + t);
tmp[i + k] = tmp[k] + mod - t;
tmp[k] += t;
}
}
}
FOR(i, 0, n - 1) a[i] = tmp[i] % mod;
}
inline void IDFT(int *a, const int &n){
reverse(a + 1, a + n);
DFT(a, n);
reg const int bk = mod - (mod - 1) / n;
FOR(i, 0, n - 1) a[i] = (i64)a[i] * bk % mod;
}

//c <- a * b
inline void poly_mul(int *a, int *b, int *c, const int &deg){
static int tmp1[N], tmp2[N];
reg const int len = get_len(deg);
memcpy(tmp1, a, sizeof(int) * len), memcpy(tmp2, b, sizeof(int) * len);
DFT(tmp1, len), DFT(tmp2, len);
FOR(i, 0, len - 1) c[i] = (i64)tmp1[i] * tmp2[i] % mod;
IDFT(c, len);
memset(c + deg, 0, sizeof(int) * (len - deg));
}
//b <- a ^ (-1)
inline void poly_inv(int *a, int *b, const int &deg){
static int tmp[N];
if(deg == 1){
b[0] = ksm(a[0], mod - 2);
return;
}
poly_inv(a, b, (deg + 1) >> 1);
reg const int len = get_len(deg << 1);
memcpy(tmp, a, sizeof(int) * deg);
memset(tmp + deg, 0, sizeof(int) * (len - deg));
DFT(b, len), DFT(tmp, len);
FOR(i, 0, len - 1){
b[i] = (i64)qmo(2 - (i64)b[i] * tmp[i] % mod) * b[i] % mod;
}
IDFT(b, len);
memset(b + deg, 0, sizeof(int) * (len - deg));
}
//b(x) <- \d a(x) / \d x
inline void poly_der(int *a, int *b, const int &deg){
FOR(i, 0, deg - 2) b[i] = (i64)a[i + 1] * (i + 1) % mod;
b[deg - 1] = 0;
}
//b(x) <- \int a(x) \d x
inline void poly_int(int *a, int *b, const int &deg){
FOR(i, 1, deg - 1) b[i] = (i64)a[i - 1] * inv[i] % mod;
b[0] = 0;
}
//b <- \ln a
inline void poly_ln(int *a, int *b, const int &deg){
static int tmp[N];
poly_inv(a, tmp, deg);
poly_der(a, b, deg);
reg const int len = get_len(deg << 1);
DFT(b, len), DFT(tmp, len);
FOR(i, 0, len - 1) tmp[i] = (i64)tmp[i] * b[i] % mod;
IDFT(tmp, len);
poly_int(tmp, b, deg);
memset(b + deg, 0, sizeof(int) * (len - deg));
memset(tmp, 0, sizeof(int) * len);
}
//b <- \exp a
inline void poly_exp(int *a, int *b, const int &deg){
static int tmp[N];
if(deg == 1){
b[0] = 1;
return;
}
poly_exp(a, b, (deg + 1) >> 1);
poly_ln(b, tmp, deg);
reg const int len = get_len(deg << 1);
FOR(i, 0, len - 1){
if(i < deg) tmp[i] = qmo(a[i] - tmp[i]);
else tmp[i] = 0;
}
++tmp[0];
DFT(b, len), DFT(tmp, len);
FOR(i, 0, len - 1){
b[i] = (i64)tmp[i] * b[i] % mod;
}
IDFT(b, len);
memset(b + deg, 0, sizeof(int) * (len - deg));
memset(tmp + deg, 0, sizeof(int) * (len - deg));
}

int main() {
F_In;
F_Out;
return 0;
}

即使在你谷巨慢的评测机下都能够跑过机房其他人在你谷评测机还没那么慢的时候跑的 exp 。

参考资料

yurzhang’s blog