Codeforces Round #833 (Div. 2)E. Yet Another Array Counting Problem(笛卡尔树+树形DP)

题目链接:Problem - E - Codeforces

 样例输入:

4
3 3
1 3 2
4 2
2 2 2 2
6 9
6 9 6 9 6 9
9 100
10 40 20 20 100 60 80 60 60

样例输出:

8
5
11880
351025663

题意:给定一个长度为n的数组a[],对于每一个区间[l,r],这个区间的leftmost定义为区间[l,r]中的值等于该区间内元素的最大值的最小下标,现在问我们有多少种长度为n的数组b[],满足对于任意区间都有leftmost值等于数组a的对应区间的leftmost值,且b数组中的元素值是介于1到m之间的。

分析:其实这道题要先知道笛卡尔树这个预备知识。

笛卡尔树是二叉排序树和堆的结合。

对于一个笛卡尔树,对笛卡尔树进行中序遍历即可得到原序列。

笛卡尔树中每个节点的左子树上的节点的坐标都是小于该节点的,右子树上的节点的坐标都是大于该节点的。而且每个节点的值是大于/小于其子树上的节点的值的,这个取决于是和大根堆结合还是和小根堆结合。

对于这道题目而言,我们就让根节点的值大于其左子节点的值,大于等于其右子节点的值,那么对于a数组和b数组需要构造出相同形态的笛卡尔树才算是满足题意的。所以我们可以按照a数组构造出笛卡尔树,然后利用树形DP向每个节点进行填值求方案数即可。

建树的过程,可以发现

每个右子树根节点的父亲节点就是他左边第一个大于他的数的位置

每个节点的左子树根节点就是他左边小于该值的数中的最大值的最小下标

这个显然可以用单调栈来维护,那么建树复杂度就是O(n),细节可以见代码

最后就是根据笛卡尔树求解方案数了。

设f[x][y]代表x节点取值为1~y时的方案数和。那么更新过程就有四种情况:

x节点有左右子树,那么就有f[x][y]=f[x][y-1]+f[l[x]][y-1]*f[r[x]][y]

x节点只有左子树,那么就有f[x][y]=f[x][y-1]+f[l[x]][y-1]

x节点只有右子树,那么就有f[x][y]=f[x][y-1]+f[r[x]][y]

如果没有孩子节点,直接返回y

利用这个方法我们即可完成本道题目的更新。

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<map>
#include<queue>
#include<vector>
#include<cmath>
using namespace std;
const int N=2e6+10,mod=1e9+7;
int l[N],r[N],fa[N];
int s[N],top;
long long f[N],v[N];
int n,m;
void build(int n)
{
	top=0;
	for(int i=1;i<=n;i++)
		r[i]=l[i]=0;
    for(int i=1;i<=n;i++)
	{
		scanf("%d",&v[i]);
        while(top&&v[s[top]]<v[i])    l[i]=s[top],top--;//i节点的左子树的根节点就是下标小于i且值小于i的数的最大值对应的下标 
        fa[i]=s[top];//i节点的父节点就是他左边第一个值大于等于他的下标 
		fa[l[i]]=i;
        if(fa[i]) r[fa[i]]=i;
        s[++top]=i;
    }
}
int find(int x,int y)
{
	return (x-1)*m+y;
}
long long dp(int x,int R)//代表节点x的可取值范围为[1,R]的方案数 
{
	if(R<1) return 0;
	if((!l[x])&&(!r[x])) return R; 
	if(f[find(x,R)]!=-1) return f[find(x,R)];
	long long ans=0;
	if(l[x]&&r[x]) ans=(dp(x,R-1)+dp(l[x],R-1)*dp(r[x],R))%mod;
	else if(l[x]) ans=(dp(x,R-1)+dp(l[x],R-1))%mod;
	else if(r[x]) ans=(dp(x,R-1)+dp(r[x],R))%mod;
	return f[find(x,R)]=ans;
}
int main()
{
	int T;
	cin>>T;
	while(T--)
	{
		scanf("%d%d",&n,&m);
		for(int i=1;i<=n*m;i++)
			f[i]=-1;
		build(n);
		int root,mx=0;
		for(int i=1;i<=n;i++)
			if(v[i]>mx) mx=v[i],root=i;
		printf("%lld\n",dp(root,m));
	}
	return 0;
}