【Luogu P5168】xtq玩魔塔(Kruskal 重构树 & 树状数组 & set)
阅读原文时间:2023年07月08日阅读:1

Description

给定一个 \(n\) 个顶点,\(m\) 条边的无向联通图,点、边带权。

先有 \(q\) 次修改或询问,每个指令形如 \(\text{opt}\ x\ y\):

  • \(\text{opt}=1\):将顶点 \(x\) 的点权修改为 \(y\);
  • \(\text{opt}=2\):查询顶点 \(x, y\) 间所有路径中路径上最大值中,最小的哪一个最大值(瓶颈路)。
  • \(\text{opt}=3\):查询顶点 \(x\) 可以结果边权 \(\le y\) 的边能到达的所有点上有几种不同的点权。

Hint

  • \(1\le n\le 10^5, 1\le m\le 3\times 10^5, 1\le q\le 2\times 10^5\)
  • \(\text{点权、边权}\in[0, 2^{31})\)

Solution

首先对于 \(\text{opt}=2\) 的操作,这是个经典问题,我们有很多解决思路。但是看到操作三就发现 Kruskal 重构树才是最好的选择。

我们假设没有操作一,那么操作二可以转化为重构树上两个结点的 LCA,操作三则是子树数颜色。

考虑到一颗子树的 dfs 序连续,那么这又可以转化为序列问题,于是成了区间数颜色。

然而在加上修改操作操作三就变的棘手了,或者树套树应该也能过但肯定不好写。


一看清一色待修莫队,感觉这个题可以不用这样麻烦。

之后在题解区发现了 mrsrz 的题解 的一只 \(\log\) 处理方法,感觉很妙,于是学习一波。

对于一个在结点 \(x\) 刚插入的一种颜色 \(c\),它可以贡献的范围是 \(x\) 的一个深度最浅的祖先 \(a\) 满足以 \(a\) 为根的子树中原本不存在任何一个颜色 \(c\)。于是我们就可以在这上面做链加。大力树剖加树状数组是 \(O(\log^2n)\) 一次的,但直接树上差分则可以做到 \(O(\log n)\)。具体地,我们在结点 \(x\) 的位置 \(+1\),然后在 \(a\) 的父亲上 \(-1\),因为它可以贡献到的最高的位置是 \(a\)。

然后就是如何找到这样一个 \(a\) 的问题。其实这个不难处理,这个 \(a\) 必然是重构树上 dfs 序与 \(x\) 相邻的两个结点(有可能一个)\(y_1, y_2\) 的两个 \(\text{LCA}(x, y_1), \text{LCA}(x, y_2)\) 中,深度较深的那一个。如果对虚树比较熟那么这就很显然。

具体实现时,为了找到 dfs 序相邻的点,我们对每一种颜色开一个 std::set,存这个颜色的所有结点并按 dfs 序排序。

这样总复杂度是 \(O((n+m+q)\log n)\) 的。

Code

这个题非常码农,所以写的有点长。不过思路还是很清晰的。

/*
 * Author : _Wallace_
 * Source : https://www.cnblogs.com/-Wallace-/
 * Problem : Luogu P5168 xtq玩魔塔
 */
#include <algorithm>
#include <cstdio>
#include <cctype>
#include <map>
#include <set>
#include <vector>

inline int read() {
  int x(0), s(0); char c; while (!isgraph(c = getchar()));
  if (x == '-') s = 1, c = getchar();
  do x = (x << 1) + (x << 3) + c - 48; while (isdigit(c = getchar()));
  return s ? -x : x;
}

const int N = 1e5 + 5;
const int M = 3e5 + 5;
const int Q = 2e5 + 5;
const int V = N << 1;
const int logN = 19;

int n, m, q, a[N];
struct Edge {
  int u, v, w;
  bool operator < (const Edge& rhs) const {
    return w < rhs.w;
  }
} e[M];

int uset[V];
int find(int x) {
  return x == uset[x] ? x : uset[x] = find(uset[x]);
}

int vcnt;
int ch[V][2], fa[V][logN], val[V];

int timer(0);
int dfn[V], siz[V], dep[V];

void dfs(int x) {
  dfn[x] = ++timer, siz[x] = 1, dep[x] = dep[fa[x][0]] + 1;
  if (!ch[x][0] && !ch[x][1]) return;
  dfs(ch[x][0]), dfs(ch[x][1]), siz[x] += siz[ch[x][0]] + siz[ch[x][1]];
}
int lca(int x, int y) {
  if (dep[x] < dep[y]) std::swap(x, y);
  for (int j = logN - 1; ~j; j--)
    if (dep[fa[x][j]] >= dep[y]) x = fa[x][j];
  if (x == y) return x;
  for (int j = logN - 1; ~j; j--)
    if (fa[x][j] != fa[y][j]) x = fa[x][j], y = fa[y][j];
  return fa[x][0];
}
int getanc(int x, int y) {
  for (int j = logN - 1; ~j; --j)
    if (fa[x][j] && val[fa[x][j]] <= y) x = fa[x][j];
  return x;
}

namespace bit {
  int tr[V];
  void add(int p, int v) {
    for (; p <= vcnt; p += p & -p) tr[p] += v;
  }
  int get(int p) {
    int v(0);
    for (; p; p -= p & -p) v += tr[p];
    return v;
  }
}

struct cmp {
  bool operator () (const int& a, const int& b) {
    return dfn[a] < dfn[b];
  }
};
int col_tot(0);
std::map<int, int> idx;
std::set<int, cmp> pos[V + Q];

int getIdx(int col) {
  return idx.count(col) ? idx[col] : idx[col] = ++col_tot;
}
void update_col(int x, int c) {
  bit::add(dfn[x], 1);
  std::set<int>::iterator it = pos[c = getIdx(c)].insert(x).first;
  if (pos[c].size() == 1u) return;

  std::vector<int> adj; adj.reserve(2);
  if (++it != pos[c].end()) adj.push_back(*it);
  if (--it != pos[c].begin()) adj.push_back(*--it);

  std::pair<int, int> y;
  for (int i = 0; i < (int)adj.size(); i++) {
    int l = lca(x, adj[i]);
    y = std::max(y, std::make_pair(dep[l], l));
  }
  bit::add(dfn[y.second], -1);
}
void remove_col(int x, int c) {
  bit::add(dfn[x], -1);
  std::set<int>::iterator it = pos[c = getIdx(c)].find(x);
  if (pos[c].size() == 1u) { pos[c].erase(it); return; }

  std::vector<int> adj; adj.reserve(2);
  if (++it != pos[c].end()) adj.push_back(*it);
  if (--it != pos[c].begin()) adj.push_back(*--it), ++it;
  pos[c].erase(it);

  std::pair<int, int> y;
  for (int i = 0; i < (int)adj.size(); i++) {
    int l = lca(x, adj[i]);
    y = std::max(y, std::make_pair(dep[l], l));
  }
  bit::add(dfn[y.second], 1);
}

int count(int x, int y) {
  x = getanc(x, y);
  return bit::get(dfn[x] + siz[x] - 1) - bit::get(dfn[x] - 1);
}

signed main() {
  n = read(), m = read(), q = read();
  for (int i = 1; i <= n; i++) a[i] = read();
  for (int i = 1; i <= m; i++) e[i].u = read(), e[i].v = read(), e[i].w = read();

  std::sort(e + 1, e + 1 + m), vcnt = n;
  for (int i = 1; i <= n; i++) uset[i] = i;
  for (int i = 1; i <= m && vcnt != n * 2 - 1; i++) {
    int u = find(e[i].u), v = find(e[i].v);
    if (u == v) continue;
    val[++vcnt] = e[i].w, uset[u] = uset[v] = uset[vcnt] = vcnt;
    fa[ch[vcnt][0] = u][0] = fa[ch[vcnt][1] = v][0] = vcnt;
  }
  for (int j = 1; j < logN; j++)
    for (int i = 1; i <= vcnt; i++)
      fa[i][j] = fa[fa[i][j - 1]][j - 1];
  dfs(vcnt);

  for (int i = 1; i <= n; i++)
    update_col(i, a[i]);

  while (q--) {
    int opt = read(), x = read(), y = read();
    if (opt == 1) remove_col(x, a[x]), update_col(x, a[x] = y);
    if (opt == 2) printf("%d\n", val[lca(x, y)]);
    if (opt == 3) printf("%d\n", count(x, y));
  }
  return 0;
}