【AtCoder AGC023F】01 on Tree(贪心)
阅读原文时间:2023年07月08日阅读:1

Description

给定一颗 \(n\) 个结点的树,每个点有一个点权 \(v\)。点权只可能为 \(0\) 或 \(1\)。

现有一个空数列,每次可以向数列尾部添加一个点 \(i\) 的点权 \(v_i\),但必须保证此时 \(i\) 没有父结点。添加后将 \(i\) 删除。

这样可以一个长为 \(n\) 的数列 \(x\)。求 \(x\) 中逆序对数的最小值。

Hint

  • \(1\le n\le 2\times 10^5\)
  • \(v_i \in \{0, 1\}\)

Solution

由于一个结点的父结点尚未被删除,那么现在该结点则无法被加入数列。可见题目要求我们 从树根自顶向下 删除。

但显然我们不会这样做——我们 将所有结点视作独立,向父亲方向合并


我们不妨先考虑这样一个问题:对于一个根结点为 \(x\) 的树,其子结点为 \(y_1, y_2, \cdots y_k\)。假设子树 \(y_1, y_2, \cdots y_k\) 都已经合并好了,那么我们只要将这些子树合并答案,向上传答案即可。

首先,由题意得,结点 \(x\) 的点权必须排在最前面。接下来就需要合理安排顺序,使得 跨越子树的逆序对 数量最小。由于子树内在前期早已统计完毕,此处无需再做讨论。

为方便讨论,在这里我们还需维护子树中 \(0, 1\) 的个数,分别记为 \(\text{cnt}(\cdots, 0), \text{cnt}(\cdots, 1)\)。

若要使逆序对尽可能小,而权值就只有 \(0, 1\),第一直觉就是 贪心地把 \(0\) 尽量排前面

但直觉是很模糊的,我们需要一个明确的标准。

对于两个子树 \(y_i, y_j\),如果 \(y_i\) 排在前面,那么会产生 \(\text{cnt}(y_i, 1)\times \text{cnt}(y_j, 0)\) 个逆序对,反正则会产生 \(\text{cnt}(y_j, 1)\times \text{cnt}(y_i, 0)\) 个。

显然我们应选择结果较少的策略——优先选取 \(\dfrac{\text{cnt}(y, 0)}{\text{cnt}(y, 1)}\) 较小的。为避免除以零造成 RE,需要化除为乘。


但此题不能直接递归处理,需要全局一起算,即上文中“将所有结点视作独立,向父亲方向合并”的思路。

那么子树的 \(\text{cnt}\) 值就变成了 连通块 的 \(\text{cnt}\) 值,容易发现上面的贪心思路于此仍然有效。

此处涉及连通块整块信息的维护,不难想到 并查集。连通块的有序维护,可以使用

在每个点向上合并后,父亲方向结点需要删去,这对于堆来说就不太方便(当然可以考虑 multiset 或 可删堆

但其实不用这么麻烦:直接根据 \(\text{cnt}\) 值判断是否已经被合并然后选择性跳过即可。

最后做到 1 号点就不用重新插入堆中了。

Code

/*
 * Author : _Wallace_
 * Source : https://www.cnblogs.com/-Wallace-/
 * Problem : AtCoder AGC023F 01 on Tree
 */
#include <algorithm>
#include <iostream>
#include <queue>

using namespace std;
const int N = 2e5 + 5;

int n, fa[N], dsu[N];
int cnt[N][2];

struct item {
    int c0, c1, idx;
    bool operator < (const item& t) const {
        return c0 * 1ll * t.c1 < c1 * 1ll * t.c0;
    }
};
priority_queue<item> pq;

int find(int x) {
    return x == dsu[x] ? x : dsu[x] = find(dsu[x]);
}

signed main() {
    ios::sync_with_stdio(false);

    cin >> n;
    for (register int i = 2; i <= n; i++)
        cin >> fa[i];
    for (register int i = 1, val; i <= n; i++)
        cin >> val, cnt[i][val]++;
    for (register int i = 1; i <= n; i++)
        dsu[i] = i;

    long long ans = 0;
    for (register int i = 2; i <= n; i++)
        pq.push({cnt[i][0], cnt[i][1], i});

    while (!pq.empty()) {
        item cur = pq.top(); pq.pop();
        int x = find(cur.idx), c0 = cur.c0, c1 = cur.c1;

        if (cnt[x][0] != c0 || cnt[x][1] != c1)
            continue;

        int y = find(fa[x]);
        ans += cnt[y][1] * 1ll * cnt[x][0];
        cnt[y][0] += cnt[x][0];
        cnt[y][1] += cnt[x][1];

        dsu[x] = y;
        if (y > 1) pq.push({cnt[y][0], cnt[y][1], y});
    }

    cout << ans << endl;
    return 0;
}