首先不考虑强制要求的话是一个经典问题,令 \(f_{i, 0 / 1}\) 为 \(i\) 选或不选时以 \(i\) 为根的子树的最优答案。那么就有转移 \(f_{u, 0} = \sum f_{v, 1}, f_{u, 1} = \sum \min(f_{v, 0}, f_{v, 1})\).每次查询重新暴力 \(dp\) 一遍整棵树就可以获得 \(44pts\).
先考虑一个部分分,\(B1\) 的情况,可以发现我们每次强制一个点选或不选能影响到的是其到根一条路径上的 \(dp\) 值,那么如果我们每次询问就只需要修改一个点到根路径上的 \(dp\) 值即可,但是如果暴力修改复杂度还是不正确的,如果是菊花图就没了,但同时我们需要知道这样一个事情,通过上面的转移方程是可以发现这个 \(dp\) 值是满足可加性的,换句话说如果我们去除掉某个子树对答案的影响,那么其 \(dp\) 求出来的将会是剩余部分的最优解,也就是说我们的 \(dp\) 是可以分开考虑然后将 \(dp\) 值简单相加。于是我们可以记录 \(g_{i, 0 / 1}\) 表示 \(i\) 选或不选时去除掉 \(i\) 所在子树其父亲的最优解,那么这样我们每次往上重新 \(dp\) 的时候复杂度就是 \(O(dep)\) 了。
可以发现这个过程是可以使用倍增优化的,每次我们往父亲跳的过程可以直接倍增地跳,那么这样就可以快速重新 \(dp\) 出答案了。具体的我们可以令 \(dp_{i, j, 0 / 1}\) 表示在以 \(i\) 的 \(2 ^ j\) 祖先为根的子树中去除掉以 \(i\) 为根的子树部分 \(i\) 选或不选的答案,但是这样我们可以发现一个问题,每次往上跳之后我们并不知道当前跳到节点的状态,因此我们需要再添加一维状态令 \(dp_{i, j, 0 / 1, 0 / 1}\) 表示以 \(i\) 的 \(2 ^ j\) 祖先为根的子树中去除掉以 \(i\) 为根的子树部分 \(i\) 选或不选,\(i\) 的 \(2 ^ j\) 祖先选或不选的答案,那么这样我们就可以一路 \(dp\) 上去了,实现代码的时候最后两个点的 \(lca\) 往根节点跳的过程可以提前预处理出来这样就可以省掉大量的代码(令 \(g_{i, 0 / 1}\) 表示整棵树去除掉以 \(i\) 为根的子树 \(i\) 选或不选的答案),细节比较多,一定要想清楚细节再开始写代码。
#include<bits/stdc++.h>
using namespace std;
#define N 100000 + 5
#define M 20
#define inf 10000000000000000
#define rep(i, l, r) for(register int i = l; i <= r; ++i)
#define dep(i, l, r) for(register int i = r; i >= l; --i)
#define Next(i, u) for(register int i = h[u]; i; i = e[i].next)
typedef long long ll;
struct edge{
int v, next;
}e[N << 1];
char type[M];
ll f[N][2], g[N][2], dp[N][M][2][2];
int n, m, u, v, a, x, b, y, tot, h[N], p[N], dep[N], fa[N][M];
inline int read(){
char c; int x = 0, f = 1;
c = getchar();
while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
inline void add(int u, int v){
e[++tot].v = v, e[tot].next = h[u], h[u] = tot;
}
inline void dfs1(int u, int Fa){
fa[u][0] = Fa, f[u][1] += p[u], dep[u] = dep[Fa] + 1;
Next(i, u){
int v = e[i].v; if(v == Fa) continue;
dfs1(v, u);
f[u][0] += f[v][1], f[u][1] += min(f[v][0], f[v][1]);
}
}
inline void dfs2(int u, int fa){
if(u != 1){
g[u][0] = g[fa][1] + f[fa][1] - min(f[u][0], f[u][1]);
g[u][1] = min(g[fa][1] + f[fa][1] - min(f[u][0], f[u][1]), g[fa][0] + f[fa][0] - f[u][1]);
dp[u][0][0][0] = inf, dp[u][0][0][1] = f[fa][1] - min(f[u][0], f[u][1]);
dp[u][0][1][0] = f[fa][0] - f[u][1], dp[u][0][1][1] = f[fa][1] - min(f[u][0], f[u][1]);
}
Next(i, u) if(e[i].v != fa) dfs2(e[i].v, u);
}
inline ll solve(int x, int a, int y, int b){
if(dep[x] < dep[y]) swap(x, y), swap(a, b);
Next(i, y) if(e[i].v == x && !a && !b) return -1;
int fx = x, fy = y;
ll l[2] = {0, 0}, r[2] = {0, 0}, ans = 0, l0, l1, r0, r1;
l[!a] = inf, r[!b] = inf;
dep(i, 0, 17) if(dep[fa[x][i]] >= dep[y]){
l0 = l[0], l1 = l[1];
l[0] = min(l0 + dp[x][i][0][0], l1 + dp[x][i][1][0]);
l[1] = min(l0 + dp[x][i][0][1], l1 + dp[x][i][1][1]);
x = fa[x][i];
}
if(x != y){
dep(i, 0, 17) if(fa[x][i] != fa[y][i]){
l0 = l[0], l1 = l[1], r0 = r[0], r1 = r[1];
l[0] = min(l0 + dp[x][i][0][0], l1 + dp[x][i][1][0]);
l[1] = min(l0 + dp[x][i][0][1], l1 + dp[x][i][1][1]);
r[0] = min(r0 + dp[y][i][0][0], r1 + dp[y][i][1][0]);
r[1] = min(r0 + dp[y][i][0][1], r1 + dp[y][i][1][1]);
x = fa[x][i], y = fa[y][i];
}
l0 = l[0], l1 = l[1], r0 = r[0], r1 = r[1];
ll l0r0, l0r1, l1r0, l1r1, lca = fa[x][0];
l0r0 = (dp[x][0][0][0] + dp[y][0][0][0] - f[y][1] - f[x][1]) / 2 + g[lca][0];
l0r0 = min(l0r0, (dp[x][0][0][1] + dp[y][0][0][1] - min(f[x][0], f[x][1]) - min(f[y][0], f[y][1])) / 2 + g[lca][1]);
l0r0 += l0 + r0;
l0r1 = (dp[x][0][0][0] + dp[y][0][1][0] - f[y][1] - f[x][1]) / 2 + g[lca][0];
l0r1 = min(l0r1, (dp[x][0][0][1] + dp[y][0][1][1] - min(f[x][0], f[x][1]) - min(f[y][0], f[y][1])) / 2 + g[lca][1]);
l0r1 += l0 + r1;
l1r0 = (dp[x][0][1][0] + dp[y][0][0][0] - f[y][1] - f[x][1]) / 2 + g[lca][0];
l1r0 = min(l1r0, (dp[x][0][1][1] + dp[y][0][0][1] - min(f[x][0], f[x][1]) - min(f[y][0], f[y][1])) / 2 + g[lca][1]);
l1r0 += l1 + r0;
l1r1 = (dp[x][0][1][0] + dp[y][0][1][0] - f[y][1] - f[x][1]) / 2 + g[lca][0];
l1r1 = min(l1r1, (dp[x][0][1][1] + dp[y][0][1][1] - min(f[x][0], f[x][1]) - min(f[y][0], f[y][1])) / 2 + g[lca][1]);
l1r1 += l1 + r1;
x = fa[x][0], y = fa[y][0];
ans += min(min(min(l0r0, l0r1), l1r0), l1r1) + f[fx][a] + f[fy][b];
}
else{
ans += f[fx][a];
l[!b] = inf;
ans += min(l[0] + g[y][0], l[1] + g[y][1]);
}
return ans;
}
signed main(){
n = read(), m = read(), scanf("%s", type + 1);
rep(i, 1, n) p[i] = read();
rep(i, 1, n - 1) u = read(), v = read(), add(u, v), add(v, u);
dfs1(1, 0), dfs2(1, 0);
rep(j, 1, 17) rep(i, 1, n){
int Fa = fa[i][j - 1];
fa[i][j] = fa[Fa][j - 1];
dp[i][j][0][0] = min(dp[i][j - 1][0][0] + dp[Fa][j - 1][0][0], dp[i][j - 1][0][1] + dp[Fa][j - 1][1][0]);
dp[i][j][0][1] = min(dp[i][j - 1][0][0] + dp[Fa][j - 1][0][1], dp[i][j - 1][0][1] + dp[Fa][j - 1][1][1]);
dp[i][j][1][0] = min(dp[i][j - 1][1][0] + dp[Fa][j - 1][0][0], dp[i][j - 1][1][1] + dp[Fa][j - 1][1][0]);
dp[i][j][1][1] = min(dp[i][j - 1][1][0] + dp[Fa][j - 1][0][1], dp[i][j - 1][1][1] + dp[Fa][j - 1][1][1]);
}
while(m--) a = read(), x = read(), b = read(), y = read(), printf("%lld\n", solve(a, x, b, y));
return 0;
}
手机扫一扫
移动阅读更方便
你可能感兴趣的文章