【Codeforces 809E】Surprise me!(莫比乌斯反演 & 虚树)
阅读原文时间:2023年07月09日阅读:3

Description

给定一颗 \(n\) 个顶点的树,顶点 \(i\) 的权值为 \(a_i\)。求:

\[\frac{1}{n(n-1)}\sum_{i=1}^n\sum_{j=1}^n\varphi(a_i\times a_j)\times\text{dist}(i, j)
\]

其中 \(a\) 为一个 \(1\sim n\) 的排列。

Hint

\(1\le n\le 2\times 10^5\)

Solution

据说是套路题 然而我不会这个套路于是我觉得是神题 开一个 blog 记一下。

首先这个 \(\varphi(a_i\times a_j)\) 很烦,因为这样就不能分开计算了,于是考虑拆开:

\[\begin{aligned}
& \varphi(x) = x\prod_{i=1}^s\left(1-\tfrac{1}{p_i} \right),\qquad \varphi(y) = y\prod_{i=1}^t\left(1-\tfrac{1}{q_i} \right), \qquad\varphi(xy) = xy\prod_{i=1}^k\left(1-\tfrac{1}{r_i}\right)\\
& \varphi(x)\times\varphi(y) = xy\prod_{i=1}^s\left(1-\tfrac{1}{p_i} \right)\prod_{i=1}^t\left(1-\tfrac{1}{q_i} \right)
\end{aligned}
\]

观察到 \(\varphi(x)\times \varphi(y)\) 相对 \(\varphi(xy)\) 多乘了一些东西,不难发现,其实就是 \(p\) 和 \(q\) 公共的部分:\(\gcd(x, y)\)。

那么我们除掉这些东西,有:

\[\begin{aligned}
& \varphi(xy) = \frac{\varphi(x)\times \varphi(y)\times \gcd(x, y)}{\varphi(\gcd(x, y))} \\
& \sum_i \sum_j \varphi(a_i\times a_j)\times \text{dist}(i, j) = \sum_i \sum_j \frac{\varphi(a_i)\varphi(a_j) \gcd(a_i, a_j)}{\varphi(\gcd(a_i, a_j))}\text{dist}(i, j)
\end{aligned}
\]


\(\varphi(a_i\times a_j)\) 拆完了,但是冒出了个 \(\gcd\),怎么处理?

于是考虑枚举 \(\gcd\),记为 \(d\)。那么上式 \(=\)

\[\begin{aligned}
& \sum_{d\le n} \sum_i\sum_j [d = \gcd(a_i, a_j)]\frac{\varphi(a_i)\varphi(a_j) d}{\varphi(d)}\text{dist}(i, j)\\
= & \sum_{d\le n} \frac{d}{\varphi(d)} \sum_i\sum_j [d = \gcd(a_i, a_j)]\varphi(a_i)\varphi(a_j)\text{dist}(i, j)
\end{aligned}
\]

然而 \([d = \gcd(a_i, a_j)]\) 还是不好搞。不过我们发现,如果这里是 \([d | \gcd(a_i, a_j)]\) 的话这个艾弗森括号就可以这样消:由于对于每个 \(d\) 都仅有满足 \(\gcd\) 为 \(d\) 的倍数时才有贡献,那么只要计算所有满足 \(d|a_i\) 的 \(i\) 点即可。

不难想到以建虚树的形式“抽出”这些点,由于 \(a\) 为 \(1\sim n\) 的一个排列,那么对于每个 \(d\),都有 \(\frac{n}{d}\) 个点满足条件。显然虚树的总点数不会超过 \(O(n\log n)\)。


扯了这么多,还是先得想办法把等号化成整除才行。

我们定义上面那个式子的后半部分为一个函数 \(f\),同时定义一个长得很像的 \(g\):

\[\begin{aligned}
f(d) &= \sum_i\sum_j [d = \gcd(a_i, a_j)]\varphi(a_i)\varphi(a_j)\text{dist}(i, j)\\
g(d) &= \sum_i\sum_j [d | \gcd(a_i, a_j)]\varphi(a_i)\varphi(a_j)\text{dist}(i, j)
\end{aligned}
\]

可以观察到 \(g(x) = \sum_{x|d} f(d)\),那么可以通过莫比乌斯反演互转化:\(f(x) = \sum_{x|d} g(d)\mu\left(\frac d x\right)\)。只要算出了 \(g\),\(f\) 就能在 \(O(n\log n)\) 计算出来。


考虑 \(g\) 的计算。我们建好虚树 \(T\) 后,就先把艾弗森括号搞掉,为了方便我们记 \(v_i = \varphi(a_i)\):\(g(d) = \sum\limits_{i\in T}\sum\limits_{j\in T}v_iv_j\text{dist}(i, j)\)。

将 \(\text{dist}\) 拆成深度之和减去两倍 LCA 的深度之和的形式,那么有(\(d_i\) 表示点 \(i\) 在原树中的深度):

\[g(d) = \sum_{i\in T}\sum_{j\in T}v_iv_j(d_i+d_j-2d_{\text{LCA}(i, j)}) = 2\sum_{i\in T}v_id_i\sum_{j\in T}v_j-2\sum_{i\in T}\sum_{j\in T} v_iv_jd_{\text{LCA}(i, j)} = X-Y
\]

其中 \(X\) 在求出子树 \(\sum v\) 之后可以轻松得到,\(Y\) 的话只要一波树形 dp 即可。

然后这题就完了,总时间复杂度 \(O(n\log^2 n)\)。

Code

/*
 * Author : _Wallace_
 * Source : https://www.cnblogs.com/-Wallace-/
 * Problem : Codeforces 809E Surprise me!
 */
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <vector>

using namespace std;
const int mod = 1e9 + 7;
const int N = 2e5 + 5;
typedef long long LL;
inline int read() {
  int x(0); char c; while (!isdigit(c = getchar()));
  do x = (x << 1) + (x << 3) + c - 48; while (isdigit(c = getchar()));
  return x;
}

LL fastpow(LL a, LL b) {
  if (!b) return 1ll;
  LL t = fastpow(a, b / 2);
  if (b & 1) return t * t % mod * a % mod;
  else return t * t % mod;
}
LL inv(LL x) {
  return fastpow(x % mod, mod - 2);
}

int p[N], tot = 0;
bool flag[N];
int mu[N], phi[N];
void sieve(int n) {
  mu[1] = phi[1] = 1;
  for (int i = 2; i <= n; i++) {
    if (!flag[i]) mu[i] = -1, phi[i] = i - 1, p[++tot] = i;
    for (int j = 1; j <= tot && i * p[j] <= n; j++) {
      flag[i * p[j]] = 1;
      if (i % p[j] == 0) { mu[i * p[j]] = 0, phi[i * p[j]] = phi[i] * p[j]; break; }
      mu[i * p[j]] = -mu[i], phi[i * p[j]] = phi[i] * (p[j] - 1);
    }
  }
}

int n, a[N], loc[N];
vector<int> adj[N];
LL f[N], g[N];

int fa[N], siz[N], dep[N];
int wson[N], wtop[N];
int dfn[N], timer = 0;
void dfs1(int x, int f) {
  dep[x] = dep[fa[x] = f] + 1, siz[x] = 1;
  for (auto y : adj[x]) if (y != f) {
    dfs1(y, x), siz[x] += siz[y];
    if (siz[wson[x]] < siz[y]) wson[x] = y;
  }
}
void dfs2(int x, int t) {
  wtop[x] = t, dfn[x] = ++timer;
  if (wson[x]) dfs2(wson[x], t);
  for (auto y : adj[x]) if (y != fa[x] && y != wson[x])
    dfs2(y, y);
}
int lca(int x, int y) {
  while (wtop[x] != wtop[y]) {
    if (dep[wtop[x]] < dep[wtop[y]]) swap(x, y);
    x = fa[wtop[x]];
  }
  return dep[x] < dep[y] ? x : y;
}

bool tag[N];
vector<int> vir[N];

template<bool d> inline void link(vector<int>* g, int x, int y) {
  g[x].push_back(y); if (d) g[y].push_back(x);
}
LL sum[N], X, Y;
void dfs(int x) {
  sum[x] = tag[x] ? a[x] : 0ll;
  for (auto y : vir[x]) {
    dfs(y);
    (Y += sum[x] * sum[y] % mod * dep[x] % mod) %= mod;
    (sum[x] += sum[y]) %= mod;
  }
  vir[x].clear(), tag[x] = 0;
}
LL solve(int rt, int* vtx, int cnt) {
  X = Y = 0ll, dfs(rt);
  for (int i = 1; i <= cnt; i++) (X += a[vtx[i]] * 1ll * dep[vtx[i]] % mod * sum[rt] % mod) %= mod;
  (Y *= 2) %= mod;
  for (int i = 1; i <= cnt; i++) (Y = Y + a[vtx[i]] * 1ll * a[vtx[i]] % mod * dep[vtx[i]] % mod + mod) % mod;
  (X *= 2) %= mod, (Y *= 2) %= mod;
  return (X - Y + mod) % mod;
}

signed main() {
  n = read(), sieve(n);
  for (int i = 1; i <= n; i++) loc[a[i] = read()] = i, a[i] = phi[a[i]];
  for (int i = 1, u, v; i < n; i++) u = read(), v = read(), link<1>(adj, u, v);
  dfs1(1, 0), dfs2(1, 1);

  for (int d = 1; d <= n; d++) {
    static int vtx[N]; int cnt = 0;
    for (int x = d; x <= n; x += d) tag[vtx[++cnt] = loc[x]] = 1;
    sort(vtx + 1, vtx + 1 + cnt, [](int a, int b) { return dfn[a] < dfn[b]; });

    static int stk[N]; int top = 0;
    stk[top = 1] = vtx[1];
    for (int i = 2; i <= cnt; i++) {
      int x = vtx[i], l = lca(x, stk[top]);
      for (; top > 1 && dep[l] <= dep[stk[top - 1]]; --top) link<0>(vir, stk[top - 1], stk[top]);
      if (stk[top] != l) link<0>(vir, l, stk[top]), stk[top] = l;
      stk[++top] = x;
    }
    for (; top > 1; --top) link<0>(vir, stk[top - 1], stk[top]);
    g[d] = solve(stk[1], vtx, cnt);
  }

  for (int x = 1; x <= n; x++)
    for (int d = x; d <= n; d += x)
      (f[x] += g[d] * mu[d / x] + mod) %= mod;

  LL ans = 0ll;
  for (int i = 1; i <= n; i++) (ans += i * inv(phi[i]) % mod * f[i] % mod) %= mod;
  printf("%lld\n", ans * inv(n * 1ll * (n - 1)) % mod);
  return 0;
}

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器

你可能感兴趣的文章