← 返回

树状数组学习记录

树状数组1

树状数组是一种可以实现 单点修改和区间查询 的数据结构,时间复杂度均为 $O(\log n)$。

相比同样可以实现区间修改的线段树,树状数组只能实现单点修改和区间查询,不过代码远比线段树少,且常数也更小,所以很有必要掌握。

基本原理

树状数组的基本原理是将一段数组 $[1,n]$ 拆分为 不多于 $\log n$ 段区间,要求这 $\log n$ 段区间是已知的。这样我们就可以通过合并这 $\log n$ 段区间求解出答案。

图示

那么我们可以用一个数组 $t_i$ 来表示存储一段右边界是 $x$ 的区间信息,那么如何确定其左边界呢?我们规定树状数组中,$t_i$ 表示的区间的长度为 $2^k$,其中:

我们规定 $x$ 的二进制最低位 1 及其后面所有的 0 组成的数为 $lowbit(x)$,那么 $t_x$ 表示的区间就是 $[x-lowbit(x)+1,x]$。

lowbit计算方法

首先给出式子:

$$lowbit(x)=x& (-x)$$

这个式子的原理为:首先由于 -x=~x+1(即按位取反再加一),设 $x$ 的原二进制编码为 (...)10...00,取反之后得到 [...]01...11,加 1 后得到 [...]10...00,注意这里的 (...)[...] 完全相反,所以得出 x & -x =(...)10...00 & [...]10...00 = 10...00,也就是我们要求的 lowbit

区间查询

考虑求解 $[l,r]$ 之和,我们可以将其转化为求解 $sum(r)-sum(l-1)$,($sum(x)$ 表示区间 $[1,x]$ 的和)。

那么现在我们只需要考虑如何求解 $sum(x)$ 即可。

我们可以将区间 $[1,x]$ 拆分为不多于 $\log n$ 段区间,那么怎么拆分呢?显然我们可以利用刚才的 lowbit 来求解:

int sum(int x){
    int res=0;
    while(x){
        res+=t[x];
        x-=lowbit(x);
    }
    return res;
}

单点修改

考虑如何单点修改 $a_x$。

我们可以参考构建完成的树状数组图,不难发现,单点修改 $a_x$ 影响的 $t_y$ 在一条链上,所以我们可以从 $x$ 开始往他的父亲节点跳,直至跳到根节点为止。

设数组 $a$ 的长度为 $n$,我们可以这样修改 $a_x$:

void add(int x,int y){
    a[x]+=y;
    while(x<=N){
        t[x]+=y;
        x+=lowbit(x);
    }
}

一些性质

Code

struct node
{
#define lowbit(x) (x & (-x))
    int t[MaxN];
    void init()
    {
        memset(t, 0, sizeof t);
        for (int i = 1; i <= N; i++)
        {
            for (int j = i - lowbit(i) + 1; j <= i; j++)
            {
                t[i] += a[j];
            }
        }
    }
    void add(int x, int y)
    {
        a[x] += y;
        while (x <= N)
        {
            t[x] += y;
            x += lowbit(x);
        }
    }
    int sum(int x)
    {
        int res = 0;
        while (x)
        {
            res += t[x];
            x -= lowbit(x);
        }
        return res;
    }
    int query(int x,int y){
        return sum(y)-sum(x-1);
    }
    void print(){
        for(int i=1;i<=N;i++){
            cerr<<t[i]<<" ";
        }
        cerr<<endl;
    }
}Tree;

树状数组2(区间修改+单点查询)

我们可以利用差分来求解。

为了求解区间修改,我们考虑数组 $a$ 的差分数组 $d$,其中 $d_i=a_{i}-a_{i-1}$。

那么区间修改操作(即修改 $[l,r]$ 的值),我们只需要更改 $d_l\leftarrow d_l+k,d_{r+1}\leftarrow d_{r+1}-k$。

同理,单点查询操作(即查询 $a_x$ 的值),我们只需要求解前 $x$ 个 $d_i$ 的和即可。

那么我们只需要用树状数组维护差分数组 $d_i$ 求解即可。

Code

#include <bits/stdc++.h>
using namespace std;
// #define MULTI_CASES
#define ll long long
#define int ll
#define endl '\n'
#define vi vector<int>
#define PII pair<int, int>
const int MaxN = 5e5 + 100;
const int INF = 1e9;
const int mod = 1e9 + 7;
int T = 1, N, M;
int a[MaxN];
int b[MaxN];

struct node
{
#define lowbit(x) (x & (-x))
    int t[MaxN];
    void init(){
        memset(t,0,sizeof t);
        for(int i=1;i<=N;i++){
            for(int j=i-lowbit(i)+1;j<=i;j++){
                t[i]+=a[j];
            }
        }
    }
    void add(int x,int y){
        // a[x]+=y;
        while(x<=N){
            t[x]+=y;
            x+=lowbit(x);
        }
    }
    int sum(int x){
        int res=0;
        while(x){
            res+=t[x];
            x-=lowbit(x);
        }
        return res;
    }
    int query(int x,int y){
        return sum(y)-sum(x-1);
    }
    void print(){
        for(int i=1;i<=N;i++){
            cerr<<t[i]<<" ";
        }
        cerr<<endl;
    }
}Tree;
inline void Solve()
{
    cin>>N>>M;
    for(int i=1;i<=N;i++){
        cin>>b[i];
        a[i]=b[i]-b[i-1];
    }
    Tree.init();
    for(int i=1;i<=M;i++){
        int opt;
        cin>>opt;
        if(opt==1){
            int x,y,k;
            cin>>x>>y>>k;
            Tree.add(x,k);
            Tree.add(y+1,-k);
        }
        else{
            int x;
            cin>>x;
            cout<<Tree.sum(x)<<endl;
        }
    }
}
signed main()
{
#ifdef NOI_IO
    freopen(".in", "r", stdin);
    freopen(".out", "w", stdout);
#endif
    ios::sync_with_stdio(0);
    cin.tie(nullptr), cout.tie(nullptr);
#ifdef MULTI_CASES
    cin >> T;
    while (T--)
#endif
        Solve();
    return 0;
}

树状数组3(区间修改+区间求和)

该问题考虑使用两个树状数组维护差分解决。

首先我们考虑设一个差分数组 $d_i=a_i-a_{i-1}$,并用树状数组来维护。

之后考虑 $[l,r]$ 区间求和,同理,我们可以将其转化为求解 $[1,r]$ 的前缀和与 $[1,l-1]$ 的前缀和之差。那么现在问题就是如何求解前缀和。

考虑前缀和

$$sum(x)$$

首先,其可以表示为:

$$=\sum_{i=1}^{x}a_i$$

之后,转化为差分数组:

$$=\sum_{i=1}^{x}\sum_{j=1}^{i}d_j$$

显然,不难发现每个 $d_i$ 都被重复加了 $x-i+1$ 次。那么我们进一步推导:

$$=\sum_{i=1}^{x}d_i\times (x-i+1)$$

$$=\sum_{i=1}^{x}d_i\times (x+1)-\sum_{i=1}^{x}d_i\times i$$

那么显然,我们只需要用树状数组分别维护 $d_i$ 和 $d_i\times i$ 这两个数组信息即可。

那么区间加就很简单了,只需要使用在树状数组2中提到的修改 $d_l$ 单点加 $v$,$d_{r+1}$ 单点加 $-v$,并对 $d_l\times l$ 单点加 $v\times l$,$d_{r+1}\times (r+1)$ 单点加 $-v\times (r+1)$即可。

Code

#include <bits/stdc++.h>
using namespace std;
// #define MULTI_CASES
#define ll long long
#define int ll
#define endl '\n'
#define vi vector<int>
#define PII pair<int, int>
const int MaxN = 1e6 + 100;
const int INF = 1e9;
const int mod = 1e9 + 7;
int T = 1, N, M;
int a[MaxN];
struct node
{
#define lowbit(x) (x & (-x))
    int t[MaxN];
    void init(int a[])
    {
        memset(t, 0, sizeof t);
        for (int i = 1; i <= N; i++)
        {
            for (int j = i - lowbit(i) + 1; j <= i; j++)
            {
                t[i] += a[j];
            }
        }
    }
    void add(int x, int y)
    {
        // a[x]+=y;
        while (x <= N)
        {
            t[x] += y;
            x += lowbit(x);
        }
    }
    int sum(int x)
    {
        int res = 0;
        while (x)
        {
            res += t[x];
            x -= lowbit(x);
        }
        return res;
    }
} t1, t2;
int d[MaxN], d2[MaxN];
int sum(int x)
{
    return t1.sum(x) * (x + 1) - t2.sum(x);
}
inline void Solve()
{
    cin >> N >> M;
    for (int i = 1; i <= N; i++)
    {
        cin >> a[i];
        d[i] = a[i] - a[i - 1];
        d2[i] = d[i] * i;
    }
    t1.init(d);
    t2.init(d2);
    for (int i = 1; i <= M; i++)
    {
        int opt;
        cin >> opt;
        if (opt == 1)
        {
            int l, r, x;
            cin >> l >> r >> x;
            t1.add(l, x);
            t1.add(r + 1, -x);
            t2.add(l, l * x);
            t2.add(r + 1, -(r + 1) * x);
        }
        else
        {
            int l, r;
            cin >> l >> r;
            cout << sum(r) - sum(l - 1) << endl;
        }
    }
}
signed main()
{
#ifdef NOI_IO
    freopen(".in", "r", stdin);
    freopen(".out", "w", stdout);
#endif
    ios::sync_with_stdio(0);
    cin.tie(nullptr), cout.tie(nullptr);
#ifdef MULTI_CASES
    cin >> T;
    while (T--)
#endif
        Solve();
    return 0;
}

tricks

树状数组维护不可差分信息

我们知道常规的树状数组必须满足 可差分 的要求,但我们其实也可以维护不可差分的信息,比如区间最值。

但要注意,这种方法的单点修改和区间查询的时间复杂度都是 $O(\log^2n)$,劣于线段树的 $O(\log n)$。使用时请注意时间复杂度要求。

区间查询

基于常规树状数组的思路,我们仍然采用从 $r$ 一直沿着 $lowbit$ 往前跳的思路,但由于信息不可差分,我们无法采用之前的前缀和求差值的方案,所以过程中需要判断,保证不能跳到 $l$ 的左边。

设当前节点为 $t[x]$,我们操作时需要先判断 $x-lowbit(x)$ 是否小于 $l$:

Code

int sum(int l,int r){
    int ans=-INF;
    while(r>=l){
        ans=max(ans,a[r]);
        --r;
        while(r-lowbit(r)>=l){
            ans=max(ans,t[r]);
            r-=lowbit(r);
        }
    }
    return ans;
}

单点修改

由于信息不可差分,单点修改 $a[x]$ 之后必须要把其所属的区间完全重构。

那么我们考虑一个受到影响的 $t[y]$,我们考虑其在树状上的儿子节点,由于我们的操作是从儿子一路回溯修改节点,那么显然此时的儿子节点一定是正确的,且这段节点恰好可以组成 $[lowbit(y),y-1]$ 区间。那么我们可以将单点 $a[y]$ 合并,那么 $t[y]$ 节点的信息就是正确的了,重复操作合并即可。

Code

void add(int x,int y){
    a[x]=y;
    while(x<=N){
        t[x]=a[x];
        for(int i=x-lowbit(x)+1;i<x;i++){
            t[x]=max(t[x],a[i]);
        }
        x+=lowbit(x);
    }
}

$O(N)$建树

由于每个节点都是由所有自己直接相连的子节点信息合并得出,因此我们可以在每次确定完自己的值以后,用自己的值更新父节点的值。

Code

void init(){
    memset(t,0,sizeof t);
    for(int i=1;i<=N;i++){
        t[i]=a[i];
        int j=i+lowbit(i);
        if(j<=N){
            t[j]+=t[i];
        }
    }
}

时间戳优化

这个优化不是树状数组的专属优化,但既然OI Wiki提到了,那就顺便提一下。

在多组数据的题目中,如果每次都暴力清空数组,很容易超时。所以我们可以利用时间戳来标记当前数据上次被使用的时间,操作时判断这个位置的时间戳与当前时间是否一致。这样就可以判断这个位置是有值还是应该为 $0$。

本文采用 CC BY-NC 4.0 许可协议,转载请注明出处并保持非商业用途。