看到换根果断lct啊,然而其实我板子还没有打熟,还不会维护子树信息,于是就挂掉了……
然而正解并不是lct。
其实好像很久很久以前将lca的时候好像讲到过一道换根的题,当时没有听懂。
直接说正解吧:
把dfs序搞出来用线段树维护。
用一个变量记录当前根节点,操作一直接改就行了。
然后是操作三:
分情况讨论,设当前根节点为root,询问的点为a。
如果root不在a的子树内,那么root不会影响a的子树,仍然输出1为根时的子树和就行了。
如果在子树内,
如图,如果要查询1的子树和,那么找到1与root这条链上靠近1的点,整体的和减去这个点的子树和就是了。(感性理解一下。)
那么这个点怎么求呢?只需要对于每个点把与他直接相连的儿子的dfs序塞到一个vector里,upper_bound然后-1就可以了。
操作二:
首先求出以1为根是a,b两点的lca。
与操作三类似,如果root不在lca的子树内,那么root不会影响lca,直接加就行。
如果在子树内,那么找a,b与root的lca中深度较大的那个。
之后就和操作三一样了。
#include
#include
#include
#include
#include
#define LL long long
#define re register
#define co const
#define rec re co
#define inline __attribute((always_inline))
const int LLL=<<|;
char buffer[LLL],*S,*TT;
#define getchar() ((S==TT&&(TT=(S=buffer)+fread(buffer,1,LLL,stdin),S==TT))?EOF:*S++)
using namespace std;
struct edge
{
int u,v,nxt;
#define u(x) ed[x].u
#define v(x) ed[x].v
#define n(x) ed[x].nxt
}ed[];
int first[],num_e;
#define f(x) first[x]
int n,q,w[];
bool pd2=,pd3=;
vector
LL fa[][],dep[];
int dfn[],id[],cnt,L[],R[];LL res;
int siz[],son[],top[];
void dfs1(int x)
{
siz[x]=;
for(int i=f(x);i;i=n(i))
if(v(i)!=fa[x][])
{
fa[v(i)][]=x;dep[v(i)]=dep[x]+;
dfs1(v(i));siz[x]+=siz[v(i)];
if(siz[v(i)]>siz[son[x]])son[x]=v(i);
}
}
void dfs2(int x,int t)
{
top[x]=t;dfn[x]=++cnt;id[cnt]=x;L[x]=cnt;
if(son[x])dfs2(son[x],t),inc[x].push_back(dfn[son[x]]);
for(int i=f(x);i;i=n(i))
if(v(i)!=fa[x][]&&v(i)!=son[x])
dfs2(v(i),v(i)),inc[x].push_back(dfn[v(i)]);
R[x]=cnt;
}
inline int LCA(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])swap(x,y);
y=fa[top[y]][];
}
if(dep[x]>dep[y])swap(x,y);
return x;
}
struct xjs_Tree
{
struct tree
{
int l,r;LL sum,la;
#define l(x) tr[x].l
#define r(x) tr[x].r
#define la(x) tr[x].la
#define sum(x) tr[x].sum
#define ls(x) ((x)<<1)
#define rs(x) (ls(x)+1)
}tr[];
void build(rec int x,rec int l,rec int r)
{
l(x)=l,r(x)=r,la(x)=;
if(l==r){sum(x)=w[id[l]];return;}
int mid=(l+r)>>;
build(ls(x),l,mid);
build(rs(x),mid+,r);
sum(x)=sum(ls(x))+sum(rs(x));
}
inline void down(rec int x)
{
if(l(x)==r(x))return;
if(!la(x))return;
sum(ls(x))+=(r(ls(x))-l(ls(x))+)*la(x);
sum(rs(x))+=(r(rs(x))-l(rs(x))+)*la(x);
la(ls(x))+=la(x);la(rs(x))+=la(x);la(x)=;
}
void add(rec int x,rec int l,rec int r,rec LL y)
{
if(l>r)return;
down(x);
if(l(x)>=l&&r(x)<=r)
{
sum(x)+=(r(x)-l(x)+)*y;
la(x)+=y;return;
}
int mid=(l(x)+r(x))>>;
if(l<=mid)add(ls(x),l,r,y);
if(r> mid)add(rs(x),l,r,y);
sum(x)=sum(ls(x))+sum(rs(x));
}
LL ask(rec int x,rec int l,rec int r)
{
if(l>r)return ;
down(x);
if(l(x)>=l&&r(x)<=r)return sum(x);
int mid=(l(x)+r(x))>>;LL ans=;
if(l<=mid)ans+=ask(ls(x),l,r);
if(r> mid)ans+=ask(rs(x),l,r);
return ans;
}
}T;
inline int read();
inline void add(rec int u,rec int v);
signed main()
{
// freopen("S1_1.in","r",stdin);
// freopen("in.txt","r",stdin);
// freopen("1.out","w",stdout);
n=read(),q=read();int tu,tv;
for(re int i=;i<=n;i++)w\[i\]=read();
for(re int i=;i<n;i++)tu=read(),tv=read(),add(tu,tv),add(tv,tu);
re int root=;
dep\[\]=;dfs1(),dfs2(,);T.build(,,n);
re int opt,a,b,c;
for(re int i=;i<=q;i++)
{
opt=read();
if(opt==)a=read(),root=a;
if(opt==)
{
a=read(),b=read(),c=read();
int lca=LCA(a,b);
if(lca==root)T.add(,,n,c);
else if(dfn\[root\]<L\[lca\]||dfn\[root\]>R\[lca\])T.add(,L\[lca\],R\[lca\],c);
else
{
int t1=LCA(a,root),t2=LCA(b,root);
if(dep\[t1\]>dep\[t2\])lca=t1;
else lca=t2;
if(lca==root)T.add(,,n,c);
else if(dfn\[root\]<L\[lca\]||dfn\[root\]>R\[lca\])T.add(,L\[lca\],R\[lca\],c);
else
{
int te=upper\_bound(inc\[lca\].begin(),inc\[lca\].end(),dfn\[root\])-inc\[lca\].begin()-;
lca=inc\[lca\]\[te\];lca=id\[lca\];
T.add(,,L\[lca\]-,c),T.add(,R\[lca\]+,n,c);
}
}
}
if(opt==)
{
a=read();
if(a==root)printf("%lld\\n",T.ask(,,n));
else if(dfn\[root\]<L\[a\]||dfn\[root\]>R\[a\])printf("%lld\\n",T.ask(,L\[a\],R\[a\]));
else
{
int te=upper\_bound(inc\[a\].begin(),inc\[a\].end(),dfn\[root\])-inc\[a\].begin()-;
te=inc\[a\]\[te\];te=id\[te\];
printf("%lld\\n",T.ask(,,L\[te\]-)+T.ask(,R\[te\]+,n));
}
}
}
return ;
}
inline int read()
{
int s=,f=;char a=getchar();
while(a<''||a>''){if(a=='-')f=-;a=getchar();}
while(a>=''&&a<=''){s=s*+a-'';a=getchar();}
return s*f;
}
inline void add(rec int u,rec int v)
{
++num_e;
u(num_e)=u;
v(num_e)=v;
n(num_e)=f(u);
f(u)=num_e;
}
手机扫一扫
移动阅读更方便
你可能感兴趣的文章