线段树基础模板题

HDU - 4027 Can you answer these queries? (同时维护单点更新 + 区间更新)

题意:维护两种操作

  • 0 a b 代表将区间 [ a , b ] [a,b] [a,b] 内的数值,变为自己的平方根
  • 1 a b 代表查询区间 [ a , b ] [a,b] [a,b] 内总值

思路:可以发现,一个数最多暴力更新 7 次,所以就是维护单点更新和区间更新。

#include <bits/stdc++.h>
#define ls (rt<<1)
#define rs (rt<<1|1)
#define ll long long
using namespace std;
const int maxn=1e5+5;

int n,m;
ll a[maxn];
ll sum[maxn<<2];

void build(int rt,int L,int R)
{
    if(L==R)
    {
        sum[rt]=a[L];
        return;
    }
    int mid=(L+R)>>1;
    build(ls,L,mid);
    build(rs,mid+1,R);
    sum[rt]=sum[ls]+sum[rs];
}

void update(int rt,int l,int r,int L,int R)
{
    if(l<=L&&R<=r&&R-L+1==sum[rt]) return;
    if(L==R)
    {
        sum[rt]=sqrt(sum[rt]);
        return;
    }
    int mid=(L+R)>>1;
    if(l<=mid) update(ls,l,r,L,mid);
    if(r>mid) update(rs,l,r,mid+1,R);
    sum[rt]=sum[ls]+sum[rs];
}
ll query(int rt,int l,int r,int L,int R)
{
    if(l<=L&&R<=r) return sum[rt];
    int mid=(L+R)>>1;
    ll ans=0;
    if(l<=mid) ans+=query(ls,l,r,L,mid);
    if(r>mid) ans+=query(rs,l,r,mid+1,R);
    return ans;
}

int main()
{
    int Case=0;
    while(~scanf("%d",&n))
    {
        for(int i=1; i<=n; ++i) scanf("%lld",&a[i]);
        build(1,1,n);
        scanf("%d",&m);
        printf("Case #%d:\n",++Case);
        while(m--)
        {
            int op,l,r;
            scanf("%d%d%d",&op,&l,&r);
            if(l>r) swap(l,r);
            if(op==0) update(1,l,r,1,n);
            else printf("%lld\n",query(1,l,r,1,n));
        }
        puts("");
    }
    return 0;
}

POJ - 2528 D - Mayor’s posters (离散化 + 单点查询)

题意:在 [ 1 , n ] [1, n] [1,n] 上染色,每次选择一个区间 [ l , r ] [l,r] [l,r] 染色,问最后能够看到的不同颜色的个数。

思路:区分表示的是线段、还是点。这里表示的是线段,将线段离散化为坐标轴上的点。然后,访问一下整棵树的叶节点,累积答案。

#include <cstdio>
#include <vector>
#include <map>
#include <algorithm>
#define ls (rt<<1)
#define rs (rt<<1|1)
#define fi first
#define se second
#define ll long long
using namespace std;
const int maxn=20000+5;

int t;
vector<int> allx;
map<int,bool> mp;
int st[maxn<<2],ans,total,n;
pair<int,int> p[maxn];

void pushDown(int rt)
{
    if(st[rt]!=-1)
    {
        st[ls]=st[rs]=st[rt];
        st[rt]=-1;
    }
}
void build(int rt,int L,int R)
{
    st[rt]=-1;
    if(L==R) return;
    int mid=(L+R)>>1;
    build(ls,L,mid);
    build(rs,mid+1,R);
}

void update(int rt,int l,int r,int L,int R,int val)
{
    if(l<=L&&R<=r)
    {
        st[rt]=val;
        return;
    }
    pushDown(rt);
    int mid=(L+R)>>1;
    if(l<=mid) update(ls,l,r,L,mid,val);
    if(r>mid) update(rs,l,r,mid+1,R,val);
}
void query(int rt,int L,int R)
{
    if(st[rt]!=-1)
    {
        int col=st[rt];
        if(!mp[col]) ans++,mp[col]=1;
        return;
    }
    if(L==R) return;
    pushDown(rt);
    int mid=(L+R)>>1;
    query(ls,L,mid);
    query(rs,mid+1,R);
}
int main()
{
    scanf("%d",&t);
    while(t--)
    {
        allx.clear();
        mp.clear();
        scanf("%d",&n);
        int l,r;
        for(int i=1; i<=n; ++i)
        {
            scanf("%d%d",&l,&r);
            l--;
            p[i].fi=l;
            p[i].se=r;
            allx.push_back(l);
            allx.push_back(r);
        }
        sort(allx.begin(),allx.end());
        allx.resize(unique(allx.begin(),allx.end())-allx.begin());
        total=allx.size()-1;
        build(1,1,total);
        for(int i=1; i<=n; ++i)
        {
            int l=lower_bound(allx.begin(),allx.end(),p[i].fi)-allx.begin()+1;
            int r=lower_bound(allx.begin(),allx.end(),p[i].se)-allx.begin();
            update(1,l,r,1,total,i);
        }
        ans=0;
        query(1,1,total);
        printf("%d\n",ans);
    }
    return 0;
}

HDU - 1540 Tunnel Warfare (求最大连续区间)

题意:有 n 个村庄。维护三种操作

  • D x:摧毁 x
  • Q x:查询与 x 相连的村庄数量
  • R :恢复上一个被摧毁的村庄

思路

  • 线段树求最大连续区间的长度

方法一

  • 维护区间从左端点开始最大连续个数 l,区间从右端点开始最大连续的个数 r。
  • query 的时候,找的是包含 x 的位置,如果找到叶节点也找不到,那么这个点就是不存在的。否则,必然会被包含在 [ m i d − s t [ l s ] . r + 1 , m i d + s t [ r s ] . l ] [mid-st[ls].r+1,mid+st[rs].l] [midst[ls].r+1,mid+st[rs].l] 这个范围内。

方法二

  • 也可以在多维护一个值 mx,表示区间的最大连续个数。mx 在向上合并的时候,注意可能会被 s t [ l s ] . r + s t [ r s ] . l st[ls].r+st[rs].l st[ls].r+st[rs].l 更新。
  • 查询的时候,注意 pos 是否落在 [ m i d − s t [ l s ] . r + 1 , m i d ] [mid-st[ls].r+1,mid] [midst[ls].r+1,mid] [ m i d + 1 , m i d + s t [ r s ] . l ] [mid+1,mid+st[rs].l] [mid+1,mid+st[rs].l] 这两个区间内,若在就可以往左往右延伸查询。

方法三

  • 用maxx表示区间内被破坏村子最大的序号,用minn表示区间内被破坏村子最小的序号。叶节点维护最大值maxx和最小值minn,恢复村子时maxx=0,minn=n+1,破坏村子时maxx=c,minn=c。根节点维护区间的最大值maxx和最小值minn
  • 每当查询一个点时,就查询该点左边区间的最大的被破坏村子maxx,以及该点右边的最小的被破坏的村子minn。统计答案,就是minn-maxx-1,当然如果是自己被破坏了答案就是minn-maxx
#include <bits/stdc++.h>
#define ls (rt<<1)
#define rs (rt<<1|1)
#define ll long long
using namespace std;
const int maxn=50000+5;

int n,q,x;
int sta[maxn],top;
char op[5];

struct Node
{
    int l,r;
} st[maxn<<2];

void pushUp(int rt,int L,int R)
{
    int mid=(L+R)>>1;
    st[rt].l=st[ls].l;
    st[rt].r=st[rs].r;
    if(st[ls].l==mid-L+1) st[rt].l+=st[rs].l;
    if(st[rs].r==R-mid) st[rt].r+=st[ls].r;
}
void build(int rt,int L,int R)
{
    st[rt].l=st[rt].r=R-L+1;
    if(L==R) return ;
    int mid=(L+R)>>1;
    build(ls,L,mid);
    build(rs,mid+1,R);
}
void update(int rt,int pos,int L,int R,int v)
{
    if(L==R)
    {
        st[rt].l=st[rt].r=v;
        return;
    }
    int mid=(L+R)>>1;
    if(pos<=mid) update(ls,pos,L,mid,v);
    else update(rs,pos,mid+1,R,v);
    pushUp(rt,L,R);
}
int query(int rt,int pos,int L,int R)
{
    if(L==R) return 0;
    int mid=(L+R)>>1;
    if(pos>=mid-st[ls].r+1&&pos<=mid+st[rs].l)
        return st[ls].r+st[rs].l;
    if(pos<=mid) return query(ls,pos,L,mid);
    else return query(rs,pos,mid+1,R);
}
int main()
{
    while(~scanf("%d%d",&n,&q))
    {
        build(1,1,n);
        while(q--)
        {
            scanf("%s",op);
            if(op[0]=='D')
            {
                scanf("%d",&x);
                sta[++top]=x;
                update(1,x,1,n,0);
            }
            else if(op[0]=='R')
            {
                x=sta[top--];
                update(1,x,1,n,1);
            }
            else
            {
                scanf("%d",&x);
                printf("%d\n",query(1,x,1,n));
            }
        }
    }
    return 0;
}

维护三个值

#include <bits/stdc++.h>
#define ls (rt<<1)
#define rs (rt<<1|1)
#define ll long long
using namespace std;
const int maxn=50000+5;

int n,q,x;
int sta[maxn],top;
char op[5];

struct Node
{
    int l,r,mx;
} st[maxn<<2];

void pushUp(int rt,int L,int R)
{
    int mid=(L+R)>>1;
    st[rt].l=st[ls].l;
    st[rt].r=st[rs].r;
    if(st[ls].l==mid-L+1) st[rt].l+=st[rs].l;
    if(st[rs].r==R-mid) st[rt].r+=st[ls].r;

    st[rt].mx=max({st[ls].mx,st[rs].mx,st[rt].l,st[rt].r,st[ls].r+st[rs].l});
}
void build(int rt,int L,int R)
{
    st[rt].l=st[rt].r=st[rt].mx=R-L+1;
    if(L==R) return ;
    int mid=(L+R)>>1;
    build(ls,L,mid);
    build(rs,mid+1,R);
}
void update(int rt,int pos,int L,int R,int v)
{
    if(L==R)
    {
        st[rt].l=st[rt].r=st[rt].mx=v;
        return;
    }
    int mid=(L+R)>>1;
    if(pos<=mid) update(ls,pos,L,mid,v);
    else update(rs,pos,mid+1,R,v);
    pushUp(rt,L,R);
}
int query(int rt,int pos,int L,int R)
{
    if(L==R||R-L+1==st[rt].mx) return st[rt].mx;
    int mid=(L+R)>>1;
    if(pos<=mid)
    {
        if(pos>=mid-st[ls].r+1) return query(ls,pos,L,mid)+query(rs,mid+1,mid+1,R);
        else return query(ls,pos,L,mid);
    }
    else
    {
        if(pos<=mid+st[rs].l) return query(ls,mid,L,mid)+query(rs,pos,mid+1,R);
        else return query(rs,pos,mid+1,R);
    }
}
int main()
{
    while(~scanf("%d%d",&n,&q))
    {
        build(1,1,n);
        while(q--)
        {
            scanf("%s",op);
            if(op[0]=='D')
            {
                scanf("%d",&x);
                sta[++top]=x;
                update(1,x,1,n,0);
            }
            else if(op[0]=='R')
            {
                x=sta[top--];
                update(1,x,1,n,1);
            }
            else
            {
                scanf("%d",&x);
                printf("%d\n",query(1,x,1,n));
            }
        }
    }
    return 0;
}