【LGP5439】【XR-2】永恒
阅读原文时间:2023年07月11日阅读:1

题目

是个傻题

显然枚举每一条路径经过了多少次,如果\(u,v\)在树上不是祖先关系的话经过\((u,v)\)这条路径的路径条数就是\(sum_u\times sum_v\)

于是我们子树大小映射到\(\rm Trie\)上去,树形\(\rm dp\)一下就可以求出所有点对产生的贡献了

但是这样祖先关系的节点就算错了,我们发现这也非常好算,\(\rm dfs\)的时候拿\(\rm LCT\)维护一下就好了

代码

#include<bits/stdc++.h>
#define re register
inline int read() {
    char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=3e5+5;
const int mod=998244353;
struct E{int v,nxt;}e[maxn];
inline int qm(int x) {return x>=mod?x-mod:x;}
inline int dqm(int x) {return x<0?x+mod:x;}
int n,m,num,rt,ans,sm[maxn],head[maxn],d[maxn];
char S[maxn];
struct Trie {
    E e[maxn<<1];
    int head[maxn],num,v[maxn],deep[maxn];
    inline void add(int x,int y) {
        e[++num].v=y;e[num].nxt=head[x];head[x]=num;
    }
    void pdfs(int x) {
        for(re int i=head[x];i;i=e[i].nxt) deep[e[i].v]=deep[x]+1,pdfs(e[i].v);
    }
    void dfs(int x,int dep) {
        for(re int i=head[x];i;i=e[i].nxt) {
            dfs(e[i].v,dep+1);
            ans=qm(ans+1ll*dep*v[x]%mod*v[e[i].v]%mod);
            v[x]=qm(v[x]+v[e[i].v]);
        }
    }
}T;
struct LinkCutTree {
    int fa[maxn],ch[maxn][2],rev[maxn],tag[maxn],st[maxn],top,sum[maxn],a[maxn],sz[maxn];
    inline int nrt(int x) {return ch[fa[x]][1]==x||ch[fa[x]][0]==x;}
    inline void pushup(int x) {
        sz[x]=1+sz[ch[x][0]]+sz[ch[x][1]];sum[x]=a[x];
        if(ch[x][0]) sum[x]=qm(sum[x]+sum[ch[x][0]]);
        if(ch[x][1]) sum[x]=qm(sum[x]+sum[ch[x][1]]);
    }
    inline void work(int x,int v) {
        a[x]=qm(a[x]+v);tag[x]=qm(tag[x]+v);
        sum[x]=qm(sum[x]+1ll*sz[x]*v%mod);
    }
    inline void pushdown(int x) {
        if(tag[x]) {
            if(ch[x][0]) work(ch[x][0],tag[x]);
            if(ch[x][1]) work(ch[x][1],tag[x]);
            tag[x]=0;
        }
        if(rev[x]) {
            rev[x]=0;rev[ch[x][0]]^=1;rev[ch[x][1]]^=1;
            std::swap(ch[ch[x][0]][0],ch[ch[x][0]][1]);
            std::swap(ch[ch[x][1]][0],ch[ch[x][1]][1]);
        }
    }
    inline void rotate(int x) {
        int y=fa[x],z=fa[y],w=ch[y][1]==x,k=ch[x][w^1];
        if(nrt(y)) ch[z][ch[z][1]==y]=x;
        ch[x][w^1]=y,ch[y][w]=k;
        pushup(y),pushup(x);fa[k]=y,fa[y]=x,fa[x]=z;
    }
    inline void splay(int x) {
        int y=x;top=0;st[++top]=x;
        while(nrt(y)) y=fa[y],st[++top]=y;
        while(top) pushdown(st[top--]);
        while(nrt(x)) {
            int y=fa[x];
            if(nrt(y)) rotate((ch[fa[y]][1]==y)^(ch[y][1]==x)?x:y);
            rotate(x);
        }
    }
    inline void access(int x) {
        for(re int y=0;x;y=x,x=fa[x])
            splay(x),ch[x][1]=y,pushup(x);
    }
    inline void mrt(int x) {
        access(x);splay(x);rev[x]^=1;std::swap(ch[x][0],ch[x][1]);
    }
    inline void link(int x,int y) {
        mrt(x);fa[x]=y;T.add(x,y);
    }
    inline void split(int x,int y) {
        mrt(x);access(y);splay(y);
    }
    inline void ins(int x,int y,int v) {
        split(x,y);v=dqm(v);work(y,v);
    }
    inline int query(int x,int y) {
        split(x,y);
        return dqm(sum[y]-a[y]);
    }
}lct;
inline void add(int x,int y) {
    e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void dfs1(int x) {
    sm[x]=1;
    for(re int i=head[x];i;i=e[i].nxt) dfs1(e[i].v),sm[x]+=sm[e[i].v];
}
void dfs2(int x) {
    ans=qm(ans+1ll*sm[x]*lct.query(d[x],1)%mod);
    for(re int i=head[x];i;i=e[i].nxt) {
        lct.ins(1,d[x],n-sm[e[i].v]-sm[x]);
        dfs2(e[i].v);
        lct.ins(1,d[x],sm[x]+sm[e[i].v]-n);
    }
}
int main() {
    n=read(),m=read();
    for(re int x,i=1;i<=n;i++) {
        x=read();if(x) add(x,i);else rt=i;
    }
    for(re int x,i=1;i<=m;i++) {
        x=read();if(x) lct.link(x,i);
    }
    dfs1(rt);scanf("%s",S+1);T.pdfs(1);
    for(re int i=1;i<=n;i++) {
        d[i]=read();
        ans=qm(ans+1ll*sm[i]*T.deep[d[i]]%mod*T.v[d[i]]%mod);
        T.v[d[i]]=qm(T.v[d[i]]+sm[i]);
    }
    T.dfs(1,0);dfs2(rt);printf("%d\n",ans);
    return 0;
}