Weak Pair
**Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/262144 K (Java/Others)
Total Submission(s): 1468 Accepted Submission(s): 472
**
Problem Description
You are given a rooted tree
of N nodes,
labeled from 1 to N.
To the ith
node a non-negative value ai is
assigned.An ordered pair
of nodes (u,v) is
said to be weak if
(1) u is
an ancestor of v (Note:
In this problem a node u is
not considered an ancestor of itself);
(2) au×av≤k.
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer T denoting
number of test cases.
For each case, the first line contains two space-separated integers, N and k,
respectively.
The second line contains N space-separated
integers, denoting a1 to aN.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v ,
where node u is
the parent of node v.
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1
2 3
1 2
1 2
Sample Output
1
这是一道很好的数据结构的题目:
可以用很多方法写
首先思路是:dfs这颗树,每到一个节点,都计算这个节点的祖先中满足条件的有几个
而计算这个就需要维护一个序列,并且高效的得出多少个祖先满足条件。
即在序列中找到小于k/a[i]的数有多少个,很容易想到用树状数组和线段树。
权值1e9需要离散化。
树状数组:
#include
#include
#include
#include
#include
#include
#include
using namespace std;
const int maxn=1e5;
typedef long long int LL;
int n;
LL k;
LL a[maxn+5];
struct Node
{
int value;
int next;
}edge[maxn*2+5];
int head[maxn+5];
int vis[maxn+5];
int tot;
int c[maxn*2+5];
LL b[maxn+5];
LL e[maxn*2+5];
map m;
void add(int x,int y)
{
edge[tot].value=y;
edge[tot].next=head[x];
head[x]=tot++;
}
int lowbit(int x)
{
return x&(-x);
}
void update(int x,int num)
{
while(x<=n*2)
{
c[x]+=num;
x+=lowbit(x);
}
}
int sum(int x)
{
int _sum=0;
while(x>0)
{
_sum+=c[x];
x-=lowbit(x);
}
return _sum;
}
LL ans;
void dfs(int root)
{
vis[root]=1;
for(int i=head[root];i!=-1;i=edge[i].next)
{
int v=edge[i].value;
if(!vis[v])
{
ans+=sum(m[b[v]]);
update(m[a[v]],1);
dfs(v);
update(m[a[v]],-1);
}
}
}
void init()
{
memset(c,0,sizeof(c));
memset(vis,0,sizeof(vis));
memset(head,-1,sizeof(head));
tot=0;
}
int tag[maxn+5];
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
init();
int cnt=n;
m.clear();
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
e[i]=a[i];
if(a[i]==0)
m[a[i]]=2*n;
else
{
b[i]=k/a[i];
e[++cnt]=b[i];
}
}
sort(e+1,e+cnt+1);
int tot=1;
for(int i=1;i<=cnt;i++)
{
if(!m.count(e[i]))
m[e[i]]=tot++;
}
memset(tag,0,sizeof(tag));
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
tag[y]++;
}
int root;
for(int i=1;i<=n;i++)
{
if(tag[i]==0)
root=i;
}
ans=0;
update(m[a[root]],1);
dfs(root);
printf("%lld\n",ans);
}
return 0;
}
线段树:
#include
#include
#include
#include
#include
#include
#include
#include