爱之箭发射(las)
阅读原文时间:2023年07月12日阅读:1

目描述

小海是弓道部的成员,非常擅长射箭(Love Arrow Shoot)。今天弓道部的练习是要射一棵树。一棵树是一个nn个点n−1n−1条边的无向图,且这棵树的第ii个点有一个值wiwi,wi∈[1,m]wi∈[1,m]。每一次小海会射中树的一条边,并将这条边移除。此外,小海定义一棵树的las值为∑vi∗i∑vi∗i,vivi为这棵树中第ii小的wiwi。现在小海会告诉你她射中的边的顺序,你需要回答每一次她射中的边所在的树的las值,之后被射中的边会被移除。答案mod998244353mod998244353

输入

第一行两个数n,mn,m

第二行nn个数wiwi

接下来n−1n−1行每行两个数ai,biai,bi,表示初始的树第ii条边连接aiai和bibi。

接下来n−1n−1行每行一个数表示射中的边。

输出

n-1行每行一个数表示射中的边的树的las值

样例输入

5 4396
2 3 1 4 5
1 2
1 3
2 4
2 5
4
1
2
3

样例输出

55
30
5
11

提示

数据范围

前20% n<=1e3

另外20% m<=10

另外20% 保证第i条边连接i和i+1

另外20% n<=1e5

100% n<=5e5 wi<=m<=1e4

下发一个样例满足第一个部分分。

来源

noip2018模拟-robinliu


solution

把操作倒过来,看成加边。

对于每个连通块可以用一棵值域线段树维护每个值出现的次数,值的和,还有la值。

合并时,新的la值可以由值的和*之前的值出现的次数得到。

比如 1*1+2*2+3*3  +2*(1+2+3) -> 3*1+4*2+5*3

线段树合并维护即可

注意x,y均没有左右儿子的特殊情况

#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 500005
#define mod 998244353
#define ll long long
using namespace std;
int n,m,st[maxn],ed[maxn],w[maxn],id[maxn];
int fa[maxn],root[maxn*18],tot;
int ls[maxn*18],rs[maxn*18],num[maxn*18];
ll ans[maxn],sum[maxn*18],s[maxn*18];
void build(int &k,int l,int r,int pl){
    if(!k)k=++tot;
    if(l==r){
        num[k]=1;sum[k]=s[k]=pl;
        return;
    }
    int mid=l+r>>1;
    if(pl<=mid)build(ls[k],l,mid,pl);
    else build(rs[k],mid+1,r,pl);
    num[k]=num[ls[k]]+num[rs[k]];
    sum[k]=sum[ls[k]]+sum[rs[k]];
    s[k]=s[ls[k]]+s[rs[k]];
}
int getf(int k){
    if(fa[k]==k)return k;
    fa[k]=getf(fa[k]);return fa[k];
}
int merge(int x,int y,int la){

    if(!x||!y){
        return x+y;
    }
    ls[x]=merge(ls[x],ls[y],la);
    rs[x]=merge(rs[x],rs[y],la+num[ls[x]]);
    //printf("x:%d y:%d sumx:%d sumy:%d la:%d\n ls:%d  %d  %d rs:%d %d %d\n",x,y,sum[x],sum[y],la,sum[ls[x]],num[ls[x]],s[ls[x]],sum[rs[x]],num[rs[x]],s[rs[x]]);
    if(!ls[x]&&!rs[x]){
        sum[x]=sum[x]+sum[y]+(num[x]*s[y])%mod;sum[x]%=mod;
        num[x]=num[x]+num[y];s[x]=s[x]+s[y];
        return x;
    }
    sum[x]=sum[ls[x]]+sum[rs[x]]+(num[ls[x]]*s[rs[x]])%mod;sum[x]%=mod;
    num[x]=num[ls[x]]+num[rs[x]];
    s[x]=s[ls[x]]+s[rs[x]];
    //cout<<sum[x]<<' '<<num[x]<<endl;
    return x;
}
int main(){
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        scanf("%d",&w[i]);
        fa[i]=i;build(root[i],1,m,w[i]);
    }
    for(int i=1;i<n;i++){
        scanf("%d%d",&st[i],&ed[i]);
    }
    for(int i=1;i<n;i++)scanf("%d",&id[i]);
    for(int i=n-1;i>=1;i--){
        int x=getf(st[id[i]]),y=getf(ed[id[i]]);

        //cout<<"---------\n";
        //cout<<x<<' '<<y<<endl;
        root[x]=merge(root[x],root[y],0);
        ans[i]=sum[root[x]];
        fa[y]=x;
    }
    for(int i=1;i<n;i++)printf("%lld\n",ans[i]);
    return 0;
}
/*
5 5
2 2 1 4 5
1 2
1 3
2 4
2 5
4
1
2
3
*/