HAOI 2018 Round 1 题解
阅读原文时间:2023年07月11日阅读:2

无聊了开一套省选题刷刷……u1s1 感觉三个题都不错,难度也挺有梯度,是一道标准的省选难度的题(话说 CSP 前你刷省选题干嘛/ts/ts)

小 C 珂海星

T1:P4495 [HAOI2018]奇怪的背包(基础数论)

一开始看错题了,以为是不可做题(

根据斐蜀定理,假设取出的体积集合为 \(S\),那么可以拼出 \(w\) 的充要条件是 \(\gcd_{x\in S}\gcd(x,P)\mid\gcd(w,P)\),不难发现 \(\gcd(x,P),\gcd(w,P)\) 肯定是 \(P\) 的约数,因此我们考虑给 \(P\) 的约数编号 \(1,2,3,\cdots,d(P)\)。然后我们设 \(f_i\) 表示 \(S\) 的 \(\gcd\) 是编号为 \(i\) 的约数的倍数的集合 \(S\) 的个数,\(g_i\) 表示 \(S\) 的 \(\gcd\) 恰好是编号为 \(i\) 的约数的集合 \(S\) 的个数。那么如果记 \(c_i\) 表示有多少个 \(\gcd(V_i,P)\) 是编号为 \(i\) 的约数的倍数,根据莫反那一套理论 \(f_i=2^{c_i}\),而 \(g_i\) 可以通过对 \(f_i\) 做狄利克雷差分求得,再一遍狄利克雷前缀和即可求出每个 \(\gcd\) 的答案。由于这题 \(d(P)\) 只有 \(10^3\),因此可以暴力 \(d^2(P)\) 狄利克雷前缀和/差分。时间复杂度 \(\mathcal O(\sqrt{P}+d^2(P)+(n+q)\log P)\)

const int MAXF=2000;
const int MAXN=1e6;
int n,qu,p,c[MAXF+5],b[MAXF+5],pw[MAXN+5],cs[MAXF+5],css[MAXF+5];
vector<int> fc;map<int,int> id;
int main(){
    scanf("%d%d%d",&n,&qu,&p);
    for(int i=1;i*i<=p;i++) if(p%i==0){
        fc.pb(i);
        if(p/i!=i) fc.pb(p/i);
    } sort(fc.begin(),fc.end());
    for(int i=0;i<fc.size();i++) id[fc[i]]=i;
    for(int i=1,x;i<=n;i++) scanf("%d",&x),c[id[__gcd(x,p)]]++;
    for(int i=0;i<fc.size();i++) for(int j=0;j<fc.size();j++)
        if(fc[j]%fc[i]==0) b[i]+=c[j];
    for(int i=(pw[0]=1);i<=n;i++) pw[i]=(pw[i-1]<<1)%MOD;
    for(int i=0;i<fc.size();i++) cs[i]=pw[b[i]];
    for(int i=(int)(fc.size())-1;~i;i--){
        for(int j=i+1;j<fc.size();j++) if(fc[j]%fc[i]==0)
            cs[i]=(cs[i]-cs[j]+MOD)%MOD;
    }
    for(int i=0;i<fc.size();i++) for(int j=0;j<=i;j++) if(fc[i]%fc[j]==0)
        css[i]=(css[i]+cs[j])%MOD;
    while(qu--){int w;scanf("%d",&w);printf("%d\n",css[id[__gcd(p,w)]]);}
    return 0;
}

T2:P4494 [HAOI2018]反色游戏(圆方树)

大概是这题的一个子问题?

首先注意到这题的样例答案要么是 \(0\) 要么是 \(2\) 的整数次幂,考虑什么情况下为 \(0\) 什么情况下为 \(2\) 的幂。注意到在我们操作的过程中,每个连通块中 \(1\) 的奇偶性是不会变的,因此如果一个连通块中 \(1\) 的个数是奇数答案就是 \(0\)。否则我们考虑对于每个连通块任取一个生成树,那么有这样一个性质,如果非树边选/不选的状态确定了,那么整个连通块的状态就确定了,具体构造方法就是从叶子开始,如果叶子权值是 \(1\) 就选其与父亲的边,如果是 \(0\) 则不选。因此一个连通块的方案数就是 \(2^{n'-m'+1}\),其中 \(n',m'\) 分别为该连通块的点数和边数。把所有连通块的方案乘在一起可以得到 \(2^{n-m+c}\)。其中 \(n,m,c\) 分别为点数、边数、连通块个数。

考虑怎么求这个东西,点数、边数的求法都是容易的,比较难的是连通块个数,这里有两种方法:

  • 方法一:注意到删点很困难,因此考虑线段树分治将删点变成加边,这样可以并查集维护,复杂度 \(Tn\log^2n\),然后便获得了 70 分的好成绩……
  • 方法二:注意到此题与 P5227 连通图 不同的一点是,此题每次只删一个点,因此我们不妨从割点的角度思考,我们考虑建出原图的圆方树,那么根据圆方树的一个性质:对于一个连通图中的每个点,删掉这个点后图中所有连通块,就是圆方树上所有与这个点相邻的所有方点所在的连通块,这样即可计算出删掉一个点后会多出多少个连通块,也进而可以计算出有多少个连通块满足连通块中 \(1\) 的个数为奇数。时间复杂度 \(\mathcal O(\sum n+m)\),可以通过。

这道题竟码了我 170+ 行,主要我把之前写的线段树分治给注释掉了导致代码里一车注释(

const int MAXN=1e5;
int n,m;char str[MAXN+5];
//struct node{int l,r;vector<pii> e;} s[MAXN*4+5];
//int bad_cnt=0,mul=1;
//void build(int k,int l,int r){
//    s[k].l=l;s[k].r=r;s[k].e.clear();if(l==r) return;
//    int mid=l+r>>1;build(k<<1,l,mid);build(k<<1|1,mid+1,r);
//}
//void ins(int k,int l,int r,pii v){
//    if(l>r) return;
//    if(l<=s[k].l&&s[k].r<=r) return s[k].e.pb(v),void();
//    int mid=s[k].l+s[k].r>>1;
//    if(r<=mid) ins(k<<1,l,r,v);else if(l>mid) ins(k<<1|1,l,r,v);
//    else ins(k<<1,l,mid,v),ins(k<<1|1,mid+1,r,v);
//}
//int f[MAXN+5],dep[MAXN+5],sum[MAXN+5];stack<pair<pii,int> > stk;
//int find(int x){return (!f[x])?x:find(f[x]);}
//void merge(int x,int y){
//    x=find(x);y=find(y);if(x==y) return mul=2ll*mul%MOD,void();
//    if(dep[x]<dep[y]) swap(x,y);stk.push(mp(mp(x,y),(dep[x]==dep[y])));
//    f[y]=x;dep[x]+=(dep[x]==dep[y]);
//    bad_cnt-=(sum[x]&1);bad_cnt-=(sum[y]&1);
//    sum[x]+=sum[y];bad_cnt+=(sum[x]&1);
//}
//void undo(){
//    pair<pii,int> p=stk.top();int x=p.fi.fi,y=p.fi.se,z=p.se;
//    dep[x]-=z;sum[x]-=sum[y];f[y]=0;stk.pop();
//}
//int res[MAXN+5];
//void iterate(int k){
//    int tmp=stk.size(),tmp_bad_cnt=bad_cnt,tmp_mul=mul;
//    for(pii p:s[k].e) merge(p.fi,p.se);
//    if(s[k].l==s[k].r){
//        if(s[k].l>0) bad_cnt-=(str[s[k].l]=='1');
//        res[s[k].l]=(bad_cnt)?0:mul;
//        if(s[k].l>0) bad_cnt+=(str[s[k].l]=='1');
//    } else iterate(k<<1),iterate(k<<1|1);
//    while(stk.size()>tmp) undo();
//    bad_cnt=tmp_bad_cnt;mul=tmp_mul;
//}
int deg[MAXN+5],pw[MAXN*2+5];
link_list<int,MAXN,MAXN*2> g;
link_list<int,MAXN*2,MAXN*4> t;
int stk[MAXN*2+5],dfn[MAXN+5],tp=0,low[MAXN+5];
int tim=0,ncnt,degt[MAXN*2+5];
void tarjan(int x){
    dfn[x]=low[x]=++tim;stk[++tp]=x;
    for(int e=g.hd[x];e;e=g.nxt[e]){
        int y=g.val[e];
        if(!dfn[y]){
            tarjan(y);chkmin(low[x],low[y]);
            if(low[y]>=dfn[x]){
                ++ncnt;int o;do {
                    o=stk[tp--];degt[o]++;
                    t.ins(ncnt,o);t.ins(o,ncnt);
//                    printf("%d %d\n",ncnt,o);
                } while(o^y);//printf("%d %d\n",ncnt,x);
                t.ins(ncnt,x);t.ins(x,ncnt);degt[x]++;
            }
        } else chkmin(low[x],dfn[y]);
    }
}
int rt[MAXN*2+5],siz[MAXN*2+5],fa[MAXN*2+5],cntb[MAXN*2+5];
void dfs(int x,int f,int r){
    rt[x]=r;siz[x]=1;fa[x]=f;cntb[x]=(x<=n)?(str[x]=='1'):0;
    for(int e=t.hd[x];e;e=t.nxt[e]){
        int y=t.val[e];if(y==f) continue;dfs(y,x,r);
        siz[x]+=siz[y];cntb[x]+=cntb[y];
    }
}
void clear(){
    memset(deg,0,sizeof(deg));g.clear();t.clear();
    memset(dfn,0,sizeof(dfn));memset(low,0,sizeof(low));
    tim=0;memset(degt,0,sizeof(degt));tp=0;
    memset(rt,0,sizeof(rt));memset(siz,0,sizeof(siz));
    memset(fa,0,sizeof(fa));memset(cntb,0,sizeof(cntb));
}
void solve(){
    scanf("%d%d",&n,&m);
//    build(1,0,n);
//    for(int i=1,u,v;i<=m;i++){
//        scanf("%d%d",&u,&v);if(u>v) swap(u,v);
//        ins(1,0,u-1,mp(u,v));ins(1,u+1,v-1,mp(u,v));
//        ins(1,v+1,n,mp(u,v));
//    } scanf("%s",str+1);bad_cnt=0;mul=1;
//    for(int i=1;i<=n;i++) bad_cnt+=(str[i]=='1'),sum[i]=(str[i]-'0');
//    iterate(1);
//    for(int i=0;i<=n;i++) printf("%d%c",res[i]," \n"[i==n]);
    clear();ncnt=n;
    for(int i=1,u,v;i<=m;i++){
        scanf("%d%d",&u,&v);g.ins(u,v);g.ins(v,u);
        deg[u]++;deg[v]++;
    } scanf("%s",str+1);vector<int> roots;int c=0,sum_bad=0;
    for(int i=1;i<=n;i++) if(!dfn[i]) tarjan(i),roots.pb(i),++c;
    for(int r:roots) dfs(r,0,r),sum_bad+=(cntb[r]&1);
//    for(int i=1;i<=ncnt;i++) printf("%d\n",cntb[i]);
    printf("%d",(sum_bad)?0:pw[m-n+c]);
    for(int i=1;i<=n;i++){
        int _m=m-deg[i],_n=n-1,_c=c-1+degt[i],_sum_bad=sum_bad;
        if(cntb[rt[i]]&1) _sum_bad--;
        for(int e=t.hd[i];e;e=t.nxt[e]){
            int y=t.val[e];
            if(y==fa[i]) _sum_bad+=((cntb[rt[i]]-cntb[i])&1);
            else _sum_bad+=((cntb[y])&1);
        } printf(" %d",(_sum_bad)?0:pw[_m-_n+_c]);
    } printf("\n");
}
int main(){
//    freopen("game.in","r",stdin);freopen("game.out","w",stdout);
    for(int i=(pw[0]=1);i<=MAXN*2;i++) pw[i]=(pw[i-1]<<1)%MOD;
    int qu;scanf("%d",&qu);while(qu--) solve();
    return 0;
}
/*
1
5 5
1 2
2 3
3 4
2 4
3 5
00000
*/

T3:P4493 [HAOI2018]字串覆盖(SAM+数据分治)

首先有一个显然的性质是我们肯定会每次选出 \(P\) 在 \(S\) 中左端点 \(\ge s\) 且最靠左的出现位置,然后将这些位置全部设为访问过并且贪心地向后取,直到后面不存在合法的位置或者 \(P\) 下一次出现位置的右端点 \(\le t\)。

考虑优化这个过程,注意到此题有一个奇奇怪怪的限制:\(51\le r−l\le 2000\) 的询问不超过 \(11000\) 个且随机生成。因此考虑数据分治,具体来说对于 \(r-l>2000\),我们建出两串的 SAM 然后线段树合并维护每个字符串在 endpos,那么每次询问我们找到 \(T[l…r]\) 在 SAM 上对应的节点 \(x\),然后每次在 \(x\) 的 edp 中二分找出结尾位置 \(>\) 上一个串的结尾位置 \(-l+r\) 且结尾位置最靠前的字符串,如果超过了 \(r\) 或者不存在则 break,否则答案加一,不难注意到这样最多跳 \(\dfrac{n}{2000}\) 次,复杂度 \(q·\dfrac{n}{2000}·\log n\)。对于 \(r-l\le 50\),注意到每次都暴力向后跳 edp 复杂度过高不能接受,而可能的情况并不是很多(只有 \(50\) 个可能的长度),对于这样的模型我们考虑倍增,具体来说我们将询问离线下来,然后枚举长度 \(len\),记 \(nxt_{i,j}\) 表示对于 \(S\) 中以 \(i\) 结尾长度为 \(len\) 的字符串,从 \(i\) 开始取(当作第 \(0\) 次取的位置)向后取到的第 \(2^j\) 个子串的结束位置是什么(如果不存在则 \(nxt_{i,j}=n+1\)),\(nxt_{i,0}\) 可以通过 \(r-l>2000\) 的倍增求出,这部分复杂度 \(50n\log n\)。预处理出 \(nxt_{i,j}\) 之后即可通过倍增在 \(\log n\) 时间内回答每个询问。至于 \(51\le r−l\le 2000\)……由于题目中那个奇奇怪怪的条件,把这样的询问按照 \(r-l>2000\) 的方法处理也可通过。

然后就是实现的事情了,对于这种码农题硬着头皮码就完事了,如果你代码水平像我一样都比较那啥,可能会像我一样码个 200 行(

const int MAXN=1e5;
const int MAXP=4e5;
const int MAX_ND=MAXP<<6;
const int LOG_N=18;
const int B=50;
int n,K;char s[MAXN+5],t[MAXN+5];
link_list<int,MAXP,MAXP> g;
namespace SAM{
    int ch[MAXP+5][26],len[MAXP+5],lnk[MAXP+5],ncnt=1,cur=1;
    int ed[MAXN*2+5];
    void extend(char c,int ps){
        int id=c-'a',nw=++ncnt,p=cur;
        ed[ps]=nw;len[nw]=len[cur]+1;cur=nw;
        while(p&&!ch[p][id]) ch[p][id]=nw,p=lnk[p];
        if(!p) return lnk[nw]=1,void();
        int q=ch[p][id];
        if(len[q]==len[p]+1) return lnk[nw]=q,void();
        int cl=++ncnt;len[cl]=len[p]+1;
        lnk[cl]=lnk[q];lnk[q]=lnk[nw]=cl;
        for(int i=0;i<26;i++) ch[cl][i]=ch[q][i];
        while(p&&ch[p][id]==q) ch[p][id]=cl,p=lnk[p];
    }
    void build(){
        for(int i=2;i<=ncnt;i++) g.ins(lnk[i],i);
    }
}
int rt[MAXP+5];
namespace segtree{
    struct node{int ch[2],fst,lst;} s[MAX_ND+5];
    void pushup(int k){
        s[k].fst=(s[k].ch[0])?s[s[k].ch[0]].fst:s[s[k].ch[1]].fst;
        s[k].lst=(s[k].ch[1])?s[s[k].ch[1]].lst:s[s[k].ch[0]].lst;
    }
    int ncnt=0;
    void insert(int &k,int l,int r,int p){
        if(!k) k=++ncnt;
        if(l==r) return s[k].fst=s[k].lst=p,void();
        int mid=l+r>>1;
        if(p<=mid) insert(s[k].ch[0],l,mid,p);
        else insert(s[k].ch[1],mid+1,r,p);
        pushup(k);
    }
    int merge(int x,int y,int l,int r){
        if(!x||!y) return x+y;int z=++ncnt;
        if(l==r) return s[z].fst=s[z].lst=l,z;
        int mid=l+r>>1;
        s[z].ch[0]=merge(s[x].ch[0],s[y].ch[0],l,mid);
        s[z].ch[1]=merge(s[x].ch[1],s[y].ch[1],mid+1,r);
        return pushup(z),z;
    }
    int findgeq(int k,int l,int r,int p){
//        printf("walk %d %d %d %d\n",k,l,r,p);
        if(p>s[k].lst||l>r) return n+1;
        if(l==r) return l;
        int mid=l+r>>1;
        if(!s[k].ch[1]) return findgeq(s[k].ch[0],l,mid,p);
        if(!s[k].ch[0]) return findgeq(s[k].ch[1],mid+1,r,p);
        if(s[s[k].ch[0]].lst>=p) return findgeq(s[k].ch[0],l,mid,p);
        else return findgeq(s[k].ch[1],mid+1,r,p);
    }
}
int fa[MAXP+5][LOG_N+2];
void dfs(int x,int f){
    fa[x][0]=f;//printf("dfs %d %d\n",x,f);
    for(int e=g.hd[x];e;e=g.nxt[e]){
        int y=g.val[e];dfs(y,x);
        rt[x]=segtree::merge(rt[x],rt[y],1,n);
    }
}
int getnxt(int p,int pos){
    if(pos>n) return n+1;
    return segtree::findgeq(rt[p],1,n,pos+1);
}
int getnode(int l,int r){
    int cur=SAM::ed[r];
    for(int i=LOG_N;~i;i--) if(SAM::len[fa[cur][i]]>=r-l+1) cur=fa[cur][i];
    return cur;
}
ll ans[MAXN+5];
vector<pair<pii,int> > qv[B+5];
int nxt[MAXN+5][LOG_N+2];
ll sum[MAXN+5][LOG_N+2];
int main(){
//    freopen("cover.in","r",stdin);freopen("cover.out","w",stdout);
    scanf("%d%d%s%s",&n,&K,s+1,t+1);
    for(int i=1;i<=n;i++) SAM::extend(s[i],i);SAM::cur=1;
    for(int i=1;i<=n;i++) SAM::extend(t[i],i+n);
    SAM::build();
    for(int i=1;i<=n;i++) segtree::insert(rt[SAM::ed[i]],1,n,i);
    dfs(1,0);
    for(int i=1;i<=LOG_N;i++) for(int j=1;j<=SAM::ncnt;j++) fa[j][i]=fa[fa[j][i-1]][i-1];
    int qu;scanf("%d",&qu);
    for(int i=1;i<=qu;i++){
        int s,t,l,r;scanf("%d%d%d%d",&s,&t,&l,&r);
        int cur=getnode(l+n,r+n);
        if(r-l+1>B){
            ll res=0;
            int lim=s+(r-l)-1,len=r-l+1;
            while(1){
                int ps=getnxt(cur,lim);
//                printf("%d\n",ps);
                if(ps>t) break;res+=K-(ps-len+1);
                lim=(ps+len-1);
            } ans[i]=res;
        } else {
            int lim=s+(r-l)-1;
            int nt=getnxt(cur,lim);
            if(nt>t) ans[i]=0;
            else qv[r-l+1].pb(mp(mp(nt,t),i));
        }
    }
    for(int i=1;i<=B;i++){
        memset(nxt,0,sizeof(nxt));
        memset(sum,0,sizeof(sum));
//        printf("dealing %d\n",i);
        for(int j=i;j<=n;j++){
            int nd=getnode(j-i+1,j);
            nxt[j][0]=getnxt(nd,j+i-1);
            if(nxt[j][0]!=n+1) sum[j][0]=K-(nxt[j][0]-i+1);
//            printf("%d %d\n",j,nxt[j][0]);
        } nxt[n+1][0]=n+1;
        for(int j=1;j<=LOG_N;j++){
            for(int k=i;k+(1<<j)-1<=n+1;k++){
                nxt[k][j]=nxt[nxt[k][j-1]][j-1];
                sum[k][j]=sum[k][j-1]+sum[nxt[k][j-1]][j-1];
            }
        }
        for(pair<pii,int> p:qv[i]){
            int pos=p.fi.fi,r=p.fi.se,id=p.se;
//            printf("qry %d %d %d\n",pos,r,id);
            ll res=K-(pos-i+1);
            for(int j=LOG_N;~j;j--) if(nxt[pos][j]<=r&&nxt[pos][j]){
//                printf("%d %lld\n",j,sum[pos][j]);
                res+=sum[pos][j];pos=nxt[pos][j];
            } ans[id]=res;
        }
    }
    for(int i=1;i<=qu;i++) printf("%lld\n",ans[i]);
    return 0;
}