SNOI2020 部分题解
阅读原文时间:2023年07月09日阅读:1

D1T1

画图可以发现,多了一条边过后的图是串并联图。(暂时不确定)

然后我们考虑把问题变成,若生成树包含一条边\(e\),则使生成树权值乘上\(a_e\),否则乘上\(b_e\),求最终的生成树权值之和。我们只需要支持删去度数为\(1\)的点,同时删去和它相连的那条边;删去度数为2的点,把两条边合并为一条边;合并重边三种操作。

对于第一种操作,把答案乘上\(a_e\),并删去即可。对于第二种操作,把和这个点相邻的两条边记作\(e_1,e_2\),其中\(e_1\)连接\(u, v\),\(e_2\)连接\(u, w\)。则删去两条边,连接一条端点为\(v, w\),\(a_e = a_{e_1}a_{e_2}, b_e = a_{e_1}b_{e_2} + a_{e_2}b_{e_1}\)的边。对于第三种操作,把\(e_1, e_2\)合并为\(e\)时,\(a_e = a_{e_1}b_{e_2} + a_{e_2}b_{e_1}, b_e = b_{e_1}b_{e_2}\)。

代码如下:

#include <bits/stdc++.h>
using namespace std;

const int N = 500005;
const long long mod = 998244353ll;

template <class T>
void read (T &x) {
    int sgn = 1;
    char ch;
    x = 0;
    for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
    if (ch == '-') ch = getchar(), sgn = -1;
    for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
    x *= sgn;
}
template <class T>
void write (T x) {
    if (x < 0) putchar('-'), write(-x);
    else if (x < 10) putchar(x + '0');
    else write(x / 10), putchar(x % 10 + '0');
}

int n, m, deg[N];
long long ans = 1ll;
bool vis[N];
vector<int> g[N];
set<pair<int, int> > se;
map<pair<int, int>, pair<long long, long long> > id;
queue<int> que;

int main () {
    read(n), read(m);
    for (int i = 1; i <= n; i++) deg[i] = 0;
    for (int i = 0; i < m; i++) {
        int u, v;
        read(u), read(v);
        if (u > v) swap(u, v);
        if (u == v) continue;
        if (se.count(make_pair(u, v))) id[make_pair(u, v)].second = (id[make_pair(u, v)].second + 1ll) % mod;
        else {
            deg[u]++, deg[v]++;
            g[u].push_back(v), g[v].push_back(u);
            se.insert(make_pair(u, v));
            id[make_pair(u, v)].first = id[make_pair(u, v)].second = 1ll;
        }
    }
    for (int i = 1; i <= n; i++) {
        vis[i] = false;
        if (deg[i] <= 2) que.push(i);
    }
    while (!que.empty()) {
        int u = que.front();
        que.pop();
        if (vis[u]) continue;
        vis[u] = true;
        vector<int> adj;
        for (int i = 0; i < g[u].size(); i++) {
            if (!vis[g[u][i]]) adj.push_back(g[u][i]);
        }
        if (adj.size() == 1) {
            deg[u]--;
            if (--deg[adj[0]] <= 2) que.push(adj[0]);
            ans = ans * id[make_pair(min(u, adj[0]), max(u, adj[0]))].second % mod;
        }
        else if (adj.size() == 2) {
            pair<int, int> e0(min(u, adj[0]), max(u, adj[0]));
            pair<int, int> e1(min(u, adj[1]), max(u, adj[1]));
            pair<int, int> e(min(adj[0], adj[1]), max(adj[0], adj[1]));
            pair<long long, long long> pi((id[e0].second * id[e1].first + id[e0].first * id[e1].second) % mod, id[e0].second * id[e1].second % mod);
            deg[u] -= 2;
            if (se.count(e)) {
                if (--deg[adj[0]] <= 2) que.push(adj[0]);
                if (--deg[adj[1]] <= 2) que.push(adj[1]);
                id[e] = make_pair(id[e].first * pi.first % mod, (id[e].first * pi.second + id[e].second * pi.first) % mod);
            }
            else {
                g[adj[0]].push_back(adj[1]);
                g[adj[1]].push_back(adj[0]);
                se.insert(e);
                id[e] = pi;
            }
        }
    }
    write(ans), putchar('\n');
    return 0;
}

D1T2

神奇的找规律题,感觉方向不对就很难找出来。

我们仍然考虑打表,把\(k, n\)比较小的情况打出来(记作\(f_{n, k}\))。然后我们发现固定\(n\),\(f_{n, k}\)的取值很少。再仔细观察,发现取值变化的点恰好是斐波那契数列上的数

这启发我们对于每个数\(i > 1\),找先手第一步最小要取多少,才能保证他获胜。记该最小值为\(a_i\)。然后我们写出这个数列某些前缀:

\(1\)

\(1, 2\)

\(1, 2, 3\)

\(1, 2, 3, 1, 5\)

\(1, 2, 3, 1, 5, 1, 2, 8\)

\(\cdots\)

注意这里我们是找\(1\), \(2\), \(3\), \(5\), \(8\), \(\cdots\)的第一次出现位置为终止的前缀!

然后我们发现第\(i\)个前缀是第\(i - 1, i - 3, i - 5, \cdots\)个前缀拼接上\(F_i\)后的前缀。(\(F_0 = F_1 = 1\))

找到了这个规律后,我们就设\(g_{i, j, k}\)表示在第\(i\)个前缀的前\(j\)个数中\(\leq F_k\)的数字的个数。先通过\(dp\)预处理\(j\)就是第\(i\)个前缀的长度的情况,然后每次询问递归下去做即可。(通过类似线段树的证明,可以知道每次询问递归树的点数是\(O(\log^2 N)\)的。

代码如下:

#include <bits/stdc++.h>
using namespace std;

const int S = 505, T = 100005;

template <class T>
void read (T &x) {
    int sgn = 1;
    char ch;
    x = 0;
    for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
    if (ch == '-') ch = getchar(), sgn = -1;
    for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
    x *= sgn;
}
template <class T>
void write (T x) {
    if (x < 0) putchar('-'), write(-x);
    else if (x < 10) putchar(x + '0');
    else write(x / 10), putchar(x % 10 + '0');
}

int t, cnt1 = 2, cnt2 = 1;
long long n[T], m[T], mx, fib[S], a[S], f[S][S];

void init () {
    fib[0] = fib[1] = 1;
    for (int i = 2; ; i++) {
        fib[i] = fib[i - 1] + fib[i - 2];
        if (fib[i] > mx) {
            cnt1 = i;
            break;
        }
    }
    a[1] = 1;
    for (int i = 2; ; i++) {
        a[i] = 1;
        for (int j = i - 1; j >= 1; j -= 2) a[i] += a[j];
        if (a[i] > mx) {
            cnt2 = i;
            break;
        }
    }
    for (int i = 1; i <= cnt1; i++) f[1][i] = 1ll;
    for (int i = 2; i <= cnt2; i++) {
        for (int j = 1; j <= cnt1; j++) {
            f[i][j] = i <= j ? 1ll : 0ll;
            for (int k = i - 1; k >= 1; k -= 2) f[i][j] += f[k][j];
        }
    }
}

long long solve (int x, int y, long long len) {
    if (len == a[x]) return f[x][y];
    long long ans = 0ll;
    for (int i = x - 1; i >= 1; i -= 2) {
        ans += solve(i, y, min(len, a[i]));
        len -= a[i];
        if (len <= 0) break;
    }
    if (len > 0) ans++;
    return ans;
}

int main () {
    read(t);
    for (int i = 1; i <= t; i++) {
        read(m[i]), read(n[i]);
        mx = max(mx, max(m[i], n[i]));
    }
    init();
    for (int i = 1; i <= t; i++) {
        int cnt = 0;
        for (int j = 1; j <= cnt1; j++) {
            if (fib[j] <= m[i]) cnt = j;
        }
        write(solve(cnt2, cnt, n[i] - 1)), putchar('\n');
    }
    return 0;
}

D1T3

直接暴力线段树维护最大子段和就过了,正解不会。

代码如下:

#include <bits/stdc++.h>
using namespace std;

const int N = 100005;

template <class T>
void read (T &x) {
    int sgn = 1;
    char ch;
    x = 0;
    for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
    if (ch == '-') ch = getchar(), sgn = -1;
    for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
    x *= sgn;
}
template <class T>
void write (T x) {
    if (x < 0) putchar('-'), write(-x);
    else if (x < 10) putchar(x + '0');
    else write(x / 10), putchar(x % 10 + '0');
}

int n, m;
long long a[N];
struct node {
    long long sum, pre, suf, val;
} sgt[N << 2];
node merge (node a, node b) {
    node ans;
    ans.sum = a.sum + b.sum;
    ans.pre = max(a.pre, a.sum + b.pre);
    ans.suf = max(b.suf, b.sum + a.suf);
    ans.val = max(max(a.val, b.val), a.suf + b.pre);
    return ans;
}
void pushup (int now) {
    sgt[now] = merge(sgt[now << 1], sgt[now << 1 | 1]);
}
void build (int l, int r, int now) {
    int mid = l + r >> 1;
    if (l == r) {
        sgt[now].sum = sgt[now].pre = sgt[now].suf = a[mid];
        sgt[now].val = max(a[mid], 0ll);
    }
    else {
        build(l, mid, now << 1), build(mid + 1, r, now << 1 | 1);
        pushup(now);
    }
}

void change (int pos, int l, int r, int now, long long val) {
    int mid = l + r >> 1;
    if (l == r) {
        sgt[now].sum = sgt[now].pre = sgt[now].suf = val;
        sgt[now].val = max(val, 0ll);
    }
    else {
        if (pos <= mid) change(pos, l, mid, now << 1, val);
        else change(pos, mid + 1, r, now << 1 | 1, val);
        pushup(now);
    }
}
node query (int left, int right, int l, int r, int now) {
    int mid = l + r >> 1;
    if (l == left && r == right) return sgt[now];
    else if (right <= mid) return query(left, right, l, mid, now << 1);
    else if (left > mid) return query(left, right, mid + 1, r, now << 1 | 1);
    else return merge(query(left, mid, l, mid, now << 1), query(mid + 1, right, mid + 1, r, now << 1 | 1));
}

int main () {
    read(n), read(m);
    for (int i = 1; i <= n; i++) read(a[i]);
    build(1, n, 1);
    for (int i = 1; i <= m; i++) {
        int ty;
        read(ty);
        if (ty == 0) {
            int l, r, x;
            read(l), read(r), read(x);
            for (int j = l; j <= r; j++) {
                if (a[j] < x) a[j] = x, change(j, 1, n, 1, x);
            }
        }
        else {
            int l, r;
            read(l), read(r);
            node ans = query(l, r, 1, n, 1);
            write(ans.val), putchar('\n');
        }
    }
    return 0;
}

D2T1

我们首先考虑一个\(O(n^2)\)做法,我们把\(A\)和\(B\)的所有长度为\(k\)的串建成trie树。然后相当于在trie树上有\(n - k + 1\)个\(A\)类节点和\(B\)类节点,然后你要将\(A\)类节点和\(B\)类节点匹配,使得匹配的总距离之和除以二最小!

这是一个经典的问题,在深度不一致的情况也一样能做。我们只需考虑一条边\(u,v\)所经过的最小的次数,设\(v\)是\(u\)的儿子,\(v\)的子树中有\(a\)个\(A\)类节点,\(B\)个\(B\)类节点,则容易证明至少经过\(\lvert a - b \rvert\)次。且我们可以归纳地构造出方案。

再考虑如何取优化它。只需把\(trie\)树换成把两个字符串拼在一起后构成的后缀树(用sam建立),然后在后缀树上定位\(2(n - k + 1)\)个节点即可。这里定位可以考虑倍增,也可以使用NOI2018你的名字的那种two-pointer的trick。

时间复杂度\(O(n)\)至\(O(n \log n)\)不等,可以获得\(100\)分。

代码如下:

#include <bits/stdc++.h>
using namespace std;

const int N = 150005;

int n, m, len[N << 2], par[N << 2], num[N << 2], last = 0, cnt = 0;
char a[N], b[N];
long long ans = 0ll;
map<char, int> ch[N << 2];

void extend (char c) {
    int p = last, np = ++cnt;
    len[np] = len[p] + 1;
    for (; ~p && !ch[p][c]; p = par[p]) ch[p][c] = np;
    if (p < 0) par[np] = 0;
    else {
        int q = ch[p][c];
        if (len[q] == len[p] + 1) par[np] = q;
        else {
            int nq = ++cnt;
            ch[nq] = ch[q], len[nq] = len[p] + 1;
            par[nq] = par[q], par[np] = par[q] = nq;
            for (; ~p && ch[p][c] == q; p = par[p]) ch[p][c] = nq;
        }
    }
    last = np;
}

vector<int> child[N << 2];
void dfs (int u) {
    for (int i = 0; i < child[u].size(); i++) {
        dfs(child[u][i]);
        num[u] += num[child[u][i]];
    }
    if (u) ans += 1ll * max(num[u], -num[u]) * (min(len[u], m) - min(len[par[u]], m));
}

int main () {
    scanf("%d%d", &n, &m);
    scanf("%s%s", &a, &b);
    len[0] = 0, par[0] = -1;
    for (int i = n - 1; i >= 0; i--) extend(b[i]);
    for (int i = n - 1; i >= 0; i--) extend(a[i]);
    for (int i = 0; i <= cnt; i++) num[i] = 0;
    int tmp = 0, now = 0;
    for (int i = n - 1; i >= 0; i--) {
        now = ch[now][b[i]], tmp++;
        if (tmp > m) {
            if (len[par[now]] >= m) now = par[now];
            tmp = m;
        }
        if (i <= n - m) num[now]++;
    }
    for (int i = n - 1; i >= 0; i--) {
        now = ch[now][a[i]], tmp++;
        if (tmp > m) {
            if (len[par[now]] >= m) now = par[now];
            tmp = m;
        }
        if (i <= n - m) num[now]--;
    }
    for (int i = 1; i <= cnt; i++) child[par[i]].push_back(i);
    dfs(0);
    ans >>= 1;
    printf("%lld\n", ans);
    return 0;
}

D2T2

这道毒瘤的题目性质太多,细节也太多,我可能难以给出详细的证明和解释,请见谅。

首先我们设\(0, 1, 2, …, n, n + 1\)中已经被填的数的集合为\(A\)(约定\(0, n + 1 \in A\)),没有被填的数的集合为\(B\)。我们把\(0, 1, …, n, n + 1\)按照已填和未填分段,设为\(A_0, B_1, A_1, B_2, …, B_m, A_m\)。若\(1 \leq l \leq r \leq k\),则区间内部的数已经固定,我们无需考虑。我们只需考虑\(k < l \leq r \leq n,l \leq k < r \leq n\)的部分。为了同时让这两部分最大化,我们猜测有如下结论

存在一个最优解满足:

1.对于\(B_1, B_2, …, B_m\),它在排列中一定构成递增或递减的连续的一段

2.对于每个\(k < i \leq n\),\(p_{k + 1}, p_{k + 2}, …, p_i\)必定构成\(B\)的连续的一段。(例如\(B = \{2, 4, 6, 8\}\),则\(4,6,2\)算连续的一段,\(2,4,8\)不算)

有了第二条,我们可以统计未填数的每一种前缀填法对答案的贡献是多少,然后做一个\(O(n^2)\)的区间dp。然而毒瘤的出题人有一个神仙的做法,没有考虑到有这种辣鸡想法,所以就没有给\(O(n^2)\)的部分分

为了优化这两个dp,我们需要考虑性质\(1\)了。若\(\{ p_{k + 1}, p_{k + 2}, …, p_i \}\)满足\(B_1, …, B_m\)要么全在里面,要么全不在,则称这个集合是好的,反之称为坏的集合。结果我们会发现,每一个未填的后缀最多只对4个好的集合有贡献。

再考虑坏的集合必定恰好夹在两个好的集合之间。而每一个未填的后缀又最多只对4个好集合之间的坏集合有贡献。

我们把所有有贡献的点(对应的是\(O(n)\)种好集合)称之为关键点,按左端点从大往小排序,相同则按右端点从小到大排序。注意到我们的\(dp\)状态只用考虑关键点,所以状态优化到了\(O(n)\)种。然后转移这一维只对右端点的范围有限制,所以可以用树状数组优化,就变成了\(O(n \log n)\)。注意经过坏集合的情况需要单独地转移,同时注意这个dp的起始位置和最终位置也最好作为关键点

当我们用记录方案的dp构造出了排列后,就使用线段树或析合树等方法计算排列连续段即可。总时间复杂度\(O(n \log n)\)

代码如下:

#include <bits/stdc++.h>
using namespace std;

const int N = 200005;

template <class T>
void read (T &x) {
    int sgn = 1;
    char ch;
    x = 0;
    for (ch = getchar(); (ch < '0' || ch > '9') && ch != '-'; ch = getchar()) ;
    if (ch == '-') ch = getchar(), sgn = -1;
    for (; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
    x *= sgn;
}
template <class T>
void write (T x) {
    if (x < 0) putchar('-'), write(-x);
    else if (x < 10) putchar(x + '0');
    else write(x / 10), putchar(x % 10 + '0');
}

int n, m, p[N], inv[N], mn[N << 2], tot[N << 2], tag[N << 2], cnt = 0;
bool vis[N];
long long ans = 0ll;

void build (int l, int r, int now) {
    int mid = l + r >> 1;
    mn[now] = tag[now] = 0;
    tot[now] = r - l + 1;
    if (l < r) build(l, mid, now << 1), build(mid + 1, r, now << 1 | 1);
}
void cover (int now, int val) {
    mn[now] += val, tag[now] += val;
}
void pushdown (int now) {
    cover(now << 1, tag[now]), cover(now << 1 | 1, tag[now]);
    tag[now] = 0;
}
void pushup (int now) {
    mn[now] = min(mn[now << 1], mn[now << 1 | 1]);
    tot[now] = 0;
    if (mn[now] == mn[now << 1]) tot[now] += tot[now << 1];
    if (mn[now] == mn[now << 1 | 1]) tot[now] += tot[now << 1 | 1];
}

void change (int left, int right, int l, int r, int now, int val) {
    int mid = l + r >> 1;
    if (l == left && r == right) cover(now, val);
    else {
        pushdown(now);
        if (right <= mid) change(left, right, l, mid, now << 1, val);
        else if (left > mid) change(left, right, mid + 1, r, now << 1 | 1, val);
        else change(left, mid, l, mid, now << 1, val), change(mid + 1, right, mid + 1, r, now << 1 | 1, val);
        pushup(now);
    }
}
int query (int left, int right, int l, int r, int now) {
    int mid = l + r >> 1;
    if (l == left && r == right) return mn[now] == 1 ? tot[now] : 0;
    else {
        pushdown(now);
        if (right <= mid) return query(left, right, l, mid, now << 1);
        else if (left > mid) return query(left, right, mid + 1, r, now << 1 | 1);
        else return query(left, mid, l, mid, now << 1) + query(mid + 1, right, mid + 1, r, now << 1 | 1);
    }
}

pair<long long, int> bit[N];
int lowbit (int x) {
    return x & -x;
}
void init () {
    for (int i = 0; i <= n + 1; i++) bit[i] = make_pair(0ll, -1);
}
void add (int pos, pair<long long, int> pi) {
    for (int i = pos; i <= n + 1; i += lowbit(i)) bit[i] = max(bit[i], pi);
}
pair<long long, int> ask (int pos) {
    pair<long long, int> res(-1ll, -1);
    for (; pos; pos ^= lowbit(pos)) res = max(res, bit[pos]);
    return res;
}

struct node {
    int l, r;
    bool operator < (node rhs) const {
        if (l > rhs.l) return true;
        if (l < rhs.l) return false;
        return r < rhs.r;
    }
} ;
node seg (int l, int r) {
    node i = {l, r};
    return i;
}

int rk[N], le[N], ri[N], blo[N];
long long val1[N], val2[N];
set<node> node_set;
vector<node> node_vec;
long long f[N * 5], trans_val1[N * 5], trans_val2[N * 5], val[N * 5];
int prv[N * 5], trans1[N * 5], trans2[N * 5];
void add_seg1 (int l, int r) {
    int bl = blo[l], br = blo[r];
    if (bl == br) {
        if (bl && le[bl] == l) val1[bl - 1] += ri[bl - 1] - le[bl - 1];
        if (br < blo[n + 1] && ri[br] == r) val2[br + 1] += ri[br + 1] - le[br + 1];
    }
    if (bl < br) node_set.insert(seg(bl + 1, br - 1));
    if (bl && le[bl] == l) node_set.insert(seg(bl - 1, br - 1));
    if (br < blo[n + 1] && ri[br] == r) node_set.insert(seg(bl + 1, br + 1));
    if (bl && le[bl] == l && br < blo[n + 1] && ri[br] == r) node_set.insert(seg(bl - 1, br + 1));
}
void add_seg2 (int l, int r) {
    int bl = blo[l], br = blo[r], u, v;
    if (bl < br) {
        u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br - 1)) - node_vec.begin();
        val[u]++;
    }
    if (bl && le[bl] == l) {
        u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br - 1)) - node_vec.begin();
        val[u]++;
    }
    if (br < blo[n + 1] && ri[br] == r) {
        u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br + 1)) - node_vec.begin();
        val[u]++;
    }
    if (bl && le[bl] == l && br < blo[n + 1] && ri[br] == r) {
        u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br + 1)) - node_vec.begin();
        val[u]++;
    }
    if (bl < br && bl && le[bl] == l) {
        u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br - 1)) - node_vec.begin();
        v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br - 1)) - node_vec.begin();
        trans1[v] = u, trans_val1[v] += ri[bl - 1] - le[bl - 1];
    }
    if (bl < br && br < blo[n + 1] && ri[br] == r) {
        u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br - 1)) - node_vec.begin();
        v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br + 1)) - node_vec.begin();
        trans2[v] = u, trans_val2[v] += ri[br + 1] - le[br + 1];
    }
    if (bl && br < blo[n + 1] && le[bl] == l && ri[br] == r) {
        u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl + 1, br + 1)) - node_vec.begin();
        v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br + 1)) - node_vec.begin();
        trans1[v] = u, trans_val1[v] += ri[bl - 1] - le[bl - 1];
        u = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br - 1)) - node_vec.begin();
        v = lower_bound(node_vec.begin(), node_vec.end(), seg(bl - 1, br + 1)) - node_vec.begin();
        trans2[v] = u, trans_val2[v] += ri[br + 1] - le[br + 1];
    }
}

bool dir[N];
void seg_init () {
    vis[0] = vis[n + 1] = true;
    rk[0] = le[0] = blo[0] = 0;
    for (int i = 1; i <= n + 1; i++) {
        rk[i] = rk[i - 1] + vis[i];
        if (vis[i] != vis[i - 1]) {
            blo[i] = blo[i - 1] + 1;
            le[blo[i]] = ri[blo[i]] = i;
        }
        else blo[i] = blo[i - 1], ri[blo[i]] = i;
    }

    for (int i = 0; i <= blo[n + 1]; i++) val1[i] = val2[i] = 0ll;
    for (int i = m, l = n + 1, r = 0; i >= 1; i--) {
        l = min(l, p[i]), r = max(r, p[i]);
        if (rk[r] - rk[l] == m - i) add_seg1(l, r);
    }
    node_set.insert(seg(1, blo[n + 1] - 1));
    for (int i = 1; i <= blo[n + 1]; i += 2) node_set.insert(seg(i, i));
    for (set<node> :: iterator it = node_set.begin(); it != node_set.end(); it++) node_vec.push_back(*it);
    for (int i = 0; i < node_vec.size(); i++) {
        f[i] = val[i] = trans_val1[i] = trans_val2[i] = 0ll;
        prv[i] = trans1[i] = trans2[i] = -1;
    }
    for (int i = 1; i <= blo[n + 1]; i += 2) {
        int pos = lower_bound(node_vec.begin(), node_vec.end(), seg(i, i)) - node_vec.begin();
        if (val1[i] >= val2[i]) dir[i] = false, val[pos] += val1[i];
        else dir[i] = true, val[pos] += val2[i];
    }
    for (int i = m, l = n + 1, r = 0; i >= 1; i--) {
        l = min(l, p[i]), r = max(r, p[i]);
        if (rk[r] - rk[l] == m - i) add_seg2(l, r);
    }
}
void get_dp () {
    init();
    for (int i = 0; i < node_vec.size(); i++) {
        if (~trans1[i] && f[i] < f[trans1[i]] + trans_val1[i]) {
            prv[i] = trans1[i];
            f[i] = f[trans1[i]] + trans_val1[i];
        }
        if (~trans2[i] && f[i] < f[trans2[i]] + trans_val2[i]) {
            prv[i] = trans2[i];
            f[i] = f[trans2[i]] + trans_val2[i];
        }
        pair<long long, int> pi = ask(node_vec[i].r);
        if (f[i] < pi.first) prv[i] = pi.second, f[i] = pi.first;
        f[i] += val[i], add(node_vec[i].r, make_pair(f[i], i));
    }
}
void construct () {
    vector<int> stk;
    for (int i = node_vec.size() - 1; ~i; i = prv[i]) stk.push_back(i);
    reverse(stk.begin(), stk.end());
    int l = le[node_vec[stk[0]].l], r = ri[node_vec[stk[0]].r];
    if (blo[l] < blo[r] || dir[blo[l]]) r = l;
    else l = r;
    cnt = m, p[++cnt] = l;
    for (int i = 0; i < stk.size(); i++) {
        int L = le[node_vec[stk[i]].l], R = ri[node_vec[stk[i]].r];
        for (; l > L; ) {
            if (!vis[--l]) p[++cnt] = l;
        }
        for (; r < R; ) {
            if (!vis[++r]) p[++cnt] = r;
        }
    }
}

int main () {
    read(n), read(m);
    for (int i = 1; i <= n; i++) vis[i] = false;
    for (int i = 1; i <= m; i++) read(p[i]), vis[p[i]] = true;
    if (m < n) {
        seg_init();
        get_dp();
        construct();
    }
    ans = 0ll;
    for (int i = 1; i <= n; i++) inv[p[i]] = i;
    build(1, n, 1);
    for (int i = 1; i <= n; i++) {
        change(1, i, 1, n, 1, 1);
        if (p[i] > 1 && inv[p[i] - 1] < i) change(1, inv[p[i] - 1], 1, n, 1, -1);
        if (p[i] < n && inv[p[i] + 1] < i) change(1, inv[p[i] + 1], 1, n, 1, -1);
        ans += query(1, i, 1, n, 1);
    }
    write(ans), putchar('\n');
    for (int i = 1; i <= n; i++) write(p[i]), putchar(' ');
    putchar('\n');
    return 0;
}

D2T3

不会正解,咕掉了。