Codeforces 587F - Duff is Mad(根号分治+AC 自动机+树状数组)
阅读原文时间:2022年03月25日阅读:1

题面传送门

第一眼看成了 CF547E……

话说 CF587F 和 CF547E 出题人一样欸……还有另一道 AC 自动机的题 CF696D 也是这位名叫 PrinceOfPersia 的出题人出的……似乎这位神仙很擅长字符串?

跑题了跑题了……

令 \(m=\sum|s_i|\)

u1s1 我做这题的时候一直在考虑按照 CF547E 的套路对询问进行差分处理,然鹅并没有什么卵用。因为这题相当于每次加入一个字符串时将这个字符串结尾位置 fail 树的子树中的所有节点访问次数 \(+1\),并询问某个节点 \(x\) 在 trie 树上到根节点路径上所有点总共被标记了多少次,然鹅这个东西是无法用 DS 直接维护的,trie 树与 fail 树之间也无法通过 DFS 序什么的建立直接的联系,故该算法是没有前途的。

考虑换一种思路。首先我们可以想到一个最最暴力的做法,将 trie 树上 \(s_k\) 结尾位置到根节点的路径上所有点访问次数 \(+1\),并枚举所有字符串 \(i\in [l,r]\),统计 \(s_i\) 在 fail 树的子树中所有点的标记次数之和(这也就是 \(s_i\) 在 \(s_k\) 中出现的次数)并累加进答案中。其次我们还可以想到一个刚才那个算法的暴力版本,还是将询问差分处理,拆成两个形如“前 \(x\) 个字符串在 \(s_y\) 中总共出现了多少次”并将这样的询问挂在 \(x\) 上。考虑动态插入这些字符串,当我们插入某个字符串 \(s_i\) 的时候就用树状数组+DFS 序令 \(s_i\) 在 fail 树的子树中所有点的标记次数 \(+1\),然而对于查询就没有什么好方法了,只能暴力跳 \(s_k\) 在 trie 树上的祖先并将该路径上所有点的标记次数累加入答案。

还是考虑将两个暴力结合一下,不难发现第一个暴力对每个询问都暴力求了一遍,在 \(q\) 比较小的时候能体现出较大优势,第二个暴力插入某个字符串时操作效率较高,但查询出现次数时复杂度与 \(|s_k|\) 的长度有关,在 \(|s_i|\) 比较小的时候能体现出较大优势。按照套路根分,设一个阈值 \(T\),对于长度 \(>T\) 的字符串 \(s_k\) 采用暴力一求出所有字符串在 \(s_k\) 中出现的次数,查询的时候直接前缀和相减即可,由于这样的字符串最多 \(\dfrac{m}{T}\) 个,故复杂度 \(\dfrac{m^2}{T}\)。对于长度 \(\leq T\) 的字符串就采用暴力二,复杂度 \(qT\log m\),总复杂度 \(\dfrac{m^2}{T}+qT\log m\),根据均值不等式可知 \(T=\dfrac{m}{\sqrt{q\log m}}\) 时复杂度最优。

当然如果使用区间修改 \(\mathcal O(\sqrt{m})\),单点查询 \(\mathcal O(1)\) 的分块可使暴力二部分的复杂度降到 \(n\sqrt{m}+qT\),但由于直接树状数组能过就没写了。

#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define fz(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define ffe(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++)
#define fill0(a) memset(a,0,sizeof(a))
#define fill1(a) memset(a,-1,sizeof(a))
#define fillbig(a) memset(a,63,sizeof(a))
#define pb push_back
#define ppb pop_back
#define mp make_pair
template<typename T1,typename T2> void chkmin(T1 &x,T2 y){if(x>y) x=y;}
template<typename T1,typename T2> void chkmax(T1 &x,T2 y){if(x<y) x=y;}
typedef pair<int,int> pii;
typedef long long ll;
template<typename T> void read(T &x){
    x=0;char c=getchar();T neg=1;
    while(!isdigit(c)){if(c=='-') neg=-1;c=getchar();}
    while(isdigit(c)) x=x*10+c-'0',c=getchar();
    x*=neg;
}
const int MAXN=1e5;
const int ALPHA=26;
int n,qu,m=0,blk;string s[MAXN+5];
int ch[MAXN+5][ALPHA+2],ncnt=0,fail[MAXN+5],ed[MAXN+5];
void insert(string s,int id){
    int cur=0;
    for(int i=0;i<s.size();i++){
        if(!ch[cur][s[i]-'a']) ch[cur][s[i]-'a']=++ncnt;
        cur=ch[cur][s[i]-'a'];
    } ed[id]=cur;
}
void getfail(){
    queue<int> q;
    for(int i=0;i<ALPHA;i++) if(ch[0][i]) q.push(ch[0][i]);
    while(!q.empty()){
        int x=q.front();q.pop();
        for(int i=0;i<ALPHA;i++){
            if(ch[x][i]){fail[ch[x][i]]=ch[fail[x]][i];q.push(ch[x][i]);}
            else ch[x][i]=ch[fail[x]][i];
        }
    }
}
int hd[MAXN+5],to[MAXN+5],nxt[MAXN+5],ec=0;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int tim=0,bgt[MAXN+5],edt[MAXN+5];
void dfs(int x){bgt[x]=++tim;for(int e=hd[x];e;e=nxt[e]) dfs(to[e]);edt[x]=tim;}
ll sum[MAXN+5],ss[MAXN+5];
struct query{int l,r,k;} q[MAXN+5];
vector<int> large;
vector<pii> qv[MAXN+5];
ll ans[MAXN+5];
ll t[MAXN+5];
void add(int x,int v){for(int i=x;i<=ncnt+1;i+=(i&(-i))) t[i]+=v;}
ll query(int x){int ret=0;for(int i=x;i;i&=(i-1)) ret+=t[i];return ret;}
int main(){
    scanf("%d%d",&n,&qu);
    for(int i=1;i<=n;i++) cin>>s[i],m+=s[i].size(),insert(s[i],i);
    blk=(int)m/sqrt(qu*log(m)/log(2));
    getfail();for(int i=1;i<=ncnt;i++) adde(fail[i],i);dfs(0);
    for(int i=1;i<=qu;i++){
        scanf("%d%d%d",&q[i].l,&q[i].r,&q[i].k);
        if(s[q[i].k].size()>blk) large.pb(i);
        else qv[q[i].l-1].pb(mp(i,-1)),qv[q[i].r].pb(mp(i,1));
    }
    sort(large.begin(),large.end(),[](int x,int y){return q[x].k<q[y].k;});
    for(int i=0;i<large.size();i++){
        if(i==0||q[large[i]].k!=q[large[i-1]].k){
            int id=q[large[i]].k,cur=0;
            memset(sum,0,sizeof(sum));
            memset(ss,0,sizeof(ss));
            for(int j=0;j<s[id].size();j++) cur=ch[cur][s[id][j]-'a'],sum[bgt[cur]]++;
            for(int j=1;j<=ncnt+1;j++) sum[j]+=sum[j-1];
            for(int j=1;j<=n;j++) ss[j]=ss[j-1]+sum[edt[ed[j]]]-sum[bgt[ed[j]]-1];
        } ans[large[i]]=ss[q[large[i]].r]-ss[q[large[i]].l-1];
    }
    for(int i=1;i<=n;i++){
        add(bgt[ed[i]],1);add(edt[ed[i]]+1,-1);
        for(int j=0;j<qv[i].size();j++){
            int id=qv[i][j].fi,mul=qv[i][j].se,cur=0;
            ll ret=0;
            for(int k=0;k<s[q[id].k].size();k++){
                cur=ch[cur][s[q[id].k][k]-'a'];
                ret+=query(bgt[cur]);
            } ans[id]+=ret*mul;
        }
    }
    for(int i=1;i<=qu;i++) printf("%lld\n",ans[i]);
    return 0;
}