给一棵树 \(T(|T|=n)\),每个点有个权值 \(w_i\),从中选出一个子点集 \(P=\{x\in {\rm node}|x\in T\}\),使得 \(\forall u,v\in P,v\in{u{\rm 's\ subtree}}\) 满足 \(w_v\ge w_u\),求 \(|P|_{\max}\)。
数据范围:\(1\le n\le 2\cdot 10^5\),\(0<w_i\le 10^9\)。
这里是线段树合并维护 \(\rm dp\) 的做法,\(\Theta(n\log n)\) 但是跑得最慢,觉得这个做法屑的可以在评论区里怒斥 \(\rm George1123\) 逊然后点踩走人,这个低佬一等的蒟蒻需要您的打击。
先对 \(w_i\) 离散化,然后想 \(n^2\) 的 \(\rm dp\):考虑到每个节点要选,它的子树中就只能选 \(\ge\) 它权值的。所以令 \(f(u,x)\) 表示在 \(u\) 的子树中,选的节点权重都 \(\ge x\) 的最大满足题目要求子集大小。
对于一个节点 \(u\),考虑它选不选,可以如下转移 \(dp\):
\[x\in [1,n],f(u,x)\max=\sum_{v\in{u{\rm 's\ subtree}}}f(v,x)
\]
\[x\in [1,w_u],f(u,x)\max=\sum_{v\in{u{\rm 's\ subtree}}}f(v,w_u)
\]
考虑到要区间转移,可以线段树维护;考虑到要子树合并转移,可以线段树合并。
写线段树合并的时候要维护三个操作:权值求和 \(\rm merge\)、区间与 \(x\) 取 \(\max\)、单点求值。
然后就可以做了,但是因为线段树合并上区间修改很困难,所以您旁边的巨佬 zhoukangyang 会怒斥着告诉您一个更好的做法:差分。
很明显 \(\forall x\in[1,n-1],u\in T,f(u,x)\ge f(u,x+1)\),所以令 \(g(u,x)=f(u,x)-f(u,x+1)\),然后维护。
三个操作对应的新做法为:继续权值求和 \(\rm merge\)、线段树上二分找分界点然后差分单点修改、区间求和。
然后就很好写了,线段树上二分找分界点是 \(\Theta(\log n)\) 的,具体看代码。
从 \(0\) 开始的下标证明蒟蒻是个 \(\rm crab\)。
#include <bits/stdc++.h>
using namespace std;
//Start
typedef long long ll;
typedef double db;
#define mp(a,b) make_pair((a),(b))
#define x first
#define y second
#define be(a) (a).begin()
#define en(a) (a).end()
#define sz(a) int((a).size())
#define pb(a) push_back(a)
#define R(i,a,b) for(int i=(a),I=(b);i<I;i++)
#define L(i,a,b) for(int i=(b)-1,I=(a)-1;i>I;i--)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;
//Data
const int N=2e5;
int n,m,a[N],d[N];
vector<int> e[N];
void adde(int u,int v){e[u].pb(v),e[v].pb(u);}
//SegmentTree
const int T=2e7;
#define mid ((l+r)>>1)
int tn,ls[T],rs[T],mx[T];
int newn(){ls[tn]=rs[tn]=-1;return tn++;}
void pushup(int k){
// mx[k]=mx[ls[k]]+mx[rs[k]]; 这里要考虑到空节点
mx[k]=0;
if(~ls[k]) mx[k]+=mx[ls[k]];
if(~rs[k]) mx[k]+=mx[rs[k]];
}
void add(int&k,int x,int v,int l=0,int r=m){
if(r<=x||x+1<=l) return; if(!~k) k=newn();
if(r-l==1) return mx[k]+=v,void();
add(ls[k],x,v,l,mid),add(rs[k],x,v,mid,r),pushup(k);
}
int find(int k,int v,int l=0,int r=m){
if(!~k) return -1; if(v>mx[k]) return -1; if(r-l==1) return l;
int t; if(~(t=find(rs[k],v,mid,r))) return t;
else return find(ls[k],v-(~rs[k]?mx[rs[k]]:0),l,mid); //这里虽然递归了两条路径,但可以证明至少有一条是 O(1) 的
}
int sum(int k,int x,int l=0,int r=m){
if(r<=x) return 0; if(!~k) return 0; if(x<=l) return mx[k];
return sum(ls[k],x,l,mid)+sum(rs[k],x,mid,r);
}
int merge(int k,int p,int l=0,int r=m){
if(!~k||!~p) return ~k?k:p; if(r-l==1) return mx[k]+=mx[p],k;
ls[k]=merge(ls[k],ls[p],l,mid),rs[k]=merge(rs[k],rs[p],mid,r);
return pushup(k),k;
}
//Tree
int rt[N];
void dfs(int u=0,int fa=-1){
for(int v:e[u])if(v!=fa)
dfs(v,u),rt[u]=merge(rt[u],rt[v]);
int t=sum(rt[u],a[u]),p=find(rt[u],t+1); //批量转移 dp
add(rt[u],a[u],1),add(rt[u],p,-1);
}
//Main
int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n;
R(i,0,n) cin>>a[i],d[i]=--a[i];
sort(d,d+n),m=unique(d,d+n)-d;
R(i,0,n) a[i]=lower_bound(d,d+m,a[i])-d; // 是 d+m 不是 d+n,lower_bound 玄学函数要写准
R(i,1,n){
int fa; cin>>fa,--fa;
adde(fa,i);
}
fill(rt,rt+n,-1),dfs();
cout<<sum(rt[0],0)<<'\n';
return 0;
}
祝大家学习愉快!
手机扫一扫
移动阅读更方便
你可能感兴趣的文章