画图可以发现,多了一条边过后的图是串并联图。(暂时不确定)
然后我们考虑把问题变成,若生成树包含一条边\(e\),则使生成树权值乘上\(a_e\),否则乘上\(b_e\),求最终的生成树权值之和。我们只需要支持删去度数为\(1\)的点,同时删去和它相连的那条边;删去度数为2的点,把两条边合并为一条边;合并重边三种操作。
对于第一种操作,把答案乘上\(a_e\),并删去即可。对于第二种操作,把和这个点相邻的两条边记作\(e_1,e_2\),其中\(e_1\)连接\(u, v\),\(e_2\)连接\(u, w\)。则删去两条边,连接一条端点为\(v, w\),\(a_e = a_{e_1}a_{e_2}, b_e = a_{e_1}b_{e_2} + a_{e_2}b_{e_1}\)的边。对于第三种操作,把\(e_1, e_2\)合并为\(e\)时,\(a_e = a_{e_1}b_{e_2} + a_{e_2}b_{e_1}, b_e = b_{e_1}b_{e_2}\)。
代码如下:
#include <bits/stdc++.h>
using namespace std;
const int N = 500005;
const long long mod = 998244353ll;
template <class T>
void read (T &x) {
int sgn = 1;
char ch;
x = 0;
for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
if (ch == '-') ch = getchar(), sgn = -1;
for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
x *= sgn;
}
template <class T>
void write (T x) {
if (x < 0) putchar('-'), write(-x);
else if (x < 10) putchar(x + '0');
else write(x / 10), putchar(x % 10 + '0');
}
int n, m, deg[N];
long long ans = 1ll;
bool vis[N];
vector<int> g[N];
set<pair<int, int> > se;
map<pair<int, int>, pair<long long, long long> > id;
queue<int> que;
int main () {
read(n), read(m);
for (int i = 1; i <= n; i++) deg[i] = 0;
for (int i = 0; i < m; i++) {
int u, v;
read(u), read(v);
if (u > v) swap(u, v);
if (u == v) continue;
if (se.count(make_pair(u, v))) id[make_pair(u, v)].second = (id[make_pair(u, v)].second + 1ll) % mod;
else {
deg[u]++, deg[v]++;
g[u].push_back(v), g[v].push_back(u);
se.insert(make_pair(u, v));
id[make_pair(u, v)].first = id[make_pair(u, v)].second = 1ll;
}
}
for (int i = 1; i <= n; i++) {
vis[i] = false;
if (deg[i] <= 2) que.push(i);
}
while (!que.empty()) {
int u = que.front();
que.pop();
if (vis[u]) continue;
vis[u] = true;
vector<int> adj;
for (int i = 0; i < g[u].size(); i++) {
if (!vis[g[u][i]]) adj.push_back(g[u][i]);
}
if (adj.size() == 1) {
deg[u]--;
if (--deg[adj[0]] <= 2) que.push(adj[0]);
ans = ans * id[make_pair(min(u, adj[0]), max(u, adj[0]))].second % mod;
}
else if (adj.size() == 2) {
pair<int, int> e0(min(u, adj[0]), max(u, adj[0]));
pair<int, int> e1(min(u, adj[1]), max(u, adj[1]));
pair<int, int> e(min(adj[0], adj[1]), max(adj[0], adj[1]));
pair<long long, long long> pi((id[e0].second * id[e1].first + id[e0].first * id[e1].second) % mod, id[e0].second * id[e1].second % mod);
deg[u] -= 2;
if (se.count(e)) {
if (--deg[adj[0]] <= 2) que.push(adj[0]);
if (--deg[adj[1]] <= 2) que.push(adj[1]);
id[e] = make_pair(id[e].first * pi.first % mod, (id[e].first * pi.second + id[e].second * pi.first) % mod);
}
else {
g[adj[0]].push_back(adj[1]);
g[adj[1]].push_back(adj[0]);
se.insert(e);
id[e] = pi;
}
}
}
write(ans), putchar('\n');
return 0;
}
神奇的找规律题,感觉方向不对就很难找出来。
我们仍然考虑打表,把\(k, n\)比较小的情况打出来(记作\(f_{n, k}\))。然后我们发现固定\(n\),\(f_{n, k}\)的取值很少。再仔细观察,发现取值变化的点恰好是斐波那契数列上的数。
这启发我们对于每个数\(i > 1\),找先手第一步最小要取多少,才能保证他获胜。记该最小值为\(a_i\)。然后我们写出这个数列某些前缀:
\(1\)
\(1, 2\)
\(1, 2, 3\)
\(1, 2, 3, 1, 5\)
\(1, 2, 3, 1, 5, 1, 2, 8\)
\(\cdots\)
注意这里我们是找\(1\), \(2\), \(3\), \(5\), \(8\), \(\cdots\)的第一次出现位置为终止的前缀!
然后我们发现第\(i\)个前缀是第\(i - 1, i - 3, i - 5, \cdots\)个前缀拼接上\(F_i\)后的前缀。(\(F_0 = F_1 = 1\))
找到了这个规律后,我们就设\(g_{i, j, k}\)表示在第\(i\)个前缀的前\(j\)个数中\(\leq F_k\)的数字的个数。先通过\(dp\)预处理\(j\)就是第\(i\)个前缀的长度的情况,然后每次询问递归下去做即可。(通过类似线段树的证明,可以知道每次询问递归树的点数是\(O(\log^2 N)\)的。
代码如下:
#include <bits/stdc++.h>
using namespace std;
const int S = 505, T = 100005;
template <class T>
void read (T &x) {
int sgn = 1;
char ch;
x = 0;
for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
if (ch == '-') ch = getchar(), sgn = -1;
for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
x *= sgn;
}
template <class T>
void write (T x) {
if (x < 0) putchar('-'), write(-x);
else if (x < 10) putchar(x + '0');
else write(x / 10), putchar(x % 10 + '0');
}
int t, cnt1 = 2, cnt2 = 1;
long long n[T], m[T], mx, fib[S], a[S], f[S][S];
void init () {
fib[0] = fib[1] = 1;
for (int i = 2; ; i++) {
fib[i] = fib[i - 1] + fib[i - 2];
if (fib[i] > mx) {
cnt1 = i;
break;
}
}
a[1] = 1;
for (int i = 2; ; i++) {
a[i] = 1;
for (int j = i - 1; j >= 1; j -= 2) a[i] += a[j];
if (a[i] > mx) {
cnt2 = i;
break;
}
}
for (int i = 1; i <= cnt1; i++) f[1][i] = 1ll;
for (int i = 2; i <= cnt2; i++) {
for (int j = 1; j <= cnt1; j++) {
f[i][j] = i <= j ? 1ll : 0ll;
for (int k = i - 1; k >= 1; k -= 2) f[i][j] += f[k][j];
}
}
}
long long solve (int x, int y, long long len) {
if (len == a[x]) return f[x][y];
long long ans = 0ll;
for (int i = x - 1; i >= 1; i -= 2) {
ans += solve(i, y, min(len, a[i]));
len -= a[i];
if (len <= 0) break;
}
if (len > 0) ans++;
return ans;
}
int main () {
read(t);
for (int i = 1; i <= t; i++) {
read(m[i]), read(n[i]);
mx = max(mx, max(m[i], n[i]));
}
init();
for (int i = 1; i <= t; i++) {
int cnt = 0;
for (int j = 1; j <= cnt1; j++) {
if (fib[j] <= m[i]) cnt = j;
}
write(solve(cnt2, cnt, n[i] - 1)), putchar('\n');
}
return 0;
}
直接暴力线段树维护最大子段和就过了,正解不会。
代码如下:
#include <bits/stdc++.h>
using namespace std;
const int N = 100005;
template <class T>
void read (T &x) {
int sgn = 1;
char ch;
x = 0;
for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
if (ch == '-') ch = getchar(), sgn = -1;
for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
x *= sgn;
}
template <class T>
void write (T x) {
if (x < 0) putchar('-'), write(-x);
else if (x < 10) putchar(x + '0');
else write(x / 10), putchar(x % 10 + '0');
}
int n, m;
long long a[N];
struct node {
long long sum, pre, suf, val;
} sgt[N << 2];
node merge (node a, node b) {
node ans;
ans.sum = a.sum + b.sum;
ans.pre = max(a.pre, a.sum + b.pre);
ans.suf = max(b.suf, b.sum + a.suf);
ans.val = max(max(a.val, b.val), a.suf + b.pre);
return ans;
}
void pushup (int now) {
sgt[now] = merge(sgt[now << 1], sgt[now << 1 | 1]);
}
void build (int l, int r, int now) {
int mid = l + r >> 1;
if (l == r) {
sgt[now].sum = sgt[now].pre = sgt[now].suf = a[mid];
sgt[now].val = max(a[mid], 0ll);
}
else {
build(l, mid, now << 1), build(mid + 1, r, now << 1 | 1);
pushup(now);
}
}
void change (int pos, int l, int r, int now, long long val) {
int mid = l + r >> 1;
if (l == r) {
sgt[now].sum = sgt[now].pre = sgt[now].suf = val;
sgt[now].val = max(val, 0ll);
}
else {
if (pos <= mid) change(pos, l, mid, now << 1, val);
else change(pos, mid + 1, r, now << 1 | 1, val);
pushup(now);
}
}
node query (int left, int right, int l, int r, int now) {
int mid = l + r >> 1;
if (l == left && r == right) return sgt[now];
else if (right <= mid) return query(left, right, l, mid, now << 1);
else if (left > mid) return query(left, right, mid + 1, r, now << 1 | 1);
else return merge(query(left, mid, l, mid, now << 1), query(mid + 1, right, mid + 1, r, now << 1 | 1));
}
int main () {
read(n), read(m);
for (int i = 1; i <= n; i++) read(a[i]);
build(1, n, 1);
for (int i = 1; i <= m; i++) {
int ty;
read(ty);
if (ty == 0) {
int l, r, x;
read(l), read(r), read(x);
for (int j = l; j <= r; j++) {
if (a[j] < x) a[j] = x, change(j, 1, n, 1, x);
}
}
else {
int l, r;
read(l), read(r);
node ans = query(l, r, 1, n, 1);
write(ans.val), putchar('\n');
}
}
return 0;
}
我们首先考虑一个\(O(n^2)\)做法,我们把\(A\)和\(B\)的所有长度为\(k\)的串建成trie树。然后相当于在trie树上有\(n - k + 1\)个\(A\)类节点和\(B\)类节点,然后你要将\(A\)类节点和\(B\)类节点匹配,使得匹配的总距离之和除以二最小!
这是一个经典的问题,在深度不一致的情况也一样能做。我们只需考虑一条边\(u,v\)所经过的最小的次数,设\(v\)是\(u\)的儿子,\(v\)的子树中有\(a\)个\(A\)类节点,\(B\)个\(B\)类节点,则容易证明至少经过\(\lvert a - b \rvert\)次。且我们可以归纳地构造出方案。
再考虑如何取优化它。只需把\(trie\)树换成把两个字符串拼在一起后构成的后缀树(用sam建立),然后在后缀树上定位\(2(n - k + 1)\)个节点即可。这里定位可以考虑倍增,也可以使用NOI2018你的名字的那种two-pointer的trick。
时间复杂度\(O(n)\)至\(O(n \log n)\)不等,可以获得\(100\)分。
代码如下:
#include <bits/stdc++.h>
using namespace std;
const int N = 150005;
int n, m, len[N << 2], par[N << 2], num[N << 2], last = 0, cnt = 0;
char a[N], b[N];
long long ans = 0ll;
map<char, int> ch[N << 2];
void extend (char c) {
int p = last, np = ++cnt;
len[np] = len[p] + 1;
for (; ~p && !ch[p][c]; p = par[p]) ch[p][c] = np;
if (p < 0) par[np] = 0;
else {
int q = ch[p][c];
if (len[q] == len[p] + 1) par[np] = q;
else {
int nq = ++cnt;
ch[nq] = ch[q], len[nq] = len[p] + 1;
par[nq] = par[q], par[np] = par[q] = nq;
for (; ~p && ch[p][c] == q; p = par[p]) ch[p][c] = nq;
}
}
last = np;
}
vector<int> child[N << 2];
void dfs (int u) {
for (int i = 0; i < child[u].size(); i++) {
dfs(child[u][i]);
num[u] += num[child[u][i]];
}
if (u) ans += 1ll * max(num[u], -num[u]) * (min(len[u], m) - min(len[par[u]], m));
}
int main () {
scanf("%d%d", &n, &m);
scanf("%s%s", &a, &b);
len[0] = 0, par[0] = -1;
for (int i = n - 1; i >= 0; i--) extend(b[i]);
for (int i = n - 1; i >= 0; i--) extend(a[i]);
for (int i = 0; i <= cnt; i++) num[i] = 0;
int tmp = 0, now = 0;
for (int i = n - 1; i >= 0; i--) {
now = ch[now][b[i]], tmp++;
if (tmp > m) {
if (len[par[now]] >= m) now = par[now];
tmp = m;
}
if (i <= n - m) num[now]++;
}
for (int i = n - 1; i >= 0; i--) {
now = ch[now][a[i]], tmp++;
if (tmp > m) {
if (len[par[now]] >= m) now = par[now];
tmp = m;
}
if (i <= n - m) num[now]--;
}
for (int i = 1; i <= cnt; i++) child[par[i]].push_back(i);
dfs(0);
ans >>= 1;
printf("%lld\n", ans);
return 0;
}
这道毒瘤的题目性质太多,细节也太多,我可能难以给出详细的证明和解释,请见谅。
首先我们设\(0, 1, 2, …, n, n + 1\)中已经被填的数的集合为\(A\)(约定\(0, n + 1 \in A\)),没有被填的数的集合为\(B\)。我们把\(0, 1, …, n, n + 1\)按照已填和未填分段,设为\(A_0, B_1, A_1, B_2, …, B_m, A_m\)。若\(1 \leq l \leq r \leq k\),则区间内部的数已经固定,我们无需考虑。我们只需考虑\(k < l \leq r \leq n,l \leq k < r \leq n\)的部分。为了同时让这两部分最大化,我们猜测有如下结论
存在一个最优解满足:
1.对于\(B_1, B_2, …, B_m\),它在排列中一定构成递增或递减的连续的一段。
2.对于每个\(k < i \leq n\),\(p_{k + 1}, p_{k + 2}, …, p_i\)必定构成\(B\)的连续的一段。(例如\(B = \{2, 4, 6, 8\}\),则\(4,6,2\)算连续的一段,\(2,4,8\)不算)
有了第二条,我们可以统计未填数的每一种前缀填法对答案的贡献是多少,然后做一个\(O(n^2)\)的区间dp。然而毒瘤的出题人有一个神仙的做法,没有考虑到有这种辣鸡想法,所以就没有给\(O(n^2)\)的部分分
为了优化这两个dp,我们需要考虑性质\(1\)了。若\(\{ p_{k + 1}, p_{k + 2}, …, p_i \}\)满足\(B_1, …, B_m\)要么全在里面,要么全不在,则称这个集合是好的,反之称为坏的集合。结果我们会发现,每一个未填的后缀最多只对4个好的集合有贡献。
再考虑坏的集合必定恰好夹在两个好的集合之间。而每一个未填的后缀又最多只对4个好集合之间的坏集合有贡献。
我们把所有有贡献的点(对应的是\(O(n)\)种好集合)称之为关键点,按左端点从大往小排序,相同则按右端点从小到大排序。注意到我们的\(dp\)状态只用考虑关键点,所以状态优化到了\(O(n)\)种。然后转移这一维只对右端点的范围有限制,所以可以用树状数组优化,就变成了\(O(n \log n)\)。注意经过坏集合的情况需要单独地转移,同时注意这个dp的起始位置和最终位置也最好作为关键点。
当我们用记录方案的dp构造出了排列后,就使用线段树或析合树等方法计算排列连续段即可。总时间复杂度\(O(n \log n)\)
代码如下:
#include <bits/stdc++.h>
using namespace std;
const int N = 200005;
template <class T>
void read (T &x) {
int sgn = 1;
char ch;
x = 0;
for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
if (ch == '-') ch = getchar(), sgn = -1;
for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
x *= sgn;
}
template <class T>
void write (T x) {
if (x < 0) putchar('-'), write(-x);
else if (x < 10) putchar(x + '0');
else write(x / 10), putchar(x % 10 + '0');
}
int n, m, p[N], inv[N], mn[N << 2], tot[N << 2], tag[N << 2], cnt = 0;
bool vis[N];
long long ans = 0ll;
void build (int l, int r, int now) {
int mid = l + r >> 1;
mn[now] = tag[now] = 0;
tot[now] = r - l + 1;
if (l < r) build(l, mid, now << 1), build(mid + 1, r, now << 1 | 1);
}
void cover (int now, int val) {
mn[now] += val, tag[now] += val;
}
void pushdown (int now) {
cover(now << 1, tag[now]), cover(now << 1 | 1, tag[now]);
tag[now] = 0;
}
void pushup (int now) {
mn[now] = min(mn[now << 1], mn[now << 1 | 1]);
tot[now] = 0;
if (mn[now] == mn[now << 1]) tot[now] += tot[now << 1];
if (mn[now] == mn[now << 1 | 1]) tot[now] += tot[now << 1 | 1];
}
void change (int left, int right, int l, int r, int now, int val) {
int mid = l + r >> 1;
if (l == left && r == right) cover(now, val);
else {
pushdown(now);
if (right <= mid) change(left, right, l, mid, now << 1, val);
else if (left > mid) change(left, right, mid + 1, r, now << 1 | 1, val);
else change(left, mid, l, mid, now << 1, val), change(mid + 1, right, mid + 1, r, now << 1 | 1, val);
pushup(now);
}
}
int query (int left, int right, int l, int r, int now) {
int mid = l + r >> 1;
if (l == left && r == right) return mn[now] == 1 ? tot[now] : 0;
else {
pushdown(now);
if (right <= mid) return query(left, right, l, mid, now << 1);
else if (left > mid) return query(left, right, mid + 1, r, now << 1 | 1);
else return query(left, mid, l, mid, now << 1) + query(mid + 1, right, mid + 1, r, now << 1 | 1);
}
}
pair<long long, int> bit[N];
int lowbit (int x) {
return x & -x;
}
void init () {
for (int i = 0; i <= n + 1; i++) bit[i] = make_pair(0ll, -1);
}
void add (int pos, pair<long long, int> pi) {
for (int i = pos; i <= n + 1; i += lowbit(i)) bit[i] = max(bit[i], pi);
}
pair<long long, int> ask (int pos) {
pair<long long, int> res(-1ll, -1);
for (; pos; pos ^= lowbit(pos)) res = max(res, bit[pos]);
return res;
}
struct node {
int l, r;
bool operator < (node rhs) const {
if (l > rhs.l) return true;
if (l < rhs.l) return false;
return r < rhs.r;
}
} ;
node seg (int l, int r) {
node i = {l, r};
return i;
}
int rk[N], le[N], ri[N], blo[N];
long long val1[N], val2[N];
set<node> node_set;
vector<node> node_vec;
long long f[N * 5], trans_val1[N * 5], trans_val2[N * 5], val[N * 5];
int prv[N * 5], trans1[N * 5], trans2[N * 5];
void add_seg1 (int l, int r) {
int bl = blo[l], br = blo[r];
if (bl == br) {
if (bl && le[bl] == l) val1[bl - 1] += ri[bl - 1] - le[bl - 1];
if (br < blo[n + 1] && ri[br] == r) val2[br + 1] += ri[br + 1] - le[br + 1];
}
if (bl < br) node_set.insert(seg(bl + 1, br - 1));
if (bl && le[bl] == l) node_set.insert(seg(bl - 1, br - 1));
if (br < blo[n + 1] && ri[br] == r) node_set.insert(seg(bl + 1, br + 1));
if (bl && le[bl] == l && br < blo[n + 1] && ri[br] == r) node_set.insert(seg(bl - 1, br + 1));
}
void add_seg2 (int l, int r) {
int bl = blo[l], br = blo[r], u, v;
if (bl < br) {
u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br - 1)) - node_vec.begin();
val[u]++;
}
if (bl && le[bl] == l) {
u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br - 1)) - node_vec.begin();
val[u]++;
}
if (br < blo[n + 1] && ri[br] == r) {
u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br + 1)) - node_vec.begin();
val[u]++;
}
if (bl && le[bl] == l && br < blo[n + 1] && ri[br] == r) {
u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br + 1)) - node_vec.begin();
val[u]++;
}
if (bl < br && bl && le[bl] == l) {
u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br - 1)) - node_vec.begin();
v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br - 1)) - node_vec.begin();
trans1[v] = u, trans_val1[v] += ri[bl - 1] - le[bl - 1];
}
if (bl < br && br < blo[n + 1] && ri[br] == r) {
u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br - 1)) - node_vec.begin();
v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br + 1)) - node_vec.begin();
trans2[v] = u, trans_val2[v] += ri[br + 1] - le[br + 1];
}
if (bl && br < blo[n + 1] && le[bl] == l && ri[br] == r) {
u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br + 1)) - node_vec.begin();
v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br + 1)) - node_vec.begin();
trans1[v] = u, trans_val1[v] += ri[bl - 1] - le[bl - 1];
u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br - 1)) - node_vec.begin();
v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br + 1)) - node_vec.begin();
trans2[v] = u, trans_val2[v] += ri[br + 1] - le[br + 1];
}
}
bool dir[N];
void seg_init () {
vis[0] = vis[n + 1] = true;
rk[0] = le[0] = blo[0] = 0;
for (int i = 1; i <= n + 1; i++) {
rk[i] = rk[i - 1] + vis[i];
if (vis[i] != vis[i - 1]) {
blo[i] = blo[i - 1] + 1;
le[blo[i]] = ri[blo[i]] = i;
}
else blo[i] = blo[i - 1], ri[blo[i]] = i;
}
for (int i = 0; i <= blo[n + 1]; i++) val1[i] = val2[i] = 0ll;
for (int i = m, l = n + 1, r = 0; i >= 1; i--) {
l = min(l, p[i]), r = max(r, p[i]);
if (rk[r] - rk[l] == m - i) add_seg1(l, r);
}
node_set.insert(seg(1, blo[n + 1] - 1));
for (int i = 1; i <= blo[n + 1]; i += 2) node_set.insert(seg(i, i));
for (set<node> :: iterator it = node_set.begin(); it != node_set.end(); it++) node_vec.push_back(*it);
for (int i = 0; i < node_vec.size(); i++) {
f[i] = val[i] = trans_val1[i] = trans_val2[i] = 0ll;
prv[i] = trans1[i] = trans2[i] = -1;
}
for (int i = 1; i <= blo[n + 1]; i += 2) {
int pos = lower_bound(node_vec.begin(), node_vec.end(), seg(i, i)) - node_vec.begin();
if (val1[i] >= val2[i]) dir[i] = false, val[pos] += val1[i];
else dir[i] = true, val[pos] += val2[i];
}
for (int i = m, l = n + 1, r = 0; i >= 1; i--) {
l = min(l, p[i]), r = max(r, p[i]);
if (rk[r] - rk[l] == m - i) add_seg2(l, r);
}
}
void get_dp () {
init();
for (int i = 0; i < node_vec.size(); i++) {
if (~trans1[i] && f[i] < f[trans1[i]] + trans_val1[i]) {
prv[i] = trans1[i];
f[i] = f[trans1[i]] + trans_val1[i];
}
if (~trans2[i] && f[i] < f[trans2[i]] + trans_val2[i]) {
prv[i] = trans2[i];
f[i] = f[trans2[i]] + trans_val2[i];
}
pair<long long, int> pi = ask(node_vec[i].r);
if (f[i] < pi.first) prv[i] = pi.second, f[i] = pi.first;
f[i] += val[i], add(node_vec[i].r, make_pair(f[i], i));
}
}
void construct () {
vector<int> stk;
for (int i = node_vec.size() - 1; ~i; i = prv[i]) stk.push_back(i);
reverse(stk.begin(), stk.end());
int l = le[node_vec[stk[0]].l], r = ri[node_vec[stk[0]].r];
if (blo[l] < blo[r] || dir[blo[l]]) r = l;
else l = r;
cnt = m, p[++cnt] = l;
for (int i = 0; i < stk.size(); i++) {
int L = le[node_vec[stk[i]].l], R = ri[node_vec[stk[i]].r];
for (; l > L; ) {
if (!vis[--l]) p[++cnt] = l;
}
for (; r < R; ) {
if (!vis[++r]) p[++cnt] = r;
}
}
}
int main () {
read(n), read(m);
for (int i = 1; i <= n; i++) vis[i] = false;
for (int i = 1; i <= m; i++) read(p[i]), vis[p[i]] = true;
if (m < n) {
seg_init();
get_dp();
construct();
}
ans = 0ll;
for (int i = 1; i <= n; i++) inv[p[i]] = i;
build(1, n, 1);
for (int i = 1; i <= n; i++) {
change(1, i, 1, n, 1, 1);
if (p[i] > 1 && inv[p[i] - 1] < i) change(1, inv[p[i] - 1], 1, n, 1, -1);
if (p[i] < n && inv[p[i] + 1] < i) change(1, inv[p[i] + 1], 1, n, 1, -1);
ans += query(1, i, 1, n, 1);
}
write(ans), putchar('\n');
for (int i = 1; i <= n; i++) write(p[i]), putchar(' ');
putchar('\n');
return 0;
}
不会正解,咕掉了。
手机扫一扫
移动阅读更方便
你可能感兴趣的文章