题解-洛谷P4859 已经没有什么好害怕的了
阅读原文时间:2021年10月28日阅读:1

洛谷P4859 已经没有什么好害怕的了

给定 \(n\) 和 \(k\),\(n\) 个糖果能量 \(a_i\) 和 \(n\) 个药片能量 \(b_i\),每个 \(a_i\) 和 \(b_i\) 互不相等。将糖果和药片一一对应,求 糖果能量大于药片 比 药片能量大于糖果 多 \(k\) 组的方案数。

数据范围:\(1\le n\le 2000\),\(0\le k\le n\)。


萌新初学二项式反演,这是第一道完全自己做出来的题,所以写篇题解庆祝并提升理解。


有 \(\frac{n+k}{2}\) 组糖果能量大于药片,\(\frac{n-k}{2}\) 组药片能量大于糖果

如果 \(n+k\) 是奇数,直接答案为 \(0\) 特判掉。

\(f(i)\) 表示 \(i\) 组糖果能量大于药片,\(n-i\) 组药片能量大于糖果的方案数。

\(g(i)\) 表示 \(i\) 组糖果能量大于药片,\(n-i\) 组随意的方案数。

二项式反演必然有 \(f(i)\) 和 \(g(i)\),往往前者表示 \(i\) 个符合条件 \(a\) 剩下符合另条件 \(b\),后者表示 \(i\) 个符合条件 \(a\) 剩下随意。


先考虑 \(g(i)\) 怎么独立地求,蒟蒻想到了 \(\tt dp\)。

将 \(a_i\) 和 \(b_i\) 排序,现在 \(a_i<a_{i+1}\),\(b_i<b_{i+1}\)。

比如 \(b_i<a_1<b_{i+1}\),\(b_j<a_2<b_{j+1}(i<j)\)。

所以 \(a_1\) 可以对应 \(b_1\sim b_i\),\(a_2\) 可以对应 \(b_1\sim b_j\)。

因为 \(a_1\) 对于的 \(b_x\) 满足 \(x<i<j\),所以必然占了一个 \(a_2\) 可以对应的位。

所以有 \(i(j-1)\) 种对应法。

设 \(F_{i,j}\) 表示看了 \(a_1\sim a_i\),对应了 \(j\) 组的方案数。

令 \(p(i)\) 表示 \(b_{p(i)}<a_i<b_{p(i)+1}\)。

同理,所以 \(F(0,0)=1\)。

\[F(i,j)=F(i-1,j)+F(i-1,j-1)\cdot (p(i)-(j-1))
\]

\[g(i)=F(n,i)\cdot(n-i)!
\]


二项式反演来了:

\[g(i)=\sum_{x=i}^n{x\choose i}f(x) \Longleftrightarrow f(i)=\sum_{x=i}^n(-1)^{x-i}{x\choose i}g(x)
\]

答案是 \(f(\frac{n+k}{2})\),带进去算就好了。


时间复杂度 \(\Theta(n^2)\),空间复杂度 \(\Theta(n^2)\)。


  • 代码

下标从 \(0\) 开始的…巨佬们琢磨琢磨吧。\(\tt /kel\)。

#include <bits/stdc++.h>
using namespace std;

//Start
typedef long long ll;
typedef double db;
#define mp(a,b) make_pair(a,b)
#define x first
#define y second
#define be(a) a.begin()
#define en(a) a.end()
#define sz(a) int((a).size())
#define pb(a) push_back(a)
const int inf=0x3f3f3f3f;
const ll INF=0x3f3f3f3f3f3f3f3f;

//Data
const int mod=1e9+9;

//Main
int main(){
    cin.tie(0);
    int n,k; cin>>n>>k;
    vector<int> a(n),b(n);
    for(int&ai:a) cin>>ai;
    for(int&bi:b) cin>>bi;
    if((n-k)&1) return cout<<0<<'\n',0;
    sort(be(a),en(a));
    sort(be(b),en(b));
    vector<vector<int>> f(n+1,vector<int>(n+1,0));
    f[0][0]=1;
    for(int i=0,p=-1;i<n;i++){
        while(p+1<n&&b[p+1]<a[i]) p++;
        for(int j=0;j<n+1;j++) f[i+1][j]=f[i][j];
        for(int j=0;j<n;j++) (f[i+1][j+1]+=(ll)f[i][j]*(p-j+1)%mod)%=mod;
    }
    for(int j=n,s=1;j>=0;j--) f[n][j]=(ll)f[n][j]*s%mod,s=(ll)s*(n-j+1)%mod;
    vector<vector<int>> c(n+1,vector<int>(n+1,0));
    for(int i=0;i<n+1;i++){
        c[i][0]=c[i][i]=1;
        for(int j=1;j<i;j++) c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
    }
    int ans=0,t=(n+k)>>1;
    for(int i=t;i<n+1;i++){
        int sum=(ll)f[n][i]*c[i][t]%mod;
        if((i-t)&1) (ans+=-sum+mod)%=mod;
        else (ans+=sum)%=mod;
    }
    cout<<ans<<'\n';
    return 0;
}

祝大家学习愉快!