[NOIP2018提高组] 保卫王国 (树链剖分+动态DP)
阅读原文时间:2023年07月09日阅读:2

题面

题目链接—Luogu

题目链接—Loj(要加Freopen)

题解

什么是动态DP?

OneInDark:你不需要知道这么多,你只需要知道是利用了广义矩阵乘法就够了!


广义矩阵乘法,简单来说,就是把基本的 乘法加法 运算符改成其它运算符,同时这两种运算要满足 前者对后者有分配律,如:加法最大或最小值按位与异或 等。因为,我们会发现,乘法加法 组成的传统矩阵乘法之所以有哪些性质,其根本原因就在于乘法对加法的分配律。

举个例子,有这么个 DP 转移:

d

p

[

i

]

[

0

]

=

min

(

d

p

[

i

1

]

[

0

]

+

A

,

d

p

[

i

1

]

[

1

]

+

B

)

d

p

[

i

]

[

1

]

=

min

(

d

p

[

i

1

]

[

0

]

+

C

,

d

p

[

i

1

]

[

1

]

+

D

)

dp[i][0]=\min(dp[i-1][0]+A,dp[i-1][1]+B)\\ dp[i][1]=\min(dp[i-1][0]+C,dp[i-1][1]+D)

dp[i][0]=min(dp[i−1][0]+A,dp[i−1][1]+B)dp[i][1]=min(dp[i−1][0]+C,dp[i−1][1]+D)

那么它转化为 加法最小值 的广义矩阵乘法后,从

(

d

p

[

i

1

]

[

0

]

,

d

p

[

i

1

]

[

1

]

)

\big(dp[i-1][0],dp[i-1][1]\big)

(dp[i−1][0],dp[i−1][1]) 到

(

d

p

[

i

]

[

0

]

,

d

p

[

i

]

[

1

]

)

\big(dp[i][0],dp[i][1]\big)

(dp[i][0],dp[i][1]) 的转移矩阵就是

(

A

C

B

D

)

\left( \begin{matrix} A&C\\ B&D\\ \end{matrix} \right)

(AB​CD​)

顺便提一下此时的单位矩阵:

(

0

+

+

0

)

\left( \begin{matrix} 0&+\infty\\ +\infty&0\\ \end{matrix} \right)

(0+∞​+∞0​)

0

0

0 表示直接转移过来,不改变原值的基础上参与最小值运算,与传统矩乘的

1

1

1 的作用相同。

+

+\infty

+∞ 在求最小值前提下,对结果没有影响,与传统矩乘的

0

0

0 作用相同。

广义矩阵乘法的实现很简单,只需要把传统矩乘的模板

c

o

p

y

\tt copy

copy 过来,然后把原来计算

C

i

,

j

=

k

A

i

,

k

B

k

,

j

C_{i,j}=\sum_{k}A_{i,k}\cdot B_{k,j}

Ci,j​=k∑​Ai,k​⋅Bk,j​

的部分:for(...k...) C[i][j]=C[i][j]+A[i][k]*B[k][j](以加法和最小值为例)改成:

C

i

,

j

=

min

k

{

A

i

,

k

+

B

k

,

j

}

C_{i,j}=\min_{k}\{A_{i,k}+B_{k,j}\}

Ci,j​=kmin​{Ai,k​+Bk,j​}

for(...k...) C[i][j]=min(C[i][j],A[i][k]*B[k][j]),就完了。


其实这道题严格意义上没有真正的修改,所以打树剖难免有点大材小用。

不过,这都是为了配合动态 DP 啊!

我们令

g

(

i

,

0

)

g(i,0)

g(i,0) 为不在

**

i

i

**

i 结点驻兵时,

i

i

i 的非重儿子子树的最小花费,

g

(

i

,

1

)

g(i,1)

g(i,1) 为

**

i

i

**

i 结点驻兵时,

i

i

i 的非重儿子子树的最小花费。在进行一次修改时,只会变动最多

log

n

\log n

logn 个点的

g

g

g 值,因此可以直接更新它。

同时令

f

(

i

,

0

)

f(i,0)

f(i,0) 为不在

i

i

i 结点驻兵时,

i

i

i 的子树(包括

i

i

i)的最小花费,令

f

(

i

,

1

)

f(i,1)

f(i,1) 为在

i

i

i 结点驻兵时,

i

i

i 的子树(包括

i

i

i)的最小花费。可以得到转移方程:

g

(

i

,

0

)

=

y

i

f

(

y

,

1

)

g

(

i

,

1

)

=

y

i

min

{

f

(

y

,

0

)

,

f

(

y

,

1

)

}

f

(

i

,

0

)

=

f

(

S

O

N

i

,

1

)

+

g

(

i

,

0

)

f

(

i

,

1

)

=

min

{

f

(

S

O

N

i

,

0

)

,

f

(

S

O

N

i

,

1

)

}

+

g

(

i

,

1

)

+

p

i

g(i,0)=\sum_{y\in i\,的轻儿子} f(y,1)~~~~~\\~\\ g(i,1)=\sum_{y\in i\,的轻儿子} \min\{f(y,0),f(y,1)\}\\~\\ f(i,0)=f(SON_i,1)+g(i,0)~~~~~~~~~~\\~\\ f(i,1)=\min\{f(SON_i,0),f(SON_i,1)\}+g(i,1)+p_i

g(i,0)=y∈i的轻儿子∑​f(y,1)                        g(i,1)=y∈i的轻儿子∑​min{f(y,0),f(y,1)} f(i,0)=f(SONi​,1)+g(i,0)                                         f(i,1)=min{f(SONi​,0),f(SONi​,1)}+g(i,1)+pi​

我们可以就此得出

(

f

(

S

O

N

i

,

0

)

,

f

(

S

O

N

i

,

1

)

)

\big(f(SON_i,0),f(SON_i,1)\big)

(f(SONi​,0),f(SONi​,1)) 到

(

f

(

i

,

0

)

,

f

(

i

,

1

)

)

\big(f(i,0),f(i,1)\big)

(f(i,0),f(i,1)) 的转移矩阵,用 加法最小值 的广义矩乘:

(

+

g

(

i

,

1

)

+

p

i

g

(

i

,

0

)

g

(

i

,

1

)

+

p

i

)

\left( \begin{matrix} +\infty&g(i,1)+p_i\\ g(i,0)&g(i,1)+p_i \end{matrix} \right)

(+∞g(i,0)​g(i,1)+pi​g(i,1)+pi​​)

那么,我们就可以用线段树维护区间矩乘了。把每个点独特的转移矩阵都经树链剖分放到某条重链上,一个点

i

i

i 的向量矩阵

(

f

(

i

,

0

)

,

f

(

i

,

1

)

)

\big(f(i,0),f(i,1)\big)

(f(i,0),f(i,1)) 就等于

(

0

,

0

)

×

Q

u

e

r

y

(

d

f

n

i

,

)

\big(0,0\big)\times {\rm Query}({\rm dfn}_i~,~重链底)

(0,0)×Query(dfni​ , 重链底)

修改一个点的时候,只会对 到根的途径上所有重链顶的父亲

g

g

g 值以及转移矩阵有改变。对

g

g

g 的改变没有必要重新把轻儿子枚举一边,只需要把变动的那个重链顶的贡献变化量加上去。为此,有必要把每个重链顶先前的

f

f

f 值都存下来。

针对 必须驻兵 和 必须不驻兵 两种限制,可以分别把节点的

p

i

p_i

pi​ 加上

-\infty

−∞ 和

+

+\infty

+∞ 来解决。这都是老套路了。

时间复杂度

O

(

n

log

2

n

)

\rm O(n\log^{_2}n)

O(nlog2​n) 。

CODE

#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 100005
#define ENDL putchar('\n')
#define LL long long
#define DB double
#define lowbit(x) ((-x) & (x))
LL read() {
    LL f = 1,x = 0;char s = getchar();
    while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
    while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
    return f * x;
}
int n,m,i,j,s,o,k;
struct mat{
    int n,m;
    LL s[2][2];
    mat(){n=m=0;}
    void set(int N,int M) {
        n=N;m=M;
        for(int i = 0;i < n;i ++) {
            for(int j = 0;j < m;j ++) {
                s[i][j] = 1e18;
            }
        }
    }
}A,B;
mat operator * (mat a,mat b) {
    mat c; c.set(a.n,b.m);
    for(int i = 0;i < a.n;i ++) {
        for(int k = 0;k < a.m;k ++) {
            for(int j = 0;j < b.m;j ++) {
                c.s[i][j] = min(c.s[i][j],a.s[i][k] + b.s[k][j]);
            }
        }
    }return c;
}
mat qkpow(mat a,LL b) {
    mat res; res.set(a.n,a.m);
    for(int i = 0;i < a.n;i ++) res.s[i][i] = 0;
    while(b > 0) {
        if(b & 1) res = res * a;
        a = a * a; b >>= 1;
    }
    return res;
}
mat tre[MAXN<<2];
int M;
void maketree(int n) {
    M = 1;while(M < n+2) M <<= 1;
    for(int i = 1;i <= M*2;i ++) tre[i].set(2,2);
}
void addtree(int x,mat y) {
    int s = M + x; tre[s] = y; s >>= 1;
    while(s) tre[s] = tre[s<<1|1] * tre[s<<1],s >>= 1;
}
mat findtree(int l,int r) {
    int s = M+l-1,t = M+r+1;
    mat ls,rs;ls.set(2,2);
    ls.s[0][0] = ls.s[1][1] = 0;
    rs = ls;
    while(s || t) {
        if((s>>1) ^ (t>>1)) {
            if(!(s&1)) ls = tre[s^1] * ls;
            if(t & 1) rs = rs * tre[t^1];
        }else break;
        s >>= 1;t >>= 1;
    }
    return rs * ls;
}
//--------------------------------------------------------------------
vector<int> g[MAXN];
LL pi[MAXN];
int fa[MAXN];
int d[MAXN],siz[MAXN],son[MAXN];
int dfn[MAXN],tp[MAXN],tim,tl[MAXN];
LL G[MAXN][2];
void dfs0(int x,int ff) { // fa[], d[], siz[], son[],
    d[x] = d[fa[x] = ff] + 1;
    siz[x] = 1; son[x] = 0; tim = 0;
    for(int i = 0,le=(int)g[x].size();i < le;i ++) {
        int y = g[x][i];
        if(y != ff) {
            dfs0(y,x);
            siz[x] += siz[y];
            if(siz[y] > siz[son[x]]) son[x] = y;
        }
    }
    return ;
}
LL tf[MAXN][2];
mat tran[MAXN];
void dfs(int x,int ff) { // dfn[], tp[], tl[]
    if(son[ff] == x) tp[x] = tp[ff];
    else tp[x] = x;
    dfn[x] = ++ tim;
    tl[tp[x]] = max(tl[tp[x]],dfn[x]);
    G[x][0] = G[x][1] = 0;
    tf[x][0] = 0;tf[x][1] = pi[x];

    if(son[x]) dfs(son[x],x);
    tf[x][0] += tf[son[x]][1];
    tf[x][1] += min(tf[son[x]][0],tf[son[x]][1]);
    for(int i = 0,le=(int)g[x].size();i < le;i ++) {
        int y = g[x][i];
        if(y != ff && y != son[x]) {
            dfs(y,x);
            G[x][0] += tf[y][1];
            G[x][1] += min(tf[y][1],tf[y][0]);
        }
    }
    tf[x][0] += G[x][0]; tf[x][1] += G[x][1];
    mat tm; tm.set(2,2);
    tm.s[0][0] = 1e18; tm.s[0][1] = G[x][1]+pi[x];
    tm.s[1][0] = G[x][0]; tm.s[1][1] = G[x][1]+pi[x];
    addtree(dfn[x],tran[x] = tm);
    return ;
}
void addline(int a,LL pa) {
    pi[a] = pa;
    while(a) {
        mat tm = tran[a];
        tm.s[0][1] = tm.s[1][1] = G[a][1]+pi[a];
        tm.s[1][0] = G[a][0];
        addtree(dfn[a],tm);
        int hd = tp[a];
        mat tmp;tmp.set(1,2);tmp.s[0][0]=tmp.s[0][1]=0;
        tmp = tmp * findtree(dfn[hd],tl[hd]);
        LL F[2] = {tmp.s[0][0],tmp.s[0][1]};
        a = fa[hd];
        if(a) {
            G[a][0] += F[1] - tf[hd][1];
            G[a][1] += min(F[0],F[1]) - min(tf[hd][0],tf[hd][1]);
        }
        tf[hd][0] = F[0]; tf[hd][1] = F[1];
    }return ;
}
int main() {
    freopen("defense.in","r",stdin);
    freopen("defense.out","w",stdout);
    n = read(); m = read(); read();
    for(int i = 1;i <= n;i ++) {
        pi[i] = read();
    }
    for(int i = 1;i < n;i ++) {
        s = read();o = read();
        g[s].push_back(o);
        g[o].push_back(s);
    }
    dfs0(1,0);
    maketree(n);
    dfs(1,0);
    while(m --) {
        int a = read(),xx = read(),b = read(),yy = read();
        int pa = pi[a],pb = pi[b];
        LL nma = pa + (xx ? -1e13:1e13);
        LL nmb = pb + (yy ? -1e13:1e13);
        addline(a,nma); addline(b,nmb);
        LL as = min(tf[1][0],tf[1][1]);
        if(xx) as += 1e13;
        if(yy) as += 1e13;
        if(as > (LL)1e10) printf("-1\n");
        else printf("%lld\n",as);
        addline(a,pa); addline(b,pb);
    }
    return 0;
}
/*
[f(i,0),f(i,1)] --- [0,0]

[oo       ,g(i,1)+pi]
[g(i,0),g(i,1)+pi]
*/

养成习惯,写矩乘时从 0 开始编号。节省空间能快很多。如果过程中越界可能性大,那就算了吧