小海是弓道部的成员,非常擅长射箭(Love Arrow Shoot)。今天弓道部的练习是要射一棵树。一棵树是一个nn个点n−1n−1条边的无向图,且这棵树的第ii个点有一个值wiwi,wi∈[1,m]wi∈[1,m]。每一次小海会射中树的一条边,并将这条边移除。此外,小海定义一棵树的las值为∑vi∗i∑vi∗i,vivi为这棵树中第ii小的wiwi。现在小海会告诉你她射中的边的顺序,你需要回答每一次她射中的边所在的树的las值,之后被射中的边会被移除。答案mod998244353mod998244353
第一行两个数n,mn,m
第二行nn个数wiwi
接下来n−1n−1行每行两个数ai,biai,bi,表示初始的树第ii条边连接aiai和bibi。
接下来n−1n−1行每行一个数表示射中的边。
n-1行每行一个数表示射中的边的树的las值
5 4396
2 3 1 4 5
1 2
1 3
2 4
2 5
4
1
2
3
55
30
5
11
数据范围:
前20% n<=1e3
另外20% m<=10
另外20% 保证第i条边连接i和i+1
另外20% n<=1e5
100% n<=5e5 wi<=m<=1e4
下发一个样例满足第一个部分分。
solution
把操作倒过来,看成加边。
对于每个连通块可以用一棵值域线段树维护每个值出现的次数,值的和,还有la值。
合并时,新的la值可以由值的和*之前的值出现的次数得到。
比如 1*1+2*2+3*3 +2*(1+2+3) -> 3*1+4*2+5*3
线段树合并维护即可
注意x,y均没有左右儿子的特殊情况
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 500005
#define mod 998244353
#define ll long long
using namespace std;
int n,m,st[maxn],ed[maxn],w[maxn],id[maxn];
int fa[maxn],root[maxn*18],tot;
int ls[maxn*18],rs[maxn*18],num[maxn*18];
ll ans[maxn],sum[maxn*18],s[maxn*18];
void build(int &k,int l,int r,int pl){
if(!k)k=++tot;
if(l==r){
num[k]=1;sum[k]=s[k]=pl;
return;
}
int mid=l+r>>1;
if(pl<=mid)build(ls[k],l,mid,pl);
else build(rs[k],mid+1,r,pl);
num[k]=num[ls[k]]+num[rs[k]];
sum[k]=sum[ls[k]]+sum[rs[k]];
s[k]=s[ls[k]]+s[rs[k]];
}
int getf(int k){
if(fa[k]==k)return k;
fa[k]=getf(fa[k]);return fa[k];
}
int merge(int x,int y,int la){
if(!x||!y){
return x+y;
}
ls[x]=merge(ls[x],ls[y],la);
rs[x]=merge(rs[x],rs[y],la+num[ls[x]]);
//printf("x:%d y:%d sumx:%d sumy:%d la:%d\n ls:%d %d %d rs:%d %d %d\n",x,y,sum[x],sum[y],la,sum[ls[x]],num[ls[x]],s[ls[x]],sum[rs[x]],num[rs[x]],s[rs[x]]);
if(!ls[x]&&!rs[x]){
sum[x]=sum[x]+sum[y]+(num[x]*s[y])%mod;sum[x]%=mod;
num[x]=num[x]+num[y];s[x]=s[x]+s[y];
return x;
}
sum[x]=sum[ls[x]]+sum[rs[x]]+(num[ls[x]]*s[rs[x]])%mod;sum[x]%=mod;
num[x]=num[ls[x]]+num[rs[x]];
s[x]=s[ls[x]]+s[rs[x]];
//cout<<sum[x]<<' '<<num[x]<<endl;
return x;
}
int main(){
cin>>n>>m;
for(int i=1;i<=n;i++){
scanf("%d",&w[i]);
fa[i]=i;build(root[i],1,m,w[i]);
}
for(int i=1;i<n;i++){
scanf("%d%d",&st[i],&ed[i]);
}
for(int i=1;i<n;i++)scanf("%d",&id[i]);
for(int i=n-1;i>=1;i--){
int x=getf(st[id[i]]),y=getf(ed[id[i]]);
//cout<<"---------\n";
//cout<<x<<' '<<y<<endl;
root[x]=merge(root[x],root[y],0);
ans[i]=sum[root[x]];
fa[y]=x;
}
for(int i=1;i<n;i++)printf("%lld\n",ans[i]);
return 0;
}
/*
5 5
2 2 1 4 5
1 2
1 3
2 4
2 5
4
1
2
3
*/
手机扫一扫
移动阅读更方便
你可能感兴趣的文章