题意
\(n\) 个点的一棵树,要求删除尽量少的点,使得删点之后还是一棵树,并且直径不超过 \(k\),求删除点的数量
分析
补题之前的一些错误想法:
开始讲正解:
考虑点分治,假设当前处理的根是 \(rt\),那么很容易算出根的答案,如何计算子树内的?假设 \(rt\) 有一颗子树 \(u\) ,在计算 \(u\) 中的点 \(x\) 的答案时,不需要考虑 \(u\) 中除 \(x\) 外的其他点对它的贡献,这一部分会继续分治下去求解。现在只需要计算 \(rt\) 除去 \(u\) 的点对 \(x\) 的贡献即可。
由于树边没有权值,我们可以用深度 \(dep\) 来表示树上路径距离,在求 \(x\) 的答案时,假设直径为 \(k\) (先假设\(k\)是偶数), 那么也就是求除去 \(u\) 这颗子树,其他深度大于 \((k/2) - dep[x]\) 的点的个数。可以先预处理出来 \(rt\) 的深度数组,然后求解 \(u\) 时,在预处理出来一个深度数组 \(c\),两者做差即可。
设子树 \(u\) 旁边的子树 \(v\) 上的一点 \(y\) , 那么当 \(dep[y] + dep[x] > k/2\), 才需要累计 \(y\) 对 \(x\)的贡献
当 \(k\) 为奇数时,树的直径中点在边上,我们考虑 \(x\) 到 \(anc[x]\) 这条边,答案应该是深度大于 \((k-1)/2-dep[u]+1\)的点的个数。但是如果 \(anc[x]\) 是 \(rt\),要注意还要累计\(x\) 所在子树 \(u\) 上的贡献。因为该边并不会递归下去分治求解。
const int N = 300000 + 5;
typedef pair<int,int> pii;
vector<pii> g[N];
int n, k;
int anc[N], sz[N], dep[N], allcount[N], o[N];
int cnt_node[N], cnt_edge[N];
bool del[N];
int rt, min_size, all, max_dep, node_len, edge_len;
void getsz(int x, int fa) {
anc[x] = fa;
sz[x] = 1;
int maxpart = 0;
for(auto t : g[x]) {
if(t.first == fa || del[t.first]) continue;
getsz(t.first, x);
sz[x] += sz[t.first];
maxpart = max(maxpart, sz[t.first]);
}
maxpart = max(maxpart, all - sz[x]);
if(maxpart < min_size) {
min_size = maxpart; rt = x;
}
}
// 获得点分治当前处理树的dep数组,与后缀和 allcount
void travel(int u) {
static int o[N];// BFS队列
int l = 0, r = -1;
o[++r] = u;
while(l <= r) {
int u = o[l++];
allcount[dep[u]] ++;
max_dep = max(max_dep, dep[u]);
for(auto &v : g[u]){
if(v.first == anc[u] || del[v.first]) continue;
anc[v.first] = u;
dep[v.first] = dep[u] + 1;
o[++r] = v.first;
}
}
for(int i=max_dep-1;i>=0;i--) allcount[i] += allcount[i+1];
}
void travel2(int u) {
static int c[N], o[N];
int l = 0, r = -1;
o[++r] = u;
while(l <= r) {
int u = o[l++];
c[dep[u]] ++;
for(auto &v:g[u]) {
if(del[v.first] || v.first == anc[u]) continue;
o[++r] = v.first;
}
}
int Max = dep[o[r]];
for(int i=Max - 1; i >= 0; i--) c[i] += c[i+1];
for(int i=0;i<=r;i++){
int u = o[i];
int d = max(0, node_len - dep[u] + 1);
cnt_node[u] += allcount[d] - c[d];
// 下面是处理树直径中心在边上的情况
d = max(0, edge_len - dep[u] + 2);
for(auto &v:g[u]) {
if(v.first != anc[u]) continue;
cnt_edge[v.second] += allcount[d] - c[d];
if(anc[u] == rt) { // 如果是与rt相连的边,要统计 c 的答案
cnt_edge[v.second] += c[edge_len + 2];
}
}
}
// 最后要记得清空
for(int i=0;i<=Max;i++) c[i] = 0;
}
void get(int u) {
del[u] = 1;
dep[u] = 0;
max_dep = 0;
travel(rt);
cnt_node[u] += allcount[node_len + 1];
for(auto &v : g[u]) {
if(del[v.first]) continue;
travel2(v.first);
}
_rep(i,0,max_dep) allcount[i] = 0;
for(auto &v : g[u]) {
if(del[v.first]) continue;
all = sz[v.first]; min_size = all;
getsz(v.first, 0);
getsz(rt, 0);
get(rt);
}
}
int main(){
int T; scanf("%d", &T);
while(T--){
scanf("%d%d", &n, &k);
node_len = k / 2;
edge_len = (k - 1) / 2;
_rep(i, 1, n) {
g[i].clear();
cnt_edge[i] = cnt_node[i] = del[i] = 0;
}
_rep(i, 1, n-1) {
int u, v; scanf("%d%d", &u,&v);
g[u].push_back(pii(v, i));
g[v].push_back(pii(u, i));
}
all = n; min_size = n;
getsz(1, 0);
getsz(rt, 0);
get(rt);
int res = n;
_rep(i, 1, n) res = min(res, cnt_node[i]);
_rep(i, 1, n-1) res = min(res, cnt_edge[i]);
printf("%d\n", res);
}
return 0;
}
手机扫一扫
移动阅读更方便
你可能感兴趣的文章