ICPC 南昌现场赛 K:Tree(dsu on tree + 动态开点线段树)

Tree

让我们找满足一下五个条件的 ( x , y (x, y (x,y)点对有多少:

  • x ≠ y x \neq y x=y
  • x x x不是 y y y的祖先
  • y y y不是 x x x的祖先
  • d i s ( x , y ) ≤ k dis(x, y)\leq k dis(x,y)k
  • z z z x , y x, y x,y的最近公共祖先, v a l u e x + v a l u e y = 2 v a l u e z value_x + value_y = 2value_z valuex+valuey=2valuez

读题目观察到每个节点的 v a l u e value value只有 [ 0 , 1 0 5 ] [0, 10 ^ 5] [0,105](如果不是的话,也可离散化处理一下吧),所以我们可以建立 1 0 5 10 ^ 5 105棵线段树,每棵线段树里面记录的是点权为 i i i的节点的深度信息,

所以我们只要做一次 d s u   o n   t r e e dsu\ on\ tree dsu on tree,动态维护这颗线段树,然后按照需要查询即可,好像并不是特别难。

#include <bits/stdc++.h>

using namespace std;

const int N = 2e5 + 10;

int head[N], to[N], nex[N], cnt = 1;

int value[N], n, m;

int son[N], sz[N], dep[N], l[N], r[N], rk[N], tot;

int root[N], ls[N << 6], rs[N << 6], sum[N << 6], num;

void add(int x, int y) {
  to[cnt] = y;
  nex[cnt] = head[x];
  head[x] = cnt++;
}

void dfs(int rt, int fa) {
  dep[rt] = dep[fa] + 1, sz[rt] = 1, l[rt] = ++tot, rk[tot] = rt;
  for (int i = head[rt]; i; i = nex[i]) {
    if (to[i] == fa) {
      continue;
    }
    dfs(to[i], rt);
    sz[rt] += sz[to[i]];
    if (!son[rt] || sz[son[rt]] < sz[to[i]]) {
      son[rt] = to[i];
    }
  }
  r[rt] = tot;
}

void push_up(int rt) {
  sum[rt] = sum[ls[rt]] + sum[rs[rt]];
}

void update(int &rt, int l, int r, int x, int value) {
  if (!rt) {
    rt = ++num;
  }
  if (l == r) {
    sum[rt] += value;
    return ;
  }
  int mid = l + r >> 1;
  if (x <= mid) {
    update(ls[rt], l, mid, x, value);
  }
  else {
    update(rs[rt], mid + 1, r, x, value);
  }
  push_up(rt);
}

int query(int rt, int l, int r, int L, int R) {
  if (!rt) {
    return 0;
  }
  if (l >= L && r <= R) {
    return sum[rt];
  }
  int mid = l + r >> 1, ans = 0;
  if (L <= mid) {
    ans += query(ls[rt], l, mid, L, R);
  }
  if (R > mid) {
    ans += query(rs[rt], mid + 1, r, L, R);
  }
  return ans;
}

long long ans;

void dfs(int rt, int fa, bool keep) {
  for (int i = head[rt]; i; i = nex[i]) {
    if (to[i] == fa || to[i] == son[rt]) {
      continue;
    }
    dfs(to[i], rt, 0);
  }
  if (son[rt]) {
    dfs(son[rt], rt, 1);
  }
  int v = 2 * value[rt], d = dep[rt];
  for (int i = head[rt]; i; i = nex[i]) {
    if (to[i] == fa || to[i] == son[rt]) {
      continue;
    }
    for (int j = l[to[i]]; j <= r[to[i]]; j++) {
      int target_v = v - value[rk[j]], last_d = m - (dep[rk[j]] - d);//目标权值,剩下的可延展的距离
      if (target_v < 0 || last_d <= 0) {//如果目标权值小于0或者剩下的可延展距离没有了,提前剪除不合法
        continue;
      }
      int l = d + 1, r = d + last_d;//深度的区间范围,然后查询即可。
      ans += query(root[target_v], 1, n,  l, r);
    }
    for (int j = l[to[i]]; j <= r[to[i]]; j++) {
      update(root[value[rk[j]]], 1, n, dep[rk[j]], 1);
    }
  }
  update(root[value[rt]], 1, n, dep[rt], 1);
  if (!keep) {
    for (int i = l[rt]; i <= r[rt]; i++) {
      update(root[value[rk[i]]], 1, n, dep[rk[i]], -1);
    }
  }
}

int main() {
  // freopen("in.txt", "r", stdin);
  // freopen("out.txt", "w", stdout);
  // ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
  scanf("%d %d", &n, &m);
  for (int i = 1; i <= n; i++) {
    scanf("%d", &value[i]);
  }
  for (int i = 2; i <= n; i++) {
    int x;
    scanf("%d", &x);
    add(x, i);
  }
  dfs(1, 0);
  dfs(1, 0, 1);
  printf("%lld\n", ans * 2);
  return 0;
}