HDU 5877 2016大连网络赛 Weak Pair(树状数组,线段树,动态开点,启发式合并,可持久化线段树)
阅读原文时间:2024年09月16日阅读:1

Weak Pair

**Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)

Total Submission(s): 1468    Accepted Submission(s): 472

**

Problem Description

You are given a rooted tree
of N nodes,
labeled from 1 to N.
To the ith
node a non-negative value ai is
assigned.An ordered pair
of nodes (u,v) is
said to be weak if

  (1) u is
an ancestor of v (Note:
In this problem a node u is
not considered an ancestor of itself);

  (2) au×av≤k.

Can you find the number of weak pairs in the tree?

Input

There are multiple cases in the data set.

  The first line of input contains an integer T denoting
number of test cases.

  For each case, the first line contains two space-separated integers, N and k,
respectively.

  The second line contains N space-separated
integers, denoting a1 to aN.

  Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v ,
where node u is
the parent of node v.

  Constrains: 

  

  1≤N≤105 

  

  0≤ai≤109 

  

  0≤k≤1018

Output

For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.

Sample Input

1
2 3
1 2
1 2

Sample Output

1

这是一道很好的数据结构的题目:

可以用很多方法写

首先思路是:dfs这颗树,每到一个节点,都计算这个节点的祖先中满足条件的有几个

而计算这个就需要维护一个序列,并且高效的得出多少个祖先满足条件。

即在序列中找到小于k/a[i]的数有多少个,很容易想到用树状数组和线段树。

权值1e9需要离散化。

树状数组:

#include
#include
#include
#include
#include
#include
#include

using namespace std;
const int maxn=1e5;
typedef long long int LL;
int n;
LL k;
LL a[maxn+5];
struct Node
{
int value;
int next;
}edge[maxn*2+5];
int head[maxn+5];
int vis[maxn+5];
int tot;
int c[maxn*2+5];
LL b[maxn+5];
LL e[maxn*2+5];
map m;
void add(int x,int y)
{
edge[tot].value=y;
edge[tot].next=head[x];
head[x]=tot++;
}
int lowbit(int x)
{
return x&(-x);
}
void update(int x,int num)
{
while(x<=n*2) { c[x]+=num; x+=lowbit(x); } } int sum(int x) { int _sum=0; while(x>0)
{
_sum+=c[x];
x-=lowbit(x);
}
return _sum;
}
LL ans;
void dfs(int root)
{
vis[root]=1;
for(int i=head[root];i!=-1;i=edge[i].next)
{
int v=edge[i].value;
if(!vis[v])
{
ans+=sum(m[b[v]]);
update(m[a[v]],1);
dfs(v);
update(m[a[v]],-1);
}
}
}

void init()
{
memset(c,0,sizeof(c));
memset(vis,0,sizeof(vis));
memset(head,-1,sizeof(head));
tot=0;
}
int tag[maxn+5];
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
init();
int cnt=n;
m.clear();
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
e[i]=a[i];
if(a[i]==0)
m[a[i]]=2*n;
else
{
b[i]=k/a[i];
e[++cnt]=b[i];
}
}
sort(e+1,e+cnt+1);
int tot=1;
for(int i=1;i<=cnt;i++)
{
if(!m.count(e[i]))
m[e[i]]=tot++;
}
memset(tag,0,sizeof(tag));
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
tag[y]++;
}
int root;
for(int i=1;i<=n;i++)
{
if(tag[i]==0)
root=i;
}
ans=0;
update(m[a[root]],1);
dfs(root);
printf("%lld\n",ans);
}
return 0;
}

线段树:

#include   
#include   
#include   
#include   
#include   
#include   
#include   
#include   
#include 

using namespace std;  
typedef long long int LL;  
const int maxn=1e5;  
vector v\[maxn+5\];  
int sum\[maxn\*8+5\];  
int n;  
LL k;  
LL a\[maxn+5\];  
LL b\[maxn+5\];  
LL e\[maxn\*2+5\];  
int aa\[maxn+5\];  
int bb\[maxn+6\];  
map m;

void PushUp(int node)  
{  
    sum\[node\]=sum\[node<<1\]+sum\[node<<1|1\];  
}  
void update(int node,int begin,int end,int ind,int num)  
{  
    if(begin==end)  
    {  
        sum\[node\]+=num\*(end-begin+1);  
        return;  
    }  
    int m=(begin+end)>>1;  
    if(ind<=m)  
        update(node<<1,begin,m,ind,num);  
    else  
        update(node<<1|1,m+1,end,ind,num);  
    PushUp(node);  
}  
LL Query(int node,int begin,int end,int left,int right)  
{  
    if(left<=begin&&end<=right)  
        return sum\[node\];  
    int m=(begin+end)>>1;  
    LL ret=0;  
    if(left<=m)  
        ret+=Query(node<<1,begin,m,left,right);  
    if(right>m)  
        ret+=Query(node<<1|1,m+1,end,left,right);  
    PushUp(node);  
    return ret;  
}  
int tag\[maxn+5\];  
LL ans;  
void dfs(int root)  
{  
    int len=v\[root\].size();  
    for(int i=0;i

#include   
#include
#include
#include
#include
#include
#include
#include

using namespace std;
const int maxn=1e5;
const long long int len=1e18;
typedef long long int LL;
LL a[maxn+5];
LL b[maxn+5];
int n;
LL k;
vector v[maxn+5];
struct Node
{
int lch,rch;
LL sum;
Node(){};
Node(int lch,int rch,LL sum)
{
this->lch=lch;
this->rch=rch;
this->sum=sum;
}
}tr[maxn*100+5];
int p;
void PushUp(int node)
{
tr[node].sum=tr[tr[node].lch].sum+tr[tr[node].rch].sum;
}

int newnode()
{
tr[++p]=Node(-1,-1,0);
return p;
}
void update(int node,LL begin,LL end,LL ind,int num)
{
if(begin==end)
{
tr[node].sum+=num;
return;
}
LL m=(begin+end)>>1;
if(tr[node].lch==-1) tr[node].lch=newnode();
if(tr[node].rch==-1) tr[node].rch=newnode();
if(ind<=m) update(tr[node].lch,begin,m,ind,num); else update(tr[node].rch,m+1,end,ind,num); PushUp(node); } LL query(int node,LL begin,LL end,LL left,LL right) { if(node==-1) return 0; if(left<=begin&&end<=right) return tr[node].sum; LL m=(begin+end)>>1;
LL ret=0;
if(left<=m) ret+=query(tr[node].lch,begin,m,left,right); if(right>m)
ret+=query(tr[node].rch,m+1,end,left,right);
PushUp(node);
return ret;

}
int tag[maxn+5];
LL ans;
void dfs(int root)
{
int len1=v[root].size();
for(int i=0;i<len1;i++)
{
int w=v[root][i];
ans+=query(1,0,len,0,b[w]);
update(1,0,len,a[w],1);
dfs(w);
update(1,0,len,a[w],-1);
}
}
void init()
{
memset(tag,0,sizeof(tag));
p=0;
newnode();
}
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
b[i]=k/a[i];
v[i].clear();
}
init();
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
v[x].push_back(y);
tag[y]++;
}
int root;
for(int i=1;i<=n;i++)
{
if(!tag[i])
root=i;
}
ans=0;
update(1,0,len,a[root],1);
dfs(root);
printf("%lld\n",ans);
}
return 0;
}

还可以自下而上,用线段树的启发合并,计算每一个节点的所有子节点对他的贡献

关于线段树的启发式合并,有必要再写一篇博客总结一下

#include
#include
#include
#include
#include
#include

using namespace std;
const int maxn=1e5;
typedef long long int LL;
int rt[maxn*100+5];
int ls[maxn*100+5];
int rs[maxn*100+5];
LL sum[maxn*100+5];
int a[maxn+5];
LL k;
int n;
int p;
int l,r;
int newnode()
{
sum[p]=ls[p]=rs[p]=0;
return p++;
}
void build(int &node,int begin,int end,LL val)
{
if(!node) node=newnode();
sum[node]=1;
if(begin==end) return;
LL mid=(begin+end)>>1;
if(val<=mid) build(ls[node],begin,mid,val); else build(rs[node],mid+1,end,val); } LL Query(int node,int begin,int end,LL val) { if(!node||val>1;
if(val<=mid) return Query(ls[node],begin,mid,val); else return sum[ls[node]]+Query(rs[node],mid+1,end,val); } void mergge(int &x,int y, int begin,int end) { if(!x||!y) {x=x^y;return;} sum[x]+=sum[y]; if(begin==end) return; LL mid=(begin+end)>>1;
mergge(ls[x],ls[y],begin,mid);
mergge(rs[x],rs[y],mid+1,end);
}
struct Node
{
int value;
int next;
}edge[maxn*2+5];
int head[maxn+5];
int tot;
void add(int x,int y)
{
edge[tot].value=y;
edge[tot].next=head[x];
head[x]=tot++;
}
LL ans;
void dfs(int root)
{
for(int i=head[root];i!=-1;i=edge[i].next)
{
int w=edge[i].value;
dfs(w);
mergge(rt[root],rt[w],l,r);
}
ans+=Query(rt[root],l,r,k/a[root]);
if(k>=1ll*a[root]*a[root])
ans--;
}
int tag[maxn+5];
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
p=1;
memset(tag,0,sizeof(tag));
memset(head,-1,sizeof(head));

    tot=0;  
    l=1e9;r=0;  
    for(int i=1;i<=n;i++)  
    {  
         scanf("%d",&a\[i\]);  
         l=min(l,a\[i\]);r=max(r,a\[i\]);  
    }  
    for(int i=1;i<=n;i++)  
        build(rt\[i\]=0,l,r,a\[i\]);

    for(int i=1;i<=n-1;i++)  
    {  
        scanf("%d%d",&x,&y);  
        add(x,y);  
        tag\[y\]++;  
    }  
    int root;  
    for(int i=1;i<=n;i++)  
       if(tag\[i\]==0) root=i;  
    ans=0;  
    dfs(root);  
    printf("%lld\\n",ans);  
}  
return 0;  

}

也可以用拓扑排序,自下而上进行启发式合并,

#include
#include
#include
#include
#include
#include
#include

using namespace std;
const int maxn=1e5;
typedef long long int LL;
int rt[maxn*100+5];
int ls[maxn*100+5];
int rs[maxn*100+5];
LL sum[maxn*100+5];
int a[maxn+5];
int f[maxn+5];
LL k;
int n;
int p;
int l,r;
queue q;
int newnode()
{
sum[p]=ls[p]=rs[p]=0;
return p++;
}
void build(int &node,int begin,int end,LL val)
{
if(!node) node=newnode();
sum[node]=1;
if(begin==end) return;
LL mid=(begin+end)>>1;
if(val<=mid) build(ls[node],begin,mid,val); else build(rs[node],mid+1,end,val); } LL Query(int node,int begin,int end,LL val) { if(!node||val>1;
if(val<=mid) return Query(ls[node],begin,mid,val); else return sum[ls[node]]+Query(rs[node],mid+1,end,val); } void mergge(int &x,int y, int begin,int end) { if(!x||!y) {x=x^y;return;} sum[x]+=sum[y]; if(begin==end) return; LL mid=(begin+end)>>1;
mergge(ls[x],ls[y],begin,mid);
mergge(rs[x],rs[y],mid+1,end);
}
LL ans;
int tag[maxn+5];
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
p=1;
memset(tag,0,sizeof(tag));
l=1e9;r=0;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
l=min(l,a[i]);r=max(r,a[i]);
}
for(int i=1;i<=n;i++)
build(rt[i]=0,l,r,a[i]);

    for(int i=1;i<=n-1;i++)  
    {  
        scanf("%d%d",&x,&y);  
        tag\[x\]++;  
        f\[y\]=x;  
    }  
    for(int i=1;i<=n;i++)  
    {  
       if(tag\[i\]==0)  
           q.push(i);  
    }  
    ans=0;  
    while(!q.empty())  
    {  
        int x=q.front();q.pop();  
        if(1LL\*a\[x\]\*a\[x\]<=k) ans--;  
        ans+=Query(rt\[x\],l,r,k/a\[x\]);  
        mergge(rt\[f\[x\]\],rt\[x\],l,r);  
        if(!--tag\[f\[x\]\]) q.push(f\[x\]);

    }  
    printf("%lld\\n",ans);  
}  
return 0;  

}

最后写一种,可持续化线段树的解法。首先将树形转成线形,然后逐个点插入,求一个根节点的子树对根节点的贡献,就是求DFS序列一段区间

小于k/a[i]的有多少个,可持续化线段树利用类似前缀和的原理,tree[r]-tree[l-1]就是l到r这一段区间所有点的线段树

#include
#include
#include
#include
#include
#include
#include

using namespace std;
const int maxn=1e5;
typedef long long int LL;
int rt[maxn*100+5];
int ls[maxn*100+5];
int rs[maxn*100+5];
LL sum[maxn*100+5];
int p;
int n;
LL k;
int l,r;
void update(int &node,int l,int r,int val)
{

ls\[p\]=ls\[node\];rs\[p\]=rs\[node\];  
sum\[p\]=sum\[node\];node=p;  
p++;

if(l==r)  
{  
    sum\[node\]++;  
    return;  
}  
sum\[node\]++;  
int mid=(l+r)>>1;  
if(val<=mid) update(ls\[node\],l,mid,val);  
else update(rs\[node\],mid+1,r,val);  

}
LL query(int node,int l,int r,LL val)
{
if(val>1;
if(val<=mid) return query(ls[node],l,mid,val);
else return sum[ls[node]]+query(rs[node],mid+1,r,val);
}
struct Node
{
int value;
int next;
}edge[maxn*2+5];
int head[maxn+5];
int tot;
void add(int x,int y)
{
edge[tot].value=y;
edge[tot].next=head[x];
head[x]=tot++;
}
int res[maxn*2];
int a[maxn+5];
int cot;
void dfs(int root)
{
res[cot++]=root;
for(int i=head[root];i!=-1;i=edge[i].next)
{
int w=edge[i].value;
dfs(w);
}
res[cot++]=root;
}
int tag[maxn+5];
int flag[maxn+5];
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
l=1e9;r=0;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
l=min(l,a[i]);r=max(r,a[i]);
}

    memset(head,-1,sizeof(head));  
    memset(tag,0,sizeof(tag));  
    tot=0;  
    p=1;  
    for(int i=1;i<=n-1;i++)  
    {  
        scanf("%d%d",&x,&y);  
        add(x,y);  
        tag\[y\]++;  
    }  
    int root;  
    for(int i=1;i<=n;i++)  
    {  
        if(!tag\[i\])  
            root=i;  
    }  
    cot=0;  
    dfs(root);  
    memset(flag,0,sizeof(flag));  
    update(rt\[res\[0\]\],l,r,a\[res\[0\]\]);  
    flag\[res\[0\]\]=1;  
    LL ans=0;  
    int now=0;  
    for(int i=1;i<cot;i++)  
    {  
        if(flag\[res\[i\]\]==1)  
        {  
            LL ans1=query(rt\[res\[now\]\],l,r,k/a\[res\[i\]\]);  
            LL ans2=query(rt\[res\[i\]\],l,r,k/a\[res\[i\]\]);  
            //cout<<ans1<<" "<<ans2<<endl;  
            ans+=ans1-ans2;  
            continue;  
        }  
        flag\[res\[i\]\]=1;  
        update(rt\[res\[i\]\]=rt\[res\[now\]\],l,r,a\[res\[i\]\]);  
        now=i;  
    }  
    printf("%lld\\n",ans);  
}  
return 0;  

}