2019 ICPC Asia Nanchang Regional (dsu on tree+treap平衡树)
题意:
给你一棵树,每个点有一个val
让你找树上有多少有序对(x,y)满足以下条件:
1.x!=y
2.x不是y的祖先;y不是x的祖先
3.x与y的最短路径长度<=k
4.x与y的最小公共祖先的值vz,满足2vz=vx+vy
解析:
就是启发式合并,同时建n棵权值treap树,来维护该权值下,有多少深度的点插入进来。
在遍历轻孩子找答案时,每一次,已经确定根和一个点x,z,找y,那么我们只需要找需要权值的treap里面有多少点距离x的
长度<=z。
这里有几点坑的地方:1.treap树的rank操作找的是<=v的最大排名 2.treap的空间复杂度是O(nlogn)所以要开2e6
3.注意在找y的时候,一些计算出来的y的权值是大于n的,所以要判断是否在[0,n]里,不然就会RE
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <string>
#include <vector>
#include <map>
#define lc t[x].l
#define rc t[x].r
using namespace std;
typedef long long ll;
const int N = 1e5+10;
const int MM =1E7+26;
vector<int> g[N];
int val[N];
int sz[N];
bool badboy[N];
int k;
ll ans=0;
int root[N];
struct node{
int l,r,v,w,size,rnd;
}t[MM];
int cnt=0;
int n;
inline void update(int x){t[x].size=t[lc].size+t[rc].size+t[x].w;}
inline void rturn(int &x){
int c=lc;lc=t[c].r;t[c].r=x;
t[c].size=t[x].size;update(x);x=c;
}
inline void lturn(int &x){
int c=rc;rc=t[c].l;t[c].l=x;
t[c].size=t[x].size;update(x);x=c;
}
void ins(int &x,int v){
if(x==0){
cnt++;x=cnt;
t[cnt].l=t[cnt].r=0;t[cnt].size=t[cnt].w=1;
t[cnt].v=v;t[cnt].rnd=rand();
return;
}
t[x].size++;
if(t[x].v==v) {t[x].w++;return;}
if(v<t[x].v){
ins(lc,v);
if(t[lc].rnd<t[x].rnd) rturn(x);
}else{
ins(rc,v);
if(t[rc].rnd<t[x].rnd) lturn(x);
}
}
void del(int &x,int v){
if(x==0) return;
if(t[x].v==v){
if(t[x].w>1){t[x].w--,t[x].size--;return;}
if(lc*rc==0) x=lc+rc;
else if(t[lc].rnd<t[rc].rnd) rturn(x),del(x,v);
else lturn(x),del(x,v);
}else{
t[x].size--;
if(v<t[x].v) del(lc,v);
else del(rc,v);
}
}
int rnk(int x,int v){ //<=v的最大排名
if(!x) return 0;
if(t[x].v==v) return t[lc].size+1;
if(v<t[x].v) return rnk(lc,v);
else return t[lc].size+t[x].w+rnk(rc,v);
}
int kth(int x,int k){
if(x==0) return 0;
if(k<=t[lc].size) return kth(lc,k);
else if(k>t[lc].size+t[x].w) return kth(rc,k-t[lc].size-t[x].w);
else return t[x].v;
}
/************************************************/
int dastan(int v,int p)
{
sz[v]=1;
for(auto &u:g[v])
{
if(u!=p)
dastan(u,v),sz[v]+=sz[u];
}
}
void rem(int v,int dep,int p,int inc)
{
//mp[dep][val[v]]+=inc;
if(inc+1) ins(root[val[v]],dep);
else del(root[val[v]],dep);
for(auto& u:g[v])
{
if(u!=p&&!badboy[u])
rem(u,dep+1,v,inc);
}
}
void add(int v,int dep,int p,int root_dep,int root_val)
{
if(k>dep-root_dep)
{
int ind=k-(dep-root_dep)+root_dep;
int tar=root_val*2-val[v];
if(tar>=0&&tar<=n) {
int va = rnk(root[tar], ind + 1);
if (va && kth(root[tar], va) == ind + 1) va--;
ans += va;
}
}
for(auto& u:g[v])
{
if(u!=p&&!badboy[u]) {
add(u, dep + 1, v, root_dep, root_val);
}
}
}
void dfs(int v,int p,bool hrh,int dep)
{
int mx=0,big=-1;
for(auto& u:g[v])
{
if(u!=p && sz[u]>mx)
mx=sz[u],big=u;
}
for(auto& u:g[v])
{
if(u!=p&u!=big)
dfs(u,v,1,dep+1);
}
if(big+1)
{
dfs(big,v,0,dep+1);
badboy[big]=1;
}
for(auto& u:g[v])
{
if(u!=p&&!badboy[u]) {
add(u, dep + 1, v, dep, val[v]);
rem(u,dep+1,v,1);
}
}
//mp[dep][val[v]]++;
ins(root[val[v]],dep);
if(big+1)
badboy[big]=0;
if(hrh)
{
rem(v,dep,p,-1);
}
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)
scanf("%d",&val[i]);
for(int i=2;i<=n;i++)
{
int u;
scanf("%d",&u);
g[u].push_back(i);
}
dastan(1,-1);
dfs(1,-1,0,0);
printf("%lld\n",(ans<<1));
}