Atcoder Grand Contest 021 F - Trinity(dp+NTT)
阅读原文时间:2023年07月09日阅读:1

Atcoder 题面传送门 & 洛谷题面传送门

首先我们考虑设 \(dp_{i,j}\) 表示对于一个 \(i\times j\) 的网格,其每行都至少有一个黑格的合法的三元组 \((A,B,C)\) 的个数,那么对于原来的 \(n\times m\) 的网格,如果其存在黑格的行的集合不同,那么三元组 \((A,B,C)\) 肯定不同,因此我们可以直接枚举有多少行存在黑格来计算答案,即 \(ans=\sum\limits_{i=0}^n\dbinom{n}{i}dp_{i,m}\),因此我们只需求出 \(dp_{i,j}\) 即可求出答案。

接下来考虑怎样求出 \(dp_{i,j}\),首先边界条件肯定是 \(dp_{i,1}=1\),因为必须每格都染黑。接下来考虑转移,转移 \(dp_{i,j}\) 时,我们就枚举这个 \(i\times j\) 的矩阵前 \(j-1\) 列组成的子矩阵中,有多少行存在黑格,我们记这个数为 \(k\),那么可以这样从 \(dp_{k,j-1}\) 转移到 \(dp_{i,j}\):

  • 如果 \(k=i\),由于这个情况比较特殊,故我们单独将其拎出来。由于前 \(j-1\) 列组成的子矩阵中每行都存在黑格,故每行的 \(A_i\) 都是不为 \(0\) 的,加入第 \(j\) 列后 \(A\) 序列肯定不会发生变化,因此我们只需考虑 \(B,C\) 序列的变化即可。显然所有满足 \(x\le y\) 可能的二元组 \((x,y)\) 及 \((i+1,0)\)(这一列没有黑格的情况)都可以成为合法的 \((B_j,C_j)\),总个数为 \(1+j+\dbinom{j}{2}\),因此我们有 \(dp_{i,j}\leftarrow dp_{i,j-1}(1+j+\dbinom{j}{2})\)
  • 如果 \(k\lt i\),那么显然会出现一些本来没有黑格的行出现了黑格,对于这些行而言,其 \(A\) 值肯定就是 \(j\),而那些本来就存在黑格的行的 \(A\) 值肯定不会变,因此我们还是不用考虑 \(A\) 值的变化,考虑哪些行新出现了黑格及 可能的 \(B_j,C_j\) 即可,一个非常 naive 的想法是直接用组合数计算新出现黑格的行可能的情况,即 \(dp_{i,j}\leftarrow dp_{k,j-1}\dbinom{i}{i-k}\),但是这显然是错误的,因为它没有考虑到 \(B,C\) 的变化,比方说对于一个 \(5\times 2\) 的矩阵,第一列 \(1,3\) 为黑格,第二列 \(2,4,5\) 行为黑格的情况,与第一列 \(1,3\) 为黑格,第二列 \(1,2,4,5\) 行为黑格的情况,虽然对于第二列而言,新多出黑格的行的集合是相同的,但是显然在两种情况中 \(B_2\) 的值是不同的,因此它们也应被视为不同的情况。怎么解决呢?我们考虑最上方的黑格正上方的白格,即 \(B_j-1\);以及最下方的黑格正下方的白格,即 \(C_j+1\),这二者与新多出的 \(i-k\) 行会形成 \(i-k+2\) 个数 \(p_1,p_2,\cdots,p_{i-k+2}\),显然这 \(i-k+2\) 个数都在 \([0,i+1]\) 中,它们两两互不相同,且 \(B_j-1\) 肯定是这 \(i-k+2\) 中的最小值,\(C_j+1\) 肯定是这 \(i-k+2\) 中的最大值。又对于两种第 \(j\) 列的染色方案,它们被视作不同的染色方案当且仅当这 \(i-k+2\) 个数组成的集合 \(\{p_1,p_2,\cdots,p_{i-k+2}\}\) 不同,因此我们只需统计有多少个合法的集合 \(\{p_1,p_2,\cdots,p_{i-k+2}\}\) 即可,这个方案数显然是 \(\dbinom{i+2}{i-k+2}\),因此我们有转移 \(dp_{i,j}\leftarrow dp_{k,j-1}\dbinom{i+2}{i-k+2}\)

故我们有 \(dp_{i,j}=dp_{i,j-1}(1+j+\dbinom{j}{2})+\sum\limits_{k=0}^{i-1}dp_{k,j-1}\dbinom{i+2}{i-k+2}\),这样暴力转移是 \(\mathcal O(n^2m)\) 的,无法通过,不过注意到后面一个 \(\sum\) 可以变形为 \(\sum\limits_{k=0}^{i-1}dp_{k,j-1}\dfrac{(i+2)!}{k!(i-k+2)!}=(i+2)!\sum\limits_{k=0}^{i-1}\dfrac{dp_{k,j-1}}{k!}\dfrac{1}{(i-k+2)!}\),这显然是一个卷积的形式,因此 NTT 优化即可,TC 即可降到 \(\mathcal O(nm\log n)\)。

const int MAXN=8000;
const int MAXM=200;
const int MAXP=1<<14;
const int MOD=998244353;
const int pr=3;
const int ipr=(MOD+1)/3;
int qpow(int x,int e){
    int ret=1;
    for(;e;e>>=1,x=1ll*x*x%MOD) if(e&1) ret=1ll*ret*x%MOD;
    return ret;
}
int n,m,dp[MAXN+5][MAXM+5],ans=0;
int fac[MAXN+5],ifac[MAXN+5];
void init_fac(int n){
    for(int i=(fac[0]=ifac[0]=ifac[1]=1)+1;i<=n;i++) ifac[i]=1ll*ifac[MOD%i]*(MOD-MOD/i)%MOD;
    for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%MOD,ifac[i]=1ll*ifac[i-1]*ifac[i]%MOD;
}
int binom(int n,int k){return 1ll*fac[n]*ifac[k]%MOD*ifac[n-k]%MOD;}
int rev[MAXP+5];
void NTT(vector<int> &a,int len,int type){
    int lg=31-__builtin_clz(len);
    for(int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
    for(int i=0;i<len;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int i=2;i<=len;i<<=1){
        int W=qpow((type<0)?ipr:pr,(MOD-1)/i);
        for(int j=0;j<len;j+=i){
            int w=1;
            for(int k=0;k<(i>>1);k++,w=1ll*w*W%MOD){
                int X=a[j+k],Y=1ll*a[(i>>1)+j+k]*w%MOD;
                a[j+k]=(X+Y)%MOD;a[(i>>1)+j+k]=(X-Y+MOD)%MOD;
            }
        }
    }
    if(type==-1){
        int iv=qpow(len,MOD-2);
        for(int i=0;i<len;i++) a[i]=1ll*a[i]*iv%MOD;
    }
}
vector<int> conv(vector<int> a,vector<int> b){
    int LEN=1;while(LEN<a.size()+b.size()) LEN<<=1;
    a.resize(LEN,0);b.resize(LEN,0);NTT(a,LEN,1);NTT(b,LEN,1);
    for(int i=0;i<LEN;i++) a[i]=1ll*a[i]*b[i]%MOD;
    NTT(a,LEN,-1);return a;
}
int main(){
    scanf("%d%d",&n,&m);init_fac(n+2);
    for(int i=0;i<=n;i++) dp[i][1]=1;
    for(int j=2;j<=m;j++){
        vector<int> a(n+1),b(n+1);
        for(int i=1;i<=n;i++) a[i]=ifac[i+2];
        for(int i=0;i<=n;i++) b[i]=1ll*ifac[i]*dp[i][j-1]%MOD;
        a=conv(a,b);
        for(int i=0;i<=n;i++){
            dp[i][j]=1ll*(1+i+binom(i,2))*dp[i][j-1]%MOD;
            dp[i][j]=(dp[i][j]+1ll*fac[i+2]*a[i])%MOD;
        }
    } for(int i=0;i<=n;i++) ans=(ans+1ll*binom(n,i)*dp[i][m])%MOD;
    printf("%d\n",ans);
    return 0;
}