广义后缀自动机小结
阅读原文时间:2021年04月25日阅读:1

定义:

广义后缀自动机是建在 T r i e Trie Trie 树上的后缀自动机,和加特殊字符拼接相比好像就是空间上的优化? 实现上就是每次加入新的模式串的时候,将 l a s t last last 结点重置为 r o o t root root。好像也没啥特殊的。

一波题解:

sol: 寻找在各种置换下本质不同的子串个数。
由于字符集只有 a b c abc abc,可行的置换只有 3 ! 3! 3!种。对原始串在这6种置换下的表示建广义后缀自动机。发现对于只有一种字符的子串,每个重复了3种,这种子串的个数可以通过求最长连续同色子串的长度得到,如 a a a aaa aaa, c c c c c c c ccccccc ccccccc。其余的子串则重复了 3 ! 3! 3!次。则最终的答案为 ( 所 有 子 串 个 数 + 同 字 符 子 串 个 数 ∗ 3 ) / 6 (所有子串个数 + 同字符子串个数 * 3)/6 (所有子串个数+同字符子串个数∗3)/6.

code:

#include<bits/stdc++.h>

using namespace std;

typedef long long ll;

const int maxn = 8e5+5;
const int s_sz = 4;
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;

#define fi first
#define se second
#define MP make_pair
#define pii pair<int,int>

int pos[maxn],sum[maxn],tmp[maxn];
int maxx,cur;

string ss[10];
char s1[maxn],s2[maxn];

void init(){
    cur = 0;
    ss[0] = "abc";
    do{
        ss[++cur] = ss[0];
    }while(next_permutation(ss[0].begin(),ss[0].end()));
}

struct SAM{
    int ch[maxn][s_sz];
    int rt,sz,last;
    int len[maxn],suf[maxn],r[maxn];
    ll val[maxn],pre[maxn];

    void init(){
        memset(ch,0,sizeof(ch[0]) * (sz+1));
        memset(suf,0,sizeof(int)*(sz+1));
        memset(r,0,sizeof(int)*(sz+1));
        rt = sz = last = 1;
    }

    inline void add(int x,int c){
        int p = last,np = ++sz;
        last = np;
        len[np] = x;
        while(p && !ch[p][c]){
            ch[p][c] = np;
            p = suf[p];
        }
        if(!p){
            suf[np] = rt;
            return;
        }
        int q = ch[p][c];
        if(len[q] == len[p] + 1) suf[np] = q;
        else{
            int nq = ++ sz;
            len[nq] = len[p] + 1;
            memcpy(ch[nq],ch[q],sizeof(ch[q]));
            suf[nq] = suf[q];
            suf[np] = suf[q] = nq;
            while(ch[p][c] == q){
                ch[p][c] = nq;
                p = suf[p];
            }
        }
    }

    inline int idx(char c){
        return c - 'a';
    }

    inline void build(char* s){
        last = rt;
        int n = strlen(s);
        for(int i = 0;i<n;i++){
            add(i+1,idx(s[i]));
        }
    }

    inline void Topsort(int n){
        memset(sum,0,sizeof(int)*(n+1));
        for(int i = 1;i<=sz;i++) sum[len[i]] ++ ;
        for(int i = 1;i<=n;i++) sum[i] += sum[i-1];
        for(int i = 1;i<=sz;i++) tmp[sum[len[i]]--] = i;
    }

    inline void get_right(){
        for(int i = sz;i;i--){
            int u = tmp[i];
            if(suf[u]) r[suf[u]] += r[u];
        }
    }

    inline ll Query(){
        ll ret = 0;
        for(int i = rt + 1;i<=sz;i++){
            ret += len[i] - len[suf[i]];
        }
        return ret;
    }
}sam;

int main(){
    int n;
    init();
    while(~scanf("%d",&n)){
        scanf("%s",s1);
        s2[n] = '\0';
        sam.init();
        for(int i = 1;i<=cur;i++){
            for(int j = 0;j<n;j++){
                int c = s1[j] - 'a';
                s2[j] = ss[i][c];
            }
            sam.build(s2);
        }
        sam.Topsort(n);
        sam.get_right();
        ll ans = sam.Query();
        int res = 1;
        int maxx = 1;
        for(int i = 1;i<n;i++){
            if(s1[i] == s1[i-1]) res++;
            else res = 1;
            maxx = max(maxx,res);
        }
        ans += maxx * 3;
        ans /= 6;
        printf("%lld\n",ans);
    }
    return 0;
}

sol:求有n个大串和m个询问,每次给出一个字符串s询问在多少个大串中出现过
广义自动机裸题。。。 记录每个状态在几个大串中出现即可。

code:

#include<bits/stdc++.h>

using namespace std;

typedef long long ll;

const int maxn = 4e5+5;
const int s_sz = 26;
const int inf = 0x3f3f3f3f;

#define fi first
#define se second
#define MP make_pair
#define pii pair<int,int>

int sum[maxn],tmp[maxn],pos[maxn];
char str[maxn];
char s1[maxn];
ll f[maxn];

struct SAM{
    int ch[maxn][s_sz];
    int rt,sz,last;
    int len[maxn],suf[maxn],r[maxn];
    int cnt[maxn],pre[maxn];

    void init(){
        memset(ch,0,sizeof(ch[0]) * (sz+1));
        memset(suf,0,sizeof(int)*(sz+1));
        memset(r,0,sizeof(int)*(sz+1));
        rt = sz = last = 1;
    }

    inline void add(int x,int c){
        int p = last,np = ++sz;
        last = np;
        len[np] = x;
        while(p && !ch[p][c]){
            ch[p][c] = np;
            p = suf[p];
        }
        if(!p){
            suf[np] = rt;
            return;
        }
        int q = ch[p][c];
        if(len[q] == len[p] + 1) suf[np] = q;
        else{
            int nq = ++ sz;
            len[nq] = len[p] + 1;
            memcpy(ch[nq],ch[q],sizeof(ch[q]));
            suf[nq] = suf[q];
            suf[np] = suf[q] = nq;
            while(ch[p][c] == q){
                ch[p][c] = nq;
                p = suf[p];
            }
        }
    }

    inline int idx(char c){
        return c - 'a';
    }

    inline void build(char* s){
        last = rt;
        int n = strlen(s);
        for(int i = 0;i<n;i++){
            add(i+1,idx(s[i]));
        }
    }

    inline void Topsort(int n){
        memset(sum,0,sizeof(int)*(n+1));
        for(int i = 1;i<=sz;i++) sum[len[i]] ++ ;
        for(int i = 1;i<=n;i++) sum[i] += sum[i-1];
        for(int i = 1;i<=sz;i++) tmp[sum[len[i]]--] = i;
    }

    inline void work(char* s,int cur){
        int n = strlen(s);
        int u = rt;
        for(int i = 0;i<n;i++){
            u = ch[u][idx(s[i])];
            int fp = u;
            while(fp && pre[fp] != cur){
                cnt[fp] ++;
                pre[fp] = cur;
                fp = suf[fp];
            }
        }
    }

    inline ll sol(char* s){
        ll ret = 0;
        int n = strlen(s);
        int u = rt;
        for(int i = 0;i<n;i++){
            u = ch[u][idx(s[i])];
            if(!u) return 0;
        }
        return cnt[u];
    }
}sam;

int main(){
    int n,m;
    scanf("%d%d",&n,&m);
    int Last = 0;
    pos[0] = Last;
    for(int i = 1;i<=n;i++){
        scanf("%s",str+Last);
        int len = strlen(str+Last);
        Last += len;
        str[Last] = '\0';
        Last++;
        pos[i+1] = Last;
    }
    sam.init();
    for(int i = 1;i<=n;i++) {
        sam.build(str+pos[i]);
    }
    for(int i = 1;i<=n;i++) sam.work(str+pos[i],i);
    while(m--){
        scanf("%s",s1);
        ll ans = sam.sol(s1);
        printf("%lld\n",ans);
    }
    return 0;
}

sol:给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?
对每个状态维护 c n t cnt cnt 表示是几个串的子串和 p r e pre pre 表示上个更新的串是哪个。每插入一个新串,对所有经过的节点,沿 p a r e n t parent parent树暴力向上更新,直到遇到第一个已经被当前串更新的结点为止。 这个复杂度不是很懂怎么算啊。
建完广义后缀自动机后,对每个结点 s s s 沿拓扑序计算 s s s 和 s s s 所有祖先的贡献。对每个串,在自动机上扫一遍,把贡献累加起来就行。
upd: 这个题还有一个启发式合并的做法。每个节点维护一个 s e t set set存包含这个子串的模式串的标号 i d ∈ [ 1 , n ] id\in[1,n] id∈[1,n],建好 p a r e n t parent parent树之后从根开始 d f s dfs dfs,启发式合并即可。由于每个节点只有一个父亲,我们只需要 s e t set set集合的大小(用一个数组另外存),而每个节点的 s e t set set在被父节点调用时信息是完整的,所以统计的答案是正确的。

code:

#include<bits/stdc++.h>

using namespace std;

typedef long long ll;

const int maxn = 4e5+5;
const int s_sz = 26;
const int inf = 0x3f3f3f3f;

#define fi first
#define se second
#define MP make_pair
#define pii pair<int,int>

int sum[maxn],tmp[maxn],pos[maxn];
char str[maxn];
ll f[maxn];

struct SAM{
    int ch[maxn][s_sz];
    int rt,sz,last;
    int len[maxn],suf[maxn],r[maxn];
    int cnt[maxn],pre[maxn];

    void init(){
        memset(ch,0,sizeof(ch[0]) * (sz+1));
        memset(suf,0,sizeof(int)*(sz+1));
        memset(r,0,sizeof(int)*(sz+1));
        rt = sz = last = 1;
    }

    inline void add(int x,int c){
        int p = last,np = ++sz;
        last = np;
        len[np] = x;
        while(p && !ch[p][c]){
            ch[p][c] = np;
            p = suf[p];
        }
        if(!p){
            suf[np] = rt;
            return;
        }
        int q = ch[p][c];
        if(len[q] == len[p] + 1) suf[np] = q;
        else{
            int nq = ++ sz;
            len[nq] = len[p] + 1;
            memcpy(ch[nq],ch[q],sizeof(ch[q]));
            suf[nq] = suf[q];
            suf[np] = suf[q] = nq;
            while(ch[p][c] == q){
                ch[p][c] = nq;
                p = suf[p];
            }
        }
    }

    inline int idx(char c){
        return c - 'a';
    }

    inline void build(char* s){
        last = rt;
        int n = strlen(s);
        for(int i = 0;i<n;i++){
            add(i+1,idx(s[i]));
        }
    }

    inline void Topsort(int n){
        memset(sum,0,sizeof(int)*(n+1));
        for(int i = 1;i<=sz;i++) sum[len[i]] ++ ;
        for(int i = 1;i<=n;i++) sum[i] += sum[i-1];
        for(int i = 1;i<=sz;i++) tmp[sum[len[i]]--] = i;
    }

    inline void work(char* s,int cur){
        int n = strlen(s);
        int u = rt;
        for(int i = 0;i<n;i++){
            u = ch[u][idx(s[i])];
            int fp = u;
            while(fp && pre[fp] != cur){
                cnt[fp] ++;
                pre[fp] = cur;
                fp = suf[fp];
            }
        }
    }

    inline void get_f(int k){
        memset(f,0,sizeof(int)*(sz+1));
        Topsort(1e5+5);
        f[rt] = 0;
        for(int i = 1;i<=sz;i++){
            int u = tmp[i];
            if(cnt[u]>=k) f[u] = len[u] - len[suf[u]];
            f[u] += f[suf[u]];
        }
    }

    inline ll sol(char* s,int k){
        ll ret = 0;
        int n = strlen(s);
        int u = rt;
        for(int i = 0;i<n;i++){
            u = ch[u][idx(s[i])];
            ret += f[u];
        }
        return ret;
    }
}sam;

int main(){
    int n,k;
    scanf("%d%d",&n,&k);
    int Last = 0;
    pos[0] = Last;
    for(int i = 1;i<=n;i++){
        scanf("%s",str+Last);
        int len = strlen(str+Last);
        Last += len;
        str[Last] = '\0';
        Last++;
        pos[i+1] = Last;
    }
    sam.init();
    for(int i = 1;i<=n;i++) {
        sam.build(str+pos[i]);
    }
    for(int i = 1;i<=n;i++) sam.work(str+pos[i],i);
    sam.get_f(k);
    for(int i = 1;i<=n;i++) {
        if(i>1) printf(" ");
        printf("%lld",sam.sol(str+pos[i],k));
    }
    return 0;
}

sol: 给定一颗树,树上每个结点对应一个字符,每个路径对应一个字符串。问所有可能的子串有多少种。
叶子结点最多只有20个。考虑从叶子为起点开始进行爆搜,将路径上所有的子串插入到后缀自动机中。显然不可能将子串全部生成后再插入,考虑在搜索的同时维护上个生成的串对应的结点,那么新的结点就直接在上个结点的基础上扩展。建出后就是模板操作,直接统计即可。

code:

#include<bits/stdc++.h>

using namespace std;

typedef long long ll;

const int maxn = 4e6+5;
const int s_sz = 10;
const int inf = 0x3f3f3f3f;

#define fi first
#define se second
#define MP make_pair
#define pii pair<int,int>

int sum[maxn],tmp[maxn];
int sta[maxn],top;

struct SAM{
    int ch[maxn][s_sz];
    int rt,sz,last;
    int len[maxn],suf[maxn],r[maxn];

    void init(){
        memset(ch,0,sizeof(ch[0]) * (sz+1));
        memset(suf,0,sizeof(int)*(sz+1));
        memset(r,0,sizeof(int)*(sz+1));
        rt = sz = last = 1;
    }

    inline int add(int pre,int x,int c){
        int p = pre,np = ++sz;
        last = np;
        len[np] = x;
        while(p && !ch[p][c]){
            ch[p][c] = np;
            p = suf[p];
        }
        if(!p){
            suf[np] = rt;
            return last;
        }
        int q = ch[p][c];
        if(len[q] == len[p] + 1) suf[np] = q;
        else{
            int nq = ++ sz;
            len[nq] = len[p] + 1;
            memcpy(ch[nq],ch[q],sizeof(ch[q]));
            suf[nq] = suf[q];
            suf[np] = suf[q] = nq;
            while(ch[p][c] == q){
                ch[p][c] = nq;
                p = suf[p];
            }
        }
        return last;
    }

    inline int idx(char c){
        return c - 'A';
    }

    inline void Topsort(int n){
        memset(sum,0,sizeof(int)*(n+1));
        for(int i = 1;i<=sz;i++) sum[len[i]] ++ ;
        for(int i = 1;i<=n;i++) sum[i] += sum[i-1];
        for(int i = 1;i<=sz;i++) tmp[sum[len[i]]--] = i;
    }

    inline void get_right(char* s){
        int u = rt;
        int n = strlen(s);
        for(int i = 0;i<n;i++){
            u = ch[u][idx(s[i])];
            r[u] = 1;
        }
        for(int i = sz;i;i--){
            int u = tmp[i];
            r[suf[u]] += r[u];
        }
    }

    inline ll work(){
        ll ret = 0;
        for(int i = rt+1;i<=sz;i++){
            ret += len[i] - len[suf[i]];
        }
        return ret;
    }
}sam;

int col[maxn];
vector<int> G[maxn];

void DFS(int sta,int u,int fa,int L){
    // cout<<sta<<' '<<u<<' '<<fa<<' '<<L<<endl;
    int newsta = sam.add(sta,L,col[u]);
    for(int i = 0;i<G[u].size();i++){
        int v = G[u][i];
        if(v == fa) continue;
        DFS(newsta,v,u,L+1);
    }
}

int main(){
    int n;
    int coll;
    scanf("%d%d",&n,&coll);
    for(int i = 1;i<=n;i++) scanf("%d",&col[i]);
    for(int i = 1;i<n;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    sam.init();
    for(int i = 1;i<=n;i++){
        if(G[i].size()==1) {
            DFS(sam.rt,i,0,1);
        }
    }
    ll ans = sam.work();
    printf("%lld\n",ans);
    return 0;
}

sol: 给一堆模式串和对应的权值。一个子串的权值是它在其中出现过的模式串权值的积,问不超过L的串长的权值的期望。
用类似上上个题的做法维护每个状态对应的权值。注意到一个状态 s t a sta sta 表示的子串是以 r i g h t right right集合为最后一个字符,长度在 ( l e n [ s u f [ s t a ] ] , l e n [ s t a ] ] ( len[suf[sta]],len[sta] ] (len[suf[sta]],len[sta]]之间的后缀。考虑用树状数组去维护差分数组,就可以快速地求出长度为 L L L的子串权值之和。求出每个长度恰好为 L L L的子串贡献之后,再正向递推求出小于等于 L L L的子串的权值之和。长度不超过 L L L的子串一共有 ∑ i = 1 L 2 6 i \sum_{i=1}^{L} 26^i ∑i=1L​26i,等比数列求和搞一下就行。

upd: 直接线性维护差分数组即可,因为只有最后一次查询,求两遍前缀和即可。用树状数组的我怕是石乐志。

code:

#include<bits/stdc++.h>

using namespace std;

typedef long long ll;

const int maxn = 4e5+5;
const int s_sz = 26;
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;

#define fi first
#define se second
#define MP make_pair
#define pii pair<int,int>

int pos[maxn];
char str[maxn];
ll h[maxn],A[maxn];
ll f[1000000+10];
int maxx;

void Mul(ll& x,ll y){
    x *= y;
    if(x>=mod) x%=mod;
}

void Add(ll& x,ll y){
    x += y;
    if(x>=mod) x%=mod;
}

inline int lowbit(int x){ return x&-x; }

void Modify(int x,ll p){
    for(int j = x;j<=maxx;j+=lowbit(j)){
        Add(A[j],p);
    }
}

ll Query(int x){
    ll ret = 0;
    for(int j = x;j;j-=lowbit(j)){
        Add(ret,A[j]);
    }
    return ret;
}

struct SAM{
    int ch[maxn][s_sz];
    int rt,sz,last;
    int len[maxn],suf[maxn],r[maxn];
    ll val[maxn],pre[maxn];

    void init(){
        memset(ch,0,sizeof(ch[0]) * (sz+1));
        memset(suf,0,sizeof(int)*(sz+1));
        memset(r,0,sizeof(int)*(sz+1));
        rt = sz = last = 1;
    }

    inline void add(int x,int c){
        int p = last,np = ++sz;
        last = np;
        len[np] = x;
        while(p && !ch[p][c]){
            ch[p][c] = np;
            p = suf[p];
        }
        if(!p){
            suf[np] = rt;
            return;
        }
        int q = ch[p][c];
        if(len[q] == len[p] + 1) suf[np] = q;
        else{
            int nq = ++ sz;
            len[nq] = len[p] + 1;
            memcpy(ch[nq],ch[q],sizeof(ch[q]));
            suf[nq] = suf[q];
            suf[np] = suf[q] = nq;
            while(ch[p][c] == q){
                ch[p][c] = nq;
                p = suf[p];
            }
        }
    }

    inline int idx(char c){
        return c - 'a';
    }

    inline void build(char* s){
        last = rt;
        int n = strlen(s);
        for(int i = 0;i<n;i++){
            add(i+1,idx(s[i]));
        }
    }

    inline void work(char* s,int cur){
        int n = strlen(s);
        int u = rt;
        for(int i = 0;i<n;i++){
            u = ch[u][idx(s[i])];
            int fp = u;
            while(fp && pre[fp] != cur){
                Mul(val[fp],h[cur]);
                pre[fp] = cur;
                fp = suf[fp];
            }
        }
    }

    inline void sol(){
        for(int i = rt+1;i<=sz;i++){
            ll tmp = val[i];
            Modify(len[suf[i]]+1,tmp);
            Modify(len[i]+1,(mod - tmp) % mod);
        }
        for(int i = 1;i<=maxx;i++){
            f[i] = Query(i);
            Add(f[i],f[i-1]);
        }
    }
}sam;

ll qpow(ll a,ll b){
    ll ret = 1;
    while(b){
        if(b&1) Mul(ret,a);
        Mul(a,a);
        b>>=1;
    }
    return ret;
}

ll Inv(ll n){
    return qpow(n,mod-2);
}

int main(){
    int n;
    scanf("%d",&n);
    int Last = 0;
    pos[0] = Last;
    maxx = 0;
    for(int i = 1;i<=n;i++){
        scanf("%s",str+Last);
        int len = strlen(str+Last);
        Last += len;
        str[Last] = '\0';
        Last++;
        pos[i+1] = Last;
        maxx = max(Last - pos[i] + 5,maxx);
    }
    for(int i = 1;i<=n;i++) scanf("%lld",&h[i]);    
    sam.init();
    for(int i = 1;i<=n;i++) {
        sam.build(str+pos[i]);
    }
    for(int i = 1;i<=sam.sz;i++) {
        sam.val[i] = 1;
    }
    for(int i = 1;i<=n;i++) {
        sam.work(str+pos[i],i);
    }
    sam.sol();
    int m;
    scanf("%d",&m);
    int invv = Inv(25);
    while(m--){
        int L;
        scanf("%d",&L);
        ll ans =(qpow(26,L)-1) * 26 % mod;
        Mul(ans,invv);
        ans = Inv(ans);
        Mul(ans,f[L]);
        printf("%lld\n",ans);
    }
    return 0;
}