线段树基础模板题
练习题
HDU - 4027 Can you answer these queries? (同时维护单点更新 + 区间更新)
题意:维护两种操作
- 0 a b 代表将区间 [ a , b ] [a,b] [a,b] 内的数值,变为自己的平方根
- 1 a b 代表查询区间 [ a , b ] [a,b] [a,b] 内总值
思路:可以发现,一个数最多暴力更新 7 次,所以就是维护单点更新和区间更新。
#include <bits/stdc++.h>
#define ls (rt<<1)
#define rs (rt<<1|1)
#define ll long long
using namespace std;
const int maxn=1e5+5;
int n,m;
ll a[maxn];
ll sum[maxn<<2];
void build(int rt,int L,int R)
{
if(L==R)
{
sum[rt]=a[L];
return;
}
int mid=(L+R)>>1;
build(ls,L,mid);
build(rs,mid+1,R);
sum[rt]=sum[ls]+sum[rs];
}
void update(int rt,int l,int r,int L,int R)
{
if(l<=L&&R<=r&&R-L+1==sum[rt]) return;
if(L==R)
{
sum[rt]=sqrt(sum[rt]);
return;
}
int mid=(L+R)>>1;
if(l<=mid) update(ls,l,r,L,mid);
if(r>mid) update(rs,l,r,mid+1,R);
sum[rt]=sum[ls]+sum[rs];
}
ll query(int rt,int l,int r,int L,int R)
{
if(l<=L&&R<=r) return sum[rt];
int mid=(L+R)>>1;
ll ans=0;
if(l<=mid) ans+=query(ls,l,r,L,mid);
if(r>mid) ans+=query(rs,l,r,mid+1,R);
return ans;
}
int main()
{
int Case=0;
while(~scanf("%d",&n))
{
for(int i=1; i<=n; ++i) scanf("%lld",&a[i]);
build(1,1,n);
scanf("%d",&m);
printf("Case #%d:\n",++Case);
while(m--)
{
int op,l,r;
scanf("%d%d%d",&op,&l,&r);
if(l>r) swap(l,r);
if(op==0) update(1,l,r,1,n);
else printf("%lld\n",query(1,l,r,1,n));
}
puts("");
}
return 0;
}
POJ - 2528 D - Mayor’s posters (离散化 + 单点查询)
题意:在 [ 1 , n ] [1, n] [1,n] 上染色,每次选择一个区间 [ l , r ] [l,r] [l,r] 染色,问最后能够看到的不同颜色的个数。
思路:区分表示的是线段、还是点。这里表示的是线段,将线段离散化为坐标轴上的点。然后,访问一下整棵树的叶节点,累积答案。
#include <cstdio>
#include <vector>
#include <map>
#include <algorithm>
#define ls (rt<<1)
#define rs (rt<<1|1)
#define fi first
#define se second
#define ll long long
using namespace std;
const int maxn=20000+5;
int t;
vector<int> allx;
map<int,bool> mp;
int st[maxn<<2],ans,total,n;
pair<int,int> p[maxn];
void pushDown(int rt)
{
if(st[rt]!=-1)
{
st[ls]=st[rs]=st[rt];
st[rt]=-1;
}
}
void build(int rt,int L,int R)
{
st[rt]=-1;
if(L==R) return;
int mid=(L+R)>>1;
build(ls,L,mid);
build(rs,mid+1,R);
}
void update(int rt,int l,int r,int L,int R,int val)
{
if(l<=L&&R<=r)
{
st[rt]=val;
return;
}
pushDown(rt);
int mid=(L+R)>>1;
if(l<=mid) update(ls,l,r,L,mid,val);
if(r>mid) update(rs,l,r,mid+1,R,val);
}
void query(int rt,int L,int R)
{
if(st[rt]!=-1)
{
int col=st[rt];
if(!mp[col]) ans++,mp[col]=1;
return;
}
if(L==R) return;
pushDown(rt);
int mid=(L+R)>>1;
query(ls,L,mid);
query(rs,mid+1,R);
}
int main()
{
scanf("%d",&t);
while(t--)
{
allx.clear();
mp.clear();
scanf("%d",&n);
int l,r;
for(int i=1; i<=n; ++i)
{
scanf("%d%d",&l,&r);
l--;
p[i].fi=l;
p[i].se=r;
allx.push_back(l);
allx.push_back(r);
}
sort(allx.begin(),allx.end());
allx.resize(unique(allx.begin(),allx.end())-allx.begin());
total=allx.size()-1;
build(1,1,total);
for(int i=1; i<=n; ++i)
{
int l=lower_bound(allx.begin(),allx.end(),p[i].fi)-allx.begin()+1;
int r=lower_bound(allx.begin(),allx.end(),p[i].se)-allx.begin();
update(1,l,r,1,total,i);
}
ans=0;
query(1,1,total);
printf("%d\n",ans);
}
return 0;
}
HDU - 1540 Tunnel Warfare (求最大连续区间)
题意:有 n 个村庄。维护三种操作
- D x:摧毁 x
- Q x:查询与 x 相连的村庄数量
- R :恢复上一个被摧毁的村庄
思路:
- 线段树求最大连续区间的长度
方法一
- 维护区间从左端点开始最大连续个数 l,区间从右端点开始最大连续的个数 r。
- query 的时候,找的是包含 x 的位置,如果找到叶节点也找不到,那么这个点就是不存在的。否则,必然会被包含在 [ m i d − s t [ l s ] . r + 1 , m i d + s t [ r s ] . l ] [mid-st[ls].r+1,mid+st[rs].l] [mid−st[ls].r+1,mid+st[rs].l] 这个范围内。
方法二
- 也可以在多维护一个值 mx,表示区间的最大连续个数。mx 在向上合并的时候,注意可能会被 s t [ l s ] . r + s t [ r s ] . l st[ls].r+st[rs].l st[ls].r+st[rs].l 更新。
- 查询的时候,注意 pos 是否落在 [ m i d − s t [ l s ] . r + 1 , m i d ] [mid-st[ls].r+1,mid] [mid−st[ls].r+1,mid] 和 [ m i d + 1 , m i d + s t [ r s ] . l ] [mid+1,mid+st[rs].l] [mid+1,mid+st[rs].l] 这两个区间内,若在就可以往左往右延伸查询。
方法三
- 用maxx表示区间内被破坏村子最大的序号,用minn表示区间内被破坏村子最小的序号。叶节点维护最大值maxx和最小值minn,恢复村子时maxx=0,minn=n+1,破坏村子时maxx=c,minn=c。根节点维护区间的最大值maxx和最小值minn
- 每当查询一个点时,就查询该点左边区间的最大的被破坏村子maxx,以及该点右边的最小的被破坏的村子minn。统计答案,就是minn-maxx-1,当然如果是自己被破坏了答案就是minn-maxx
#include <bits/stdc++.h>
#define ls (rt<<1)
#define rs (rt<<1|1)
#define ll long long
using namespace std;
const int maxn=50000+5;
int n,q,x;
int sta[maxn],top;
char op[5];
struct Node
{
int l,r;
} st[maxn<<2];
void pushUp(int rt,int L,int R)
{
int mid=(L+R)>>1;
st[rt].l=st[ls].l;
st[rt].r=st[rs].r;
if(st[ls].l==mid-L+1) st[rt].l+=st[rs].l;
if(st[rs].r==R-mid) st[rt].r+=st[ls].r;
}
void build(int rt,int L,int R)
{
st[rt].l=st[rt].r=R-L+1;
if(L==R) return ;
int mid=(L+R)>>1;
build(ls,L,mid);
build(rs,mid+1,R);
}
void update(int rt,int pos,int L,int R,int v)
{
if(L==R)
{
st[rt].l=st[rt].r=v;
return;
}
int mid=(L+R)>>1;
if(pos<=mid) update(ls,pos,L,mid,v);
else update(rs,pos,mid+1,R,v);
pushUp(rt,L,R);
}
int query(int rt,int pos,int L,int R)
{
if(L==R) return 0;
int mid=(L+R)>>1;
if(pos>=mid-st[ls].r+1&&pos<=mid+st[rs].l)
return st[ls].r+st[rs].l;
if(pos<=mid) return query(ls,pos,L,mid);
else return query(rs,pos,mid+1,R);
}
int main()
{
while(~scanf("%d%d",&n,&q))
{
build(1,1,n);
while(q--)
{
scanf("%s",op);
if(op[0]=='D')
{
scanf("%d",&x);
sta[++top]=x;
update(1,x,1,n,0);
}
else if(op[0]=='R')
{
x=sta[top--];
update(1,x,1,n,1);
}
else
{
scanf("%d",&x);
printf("%d\n",query(1,x,1,n));
}
}
}
return 0;
}
维护三个值
#include <bits/stdc++.h>
#define ls (rt<<1)
#define rs (rt<<1|1)
#define ll long long
using namespace std;
const int maxn=50000+5;
int n,q,x;
int sta[maxn],top;
char op[5];
struct Node
{
int l,r,mx;
} st[maxn<<2];
void pushUp(int rt,int L,int R)
{
int mid=(L+R)>>1;
st[rt].l=st[ls].l;
st[rt].r=st[rs].r;
if(st[ls].l==mid-L+1) st[rt].l+=st[rs].l;
if(st[rs].r==R-mid) st[rt].r+=st[ls].r;
st[rt].mx=max({st[ls].mx,st[rs].mx,st[rt].l,st[rt].r,st[ls].r+st[rs].l});
}
void build(int rt,int L,int R)
{
st[rt].l=st[rt].r=st[rt].mx=R-L+1;
if(L==R) return ;
int mid=(L+R)>>1;
build(ls,L,mid);
build(rs,mid+1,R);
}
void update(int rt,int pos,int L,int R,int v)
{
if(L==R)
{
st[rt].l=st[rt].r=st[rt].mx=v;
return;
}
int mid=(L+R)>>1;
if(pos<=mid) update(ls,pos,L,mid,v);
else update(rs,pos,mid+1,R,v);
pushUp(rt,L,R);
}
int query(int rt,int pos,int L,int R)
{
if(L==R||R-L+1==st[rt].mx) return st[rt].mx;
int mid=(L+R)>>1;
if(pos<=mid)
{
if(pos>=mid-st[ls].r+1) return query(ls,pos,L,mid)+query(rs,mid+1,mid+1,R);
else return query(ls,pos,L,mid);
}
else
{
if(pos<=mid+st[rs].l) return query(ls,mid,L,mid)+query(rs,pos,mid+1,R);
else return query(rs,pos,mid+1,R);
}
}
int main()
{
while(~scanf("%d%d",&n,&q))
{
build(1,1,n);
while(q--)
{
scanf("%s",op);
if(op[0]=='D')
{
scanf("%d",&x);
sta[++top]=x;
update(1,x,1,n,0);
}
else if(op[0]=='R')
{
x=sta[top--];
update(1,x,1,n,1);
}
else
{
scanf("%d",&x);
printf("%d\n",query(1,x,1,n));
}
}
}
return 0;
}