一道良好的矩阵乘法优化\(dp\)的题。
首先,一个比较\(naive\)的想法。
我们定义\(dp[i][j]\)表示已经走了\(i\)步,当前在点\(j\)的方案数。
由于题目中限制了不能立即走之前走过来的那个点,所以这个状态并不能优秀的转移。
尝试重新定义\(dp\)状态。
令\(dp[i][j]\)表示已经走了\(i\)步,当前在\(j\)这条边的终点的那个点。
假设\(to[j]=p\)
那么\(dp[i][j]\)可以转移到\(dp[i+1][out[p]] 其中\ (out[p]不为j的反向边)\)
其中\(out[p]\)表示p的出边(我们把题目中的每条无向拆成两个有向边)
最后求\(ans\)的时候,只需要枚举哪些边的终点是目标点,然后加起来即可
通过具体的边的限制,我们就能满足题目中的那个要求。
qwq但是我们发现,如果暴力转移的话,时间复杂度是不能够接受的。
考虑到每次只从\(i\)转移到\(i+1\)。
所以可以构造一个转移矩阵。
对于一个状态\(dp[x][i]\),然后在如果他能对编号为\(j\)的边产生贡献,那么我们把构造矩阵\(a[i][j]\)++
for (int i=1;i<=cnt;i++)
{
int to = y[i];
for (int j=0;j<out[to].size();j++)
{
int now = out[to][j];
if((i+1)==((now+1)^1)) continue;
b.a[i][now]++;
}
}
注意不能通过具体的点来判断,而要判断是否为反向边。
然后我们强行令初始矩阵为dp[1][i]的值,就是强行走一步,然后快速幂出来\(k-1\)次方的值,二者相乘,最后求解即可。
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define mk make_pair
#define pb push_back
#define ll long long
#define int long long
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn = 150;
const int maxm = 1e5+1e2;
const int mod = 45989;
struct Ju{
int x,y;
int a[maxn][maxn];
Ju operator * (Ju b)
{
Ju ans;
memset(ans.a,0,sizeof(ans.a));
ans.x=x;
ans.y=b.y;
for (register int i=1;i<=ans.x;++i)
for (register int j=1;j<=ans.y;++j)
for (register int k=1;k<=y;++k)
ans.a[i][j]=(ans.a[i][j]+a[i][k]*b.a[k][j]%mod)%mod;
return ans;
}
};
Ju qsm(Ju i,int j)
{
Ju ans;
memset(ans.a,0,sizeof(ans.a));
ans.x=i.x;
ans.y=i.y;
for (int p=1;p<=i.x;p++) ans.a[p][p]=1;
while(j)
{
if (j&1) ans=ans*i;
i=i*i;
j>>=1;
}
return ans;
};
Ju a,b;
int n,m,k,s,t;
int x[maxm],y[maxm],w[maxm];
int cnt=0;
vector<int> in[maxn],out[maxn];
signed main()
{
n=read();m=read(),k=read(),s=read(),t=read();
s++;
t++;
for (int i=1;i<=m;i++)
{
int u=read(),v=read();
u++;
v++;
++cnt;
x[cnt]=u,y[cnt]=v;
++cnt;
x[cnt]=v,y[cnt]=u;
}
for (int i=1;i<=cnt;i++)
{
out[x[i]].pb(i);
in[y[i]].pb(i);
}
for (int i=1;i<=cnt;i++)
{
int to = y[i];
for (int j=0;j<out[to].size();j++)
{
int now = out[to][j];
if((i+1)==((now+1)^1)) continue;
b.a[i][now]++;
}
}
//for (int i=1;i<=cnt;i++)
// {
// for (int j=1;j<=cnt;j++) cout<<b.a[i][j]<<" ";
// cout<<endl;
//}
for (int i=0;i<out[s].size();i++)
{
a.a[1][out[s][i]]++;
//cout<<out[s][i]<<" "<<endl;
}
//cout<<"******************"<<endl;
//for (int i=1;i<=cnt;i++) cout<<a.a[1][i]<<" ";
//cout<<endl;
a.x=1;
a.y=cnt;
b.x=cnt;
b.y=cnt;
b=qsm(b,k-1);
a=a*b;
int ans = 0;
for (int i=1;i<=cnt;i++)
{
if (y[i]==t) ans=(ans+a.a[1][i])%mod;
//cout<<ans<<endl;
}
cout<<ans;
return 0;
}
手机扫一扫
移动阅读更方便
你可能感兴趣的文章