CF165D Beard Graph(dfs序+树状数组)
阅读原文时间:2023年07月11日阅读:1

题面

题解

乍一看,单点修改,单链查询,用树链剖分维护每条链上白边的数量就完了,

还是……得写树链剖分吗?……3e5,乘两个log会T吗……

(双手颤抖)

(纠结)

不!绝不写树链剖分!

这题如果能维护每个点到根节点路径上的白边数量,就可以用lca直接算,怎么维护呢

把点按dfs序排序,每个点存它到根节点路径上白边数量,当边的颜色变化时,就把以该边下端点为根的子树内的值整体加一或减一,也就是在按dfs序排序后的序列上做区间修改,然后单点查询

把单点修改、区间查询变成区间修改、单点查询了耶!

然后也可以不用线段树,用差分树状数组,三行解决

把区间修改、单点查询又变成单点修改、区间查询了耶!(滑稽

CODE

#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
#include<algorithm>
#define MAXN 300005
#define MAXM 300005
#define ENDL putchar('\n')
#define LL long long
#define DB double
#define lowbit(x) ((-x)&(x))
//#define int LL
//#pragma GCC optimize(2)
using namespace std;
inline LL read() {
    LL f = 1,x = 0;char s = getchar();
    while(s < '0' || s > '9') {if(s == '-')f = -1;s = getchar();}
    while(s >= '0' && s <= '9') {x = x * 10 + (s - '0');s = getchar();}
    return x * f;
}
const int jzm = 1000000007;
int n,m,i,j,s,o,k;
int u[MAXN],v[MAXN],cl[MAXN];
struct it{
    int v,w;
    it(){v = w = 0;}
    it(int V,int W){v = V;w = W;}
};
vector<int> g[MAXN];
int d[MAXN],dfn[MAXN],rr[MAXN],cnt;
int f[MAXN][20];
int c[MAXN];
void addt(int x,int y) {while(x<=n) c[x] += y,x += lowbit(x);}
int sum(int x) {int as=0;while(x>0) as += c[x],x -= lowbit(x);return as;}
void dfs(int x,int fa) {
    d[x] = d[fa] + 1;
    dfn[x] = ++ cnt;
    f[x][0] = fa;
    for(int i = 1;i <= 18;i ++) f[x][i] = f[f[x][i-1]][i-1];
    for(int i = 0;i < g[x].size();i ++) {
        if(g[x][i] != fa) {
            dfs(g[x][i],x);
        }
    }
    rr[x] = cnt;
    return ;
}
int lca(int a,int b) {
    if(d[b] > d[a]) swap(a,b);
    if(d[a] > d[b]) {
        for(int i = 18;i >= 0;i --) {
            if(d[f[a][i]] >= d[b]) {
                a = f[a][i];
            }
        }
    }
    if(a == b) return a;
    for(int i = 18;i >= 0;i --) {
        if(f[a][i] != f[b][i]) {
            a = f[a][i],b = f[b][i];
        }
    }
    return f[a][0];
}
int main() {
    n = read();
    for(int i = 1;i < n;i ++) {
        s = u[i] = read();
        o = v[i] = read();
        g[s].push_back(o);
        g[o].push_back(s);
    }
    dfs(1,1);
    m = read();
    for(int i = 1;i <= m;i ++) {
        k = read();
        if(k == 1) {
            s = read();
            int p = (d[u[s]] > d[v[s]] ? u[s] : v[s]);
            if(cl[s]) addt(dfn[p],-1),addt(rr[p]+1,1),cl[s] = 0;
        }
        else if(k == 2) {
            s = read();
            int p = (d[u[s]] > d[v[s]] ? u[s] : v[s]);
            if(!cl[s]) addt(dfn[p],1),addt(rr[p]+1,-1),cl[s] = 1;
        }
        else {
            s = read();o = read();
            int lc = lca(s,o);
            if(sum(dfn[s]) + sum(dfn[o]) - 2*sum(dfn[lc])) {
                printf("-1\n");
            }
            else printf("%d\n",d[s] + d[o] - d[lc] * 2);
        }
    }
    return 0;
}