P6329 【模板】点分树 | 震波
阅读原文时间:2023年07月08日阅读:1

点分树就是将点分治过程中的重心连成一棵虚树

对点分树子树信息的记录,就是点分治处理每个重心时需要的信息

这样就可以留下点分治的过程,支持多次修改和查询

点分树树高 \(O(log n)\) 且 \(\sum size_x = O(n \log n)\)

可以使用很多暴力的手段

但要注意:点分树和原树唯一的联系是点分树中两点的 \(LCA\) 在原树两点的路径上

\(LCA\) 的祖先和这两点的路径再无关系,容斥时要思考清楚

所以统计路径长一切行为以原树为准

因为要大量求 \(LCA\),所以用欧拉序转 \(RMQ\)

两点的 \(LCA\) 就是 \([\min(first_u,first_v),\max(first_u,first_v)]\) 中 \(dep\) 最小的点

注意:

欧拉序有两种:一个点入栈和出栈时记录,序列长 \(2n\)

一个点入栈记一次,每次回溯都记一次,考虑边数得序列长 \(2n-1\)

求 \(LCA\) 时用第二种

又:\(\text{vector}\) 的 \(\text{size()}\) 返回值为 \(\text{unsigned int}\),比较时将参与比较的元素强转 \(\text{unsigned int}\)

所以用负数比较会挂,这点让我懵逼了很久

#include <cstdio>
#include <iostream>
#include <vector>
#define IN inline
using namespace std;

const int N = 1e5 + 5;
int n, m, h[N], tot, a[N];
struct edge{int to, nxt;}e[N * 2];
IN void add(int x, int y) {e[++tot] = edge{y, h[x]}, h[x] = tot;}

int dep[N], rt, size, used[N], son[N], sz[N], Rt, fa[N];
struct BIT {
    vector <int> c;
    IN void build(int n) {c.resize(n);}
    IN int lowbit(int x) {return x & (-x);}
    IN void add(int x, int v) {for(; x < c.size(); x += lowbit(x)) c[x] += v;}
    IN int query(int x) {
        if (x >= (int)c.size()) x = c.size() - 1;
        int s = 0; for(; x > 0; x -= lowbit(x)) s += c[x]; return s;
    }
}tr[N][2];

int rev[N * 2], st[N], dfc, lg[N * 2], mn[N * 2][21];
void dfs(int x, int dad) {
    st[x] = ++dfc, rev[dfc] = x;
    for(int i = h[x], v; i; i = e[i].nxt) {
        if ((v = e[i].to) == dad) continue;
        dep[v] = dep[x] + 1, dfs(v, x), rev[++dfc] = x;
    }
}
IN int LCA(int x, int y) {
    x = st[x], y = st[y]; if (x > y) swap(x, y);
    int k = lg[y - x + 1];
    if (dep[mn[x][k]] < dep[mn[y - (1 << k) + 1][k]]) return mn[x][k];
    return mn[y - (1 << k) + 1][k];
}
IN int Dis(int x, int y) {return dep[x] + dep[y] - dep[LCA(x, y)] * 2;}

void getrt(int x, int dad) {
    sz[x] = 1, son[x] = 0;
    for(int i = h[x], v; i; i = e[i].nxt) {
        if ((v = e[i].to) == dad || used[v]) continue;
        getrt(v, x), sz[x] += sz[v], son[x] = max(son[x], sz[v]);
    }
    son[x] = max(son[x], size - sz[x]);
    if (son[rt] > son[x]) rt = x;
}
void divide(int x) {
    used[x] = 1, tr[x][0].build(size + 1), tr[x][1].build(size + 2);
    for(int i = h[x], v; i; i = e[i].nxt) {
        if (used[v = e[i].to]) continue;
        rt = 0, size = sz[v], getrt(v, x), fa[rt] = x, divide(rt);
    }
}
void obtain() {
    lg[0] = -1;
    for(int i = 1; i <= dfc; i++) mn[i][0] = rev[i], lg[i] = lg[i >> 1] + 1;
    for(int i = 1; i <= lg[dfc]; i++) {
        for(int j = 1; j + (1 << i) - 1 <= dfc; j++)
            if (dep[mn[j][i - 1]] < dep[mn[j + (1 << i - 1)][i - 1]])
                mn[j][i] = mn[j][i - 1]; else mn[j][i] = mn[j + (1 << i - 1)][i - 1];
    }
    for(int i = 1; i <= n; i++)
        for(int j = i; j; j = fa[j]) {
            tr[j][0].add(Dis(j, i) + 1, a[i]);
            if (fa[j]) tr[j][1].add(Dis(fa[j], i) + 1, a[i]);
        }
}

IN void read(int &x) {
    x = 0; char ch = getchar(); int f = 1;
    for(; !isdigit(ch); f = (ch == '-' ? -1 : f), ch = getchar());
    for(; isdigit(ch); x = (x<<3)+(x<<1)+(ch^48), ch = getchar());
    x *= f;
}
IN int Query(int x, int k) {
    int ans = 0;
    for(int i = x; i; i = fa[i]) {
        ans += tr[i][0].query(k - Dis(i, x) + 1);
        if (fa[i]) ans -= tr[i][1].query(k - Dis(fa[i], x) + 1);
    }
    return ans;
}

int main() {
    read(n), read(m);
    for(int i = 1; i <= n; i++) read(a[i]);
    for(int i = 1, u, v; i < n; i++) read(u), read(v), add(u, v), add(v, u);
    rt = 0, size = n, son[0] = 2e9, getrt(1, 0), Rt = rt, divide(rt), dfs(Rt, 0), obtain();
    for(int op, x, y, lst = 0; m; --m) {
        read(op), read(x), read(y), x ^= lst, y ^= lst;
        if (op) {
            for(int i = x; i; i = fa[i]) {
                tr[i][0].add(Dis(x, i) + 1, y - a[x]);
                if (fa[i]) tr[i][1].add(Dis(x, fa[i]) + 1, y - a[x]);
            }
            a[x] = y;
        }
        else printf("%d\n", lst = Query(x, y));
    }
}