简单的树状数组 —— ylxmf2005的OI教程

前置芝士

  1. 前缀和
  2. 差分
  3. 最基础的位运算

下面讲的用树状数组求的内容有:

  1. 各种修改查询
  2. 逆序对
  3. 最长单调序列

什么是树状数组?

树状数组,又名蜀庄鼠族。

树状数组呢,就是长的像树的数组,你可能问,数组就是一个序列,怎么能长的像树呢?别急,我们下面会讲,先让我们看两道题:

给定一个大小为 $n$ 的数组与 $m$ 个询问,每次询问区间 $[l,r]$ 的和。

  1. 不学 OI 的做法:每次暴力求和,复杂度 $O(n \times m)$。
  2. 小学生做法:前缀和,复杂度 $O(n + m)$。
  3. 无聊做法:线段树或树状数组,复杂度 $O(mlgon)$。

给定一个大小为 $n$ 的数组与 $m$ 个询问,每次询问有两种,第一种是将位置 $p$ 的值修改为 $x$,另一种是求 $[l,r]$ 的区间和。

  1. 不学 OI 的做法:每次 $O(1)$ 修改,$O(n)$ 查询。
  2. 小学生做法:我不会,也要暴力(高爸等大佬除外)。
  3. 初中生做法:树状数组或线段树,复杂度 $O(mlgon)$。
  4. 无聊做法:分块暴力卡常

我们发现,树状数组和线段树都可以再 $O(mlogn)$ 的时间解决区间问题,怎么解决的你先别问,我先告诉你,树状数组能解决的问题,线段树也都能解决。线段树能解决的问题,树状数组不一定能解决。但是,对于一道同时能用两种算法做的题,树状数组的代码短,不容易出错,而且常数还小,占用空间也小。不过这个数据结构的最大缺陷就是能解决的问题相对于线段树少的多。比如树状数组在解决一些位运算操作,区间最值的时候会很麻烦,甚至做不到,但是线段树可以解决这些问题。

那么,让我们一起进入树状数组的学习吧。

单点修改与区间查询

刚才我们的第二题实际上就是一道单点修改,区间查询的模板题,那么我们如何用树状数组来解决本题呢?

首先,我们引入一个函数,叫做 lowbit,它的作用是:求一个十进制数的二进制状态下从右数第一个 $1$ 表示的值。这么说你可能不明白,那么我举一个栗子:

比如数字 $12$,它的二进制为 $1100$,从右数第一个位置在 $2$ 上,取出来就是二进制数 $100$,表示的数字是 $4$。那么这个东东咋求呢?很简单,负数是对一个数取反后 $+1$,我们设第一个 $1$ 的位置为 $x$,那么取反后,$x$ 前面的数字都取反了,$x$ 变成 $0$,它后面的 $0$ 都变成 $1$,然后 $+1$,那么,后面的 $1$ 都变成 $0$,一直进位到 $x$ 位,$x$ 变成 $1$,$x$ 前面的数字都不变,与操作后就可以得到答案了。

举个栗子:数字 $6$ ,它的二进制数是 $0110$,那么按照补码求法,$-6$ 的值为 $1001+1=1010$。接着,$6\& -6$,结果等于 $0010$。

然后,我们对于数组的下标取一个 $lowbit$,会长成这样子:

《简单的树状数组 —— ylxmf2005的OI教程》

是不是很像一棵树?而且,对于每一个数组下标(横坐标),我们假设这个下标为 $x$,那么他是他横坐标 $x+1$ 的左儿子(如果有的话),而且他的父亲的纵坐标 $= x + lowbit(x)$。同理,如果 $x$ 是他父亲的右儿子,那么它父亲的纵坐标是 $x – lowbit(x)$。

对于每个下标,如果他的横坐标为 $x$,那么他表示的区间为 $[x – lowbit(x) + 1, x]$。我们用 $C_i$ 来维护每个下标表示的区间。我们要求一个区间的和,实际也是求一个前缀和,比如 $[2, 5]$,我们先求出 $[1,1]$,表示这个区间横坐标为 $1$,然后再求出 $[1, 5]$,实际上等于 $C_4 + C_5$。

然后我们就写完这个树状数组了。显然时间复杂度为树高,每次查询的复杂度为 $O(logn)$。

#include <bits/stdc++.h>
using namespace std;
#define re register
#define F first
#define S second
#define lowbit(x) x & (-x)
typedef long long ll;
typedef pair<int, int> P;
const int N = 5e5 + 5;
const int INF = 0x3f3f3f3f;
inline int read() {
    int X = 0,w = 0; char ch = 0;
    while(!isdigit(ch)) {w |= ch == '-';ch = getchar();}
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar();
    return w ? -X : X;
}
int a[N], C[N], n, m; 
void add(int x,int k){
    for (int i = x;i <= n; i += lowbit(i)) C[i] += k;
}
int get(int x){
    int sum = 0;
    for (int i = x; i; i -= lowbit(i)) sum += C[i];
    return sum;
}
int main(){
    n = read(), m = read();
    for (int i = 1; i <= n; i++) {
        a[i] = read(); add(i, a[i]);
    }
    while (m--){
        int opt = read(), x = read(), y = read();
        if (opt == 1) add(x, y);
        else printf ("%d\n", get(y) - get(x - 1));
    }
    return 0;
}

区间修改与单点查询

我们先考虑只有区间修改,查询只要在最后输出的情况。显然,这样的话差分可以 $O(n + m)$ 水过(不会的去百度)。

众所周知,差分修改的时候,其实就是单点修改 $l-1,r + 1$ 这两个点,查询的话就变成了 $[1,x]$ 前缀和,这又不和上面单点修改和区间查询又一样了吗?水一水又过去了。

#include <bits/stdc++.h>
using namespace std;
#define re register
#define F first
#define S second
#define lowbit(x) x & (-x)
typedef long long ll;
typedef pair<int, int> P;
const int N = 5e5 + 5;
const int INF = 0x3f3f3f3f;
inline int read() {
    int X = 0,w = 0; char ch = 0;
    while(!isdigit(ch)) {w |= ch == '-';ch = getchar();}
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar();
    return w ? -X : X;
}
int a[N], C[N], n, m; 
void add(int x,int k){
    for (int i = x;i <= n; i += lowbit(i)) C[i] += k;
}
int get(int x){
    int sum = 0;
    for (int i = x; i; i -= lowbit(i)) sum += C[i];
    return sum;
}
int main(){
    n = read(), m = read();
    for (int i = 1; i <= n; i++) a[i] = read();
    for (int i = 1; i <= n; i++) add(i, a[i] - a[i - 1]);
    while (m--){
        int opt = read();
        if (opt == 1){
            int x = read(), y = read(), k = read();
            add(x, k); add(y + 1, -k);
        }else{
            int x = read();
            printf("%d\n", get(x));
        }
    }
    return 0;
}

区间修改与区间查询

我们对支持区间修改,单点查询的树状数组进行改造。

众所周知,单点修改求的是 $\sum_{i=1}^x C_i$。

我们先求 $[1,r]$ 的和,能求出的话$[l,r]$ 就好求了。要求的是 $\sum_{i=1}^ x\sum_{j=1}^i C_i$。

在这个式子中,$C_i$ 被计算了$x-i + 1$ 次。所以,我们求的是

$$
\sum_{i=1}^x (x-i + 1) \times C_i = (x + 1) \sum_{i=1}^x C_i – C_i \times i
$$

维护两个 $C$ 即可。

#include <bits/stdc++.h>
using namespace std;
#define re register
#define F first
#define S second
#define lowbit(x) x & (-x)
typedef long long ll;
typedef pair<int, int> P;
const int N = 1e5 + 5;
const int INF = 0x3f3f3f3f;
inline int read() {
    int x = 0, f = 0; char ch = 0;
    while (!isdigit(ch)) f |= ch == '-', ch = getchar();
    while (isdigit(ch)) x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
    return f ? -x : x;
}
int C1[N], a[N]; ll C2[N];
int n, m;
void add(int x, int k){
    for (int i = x; i <= n; i += lowbit(i)) C1[i] += k, C2[i] += x * k;
}
ll query(int x){
    ll ans = 0;
    for (int i = x; i; i -= lowbit(i)) ans += (x + 1) * C1[i] - C2[i];
    return ans;
}
int main(){
    n = read(), m = read();
    for (int i = 1; i <= n; i++) a[i] = read(), add(i, a[i] - a[i - 1]);
    while (m--){
        int opt = read();
        if (opt == 1){
            int l = read(), r = read(), x = read();
            add(l, x); add(r + 1, -x);
        }else{
            int l = read(), r = read();
            printf("%lld\n", query(r) - query(l - 1)); 
        }
    } 
    return 0;
}

逆序对

学这里前,你必须先会离散化。

首先,我们将输入的数离散化一下。我们用树状数组维护 大小为 $a$ 的数出现了多少次,我们从左往右扫这个序列,当我们处理到 $i$ 的时候,$b$ 数组包含的是 $1 \sim i – 1$ 的信息。设 $C$ 数组 $1 \sim a_i -1$ 的前缀和是 $x$,那么此处对逆序对的贡献为 $x$。这样求逆序对的复杂度与归并排序一样,都是 $O(nlogn)$。下面的做法懒得写离散化了,自己写吧。

#include <bits/stdc++.h>
#define lowbit(x) x & (-x)
using namespace std;
int tree[3505], a[3505], n;
void add(int x, int c){
    while (x <= n) tree[x] += c, x += lowbit(x);
}
int query(int x){
    int ans = 0;
    while (x) ans += tree[x], x -= lowbit(x);
    return ans;
}
int main(){
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    int ans = 0;
    for (int i = 1; i <= n; i++) ans += (i - query(a[i]) - 1), add(a[i], 1);
    printf("%d\n", ans);
    return 0;
}

前缀最值

注意,我们讲的前缀最值,是求的 $[1,x]$ 的最值,我建议求 $[l,r]$ 的最值时用线段树,这种情况下用线段树更正解高效。下面默认求最大值。

这个问题还是很好解决的,相信大家自己想一下就会了。所以讲解呢?这么简单不用了吧。

#include <bits/stdc++.h>
using namespace std;
#define re register
#define F first
#define S second
#define lowbit(x) x & (-x)
typedef long long ll;
typedef pair<int, int> P;
const int N = 5e5 + 5;
const int INF = 0x3f3f3f3f;
inline int read() {
    int X = 0,w = 0; char ch = 0;
    while(!isdigit(ch)) {w |= ch == '-';ch = getchar();}
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar();
    return w ? -X : X;
}
int a[N], C[N], n, m; 
void add(int x,int k){
    for (int i = x;i <= n; i += lowbit(i)) C[i] = max(C[i], k);
}
int get(int x){
    int ans = 0;
    for (int i = x; i; i -= lowbit(i)) ans = max(ans, C[i]);
    return ans;
}
int main(){
    n = read(), m = read();
    for (int i = 1; i <= n; i++) a[i] = read();
    for (int i = 1; i <= n; i++) add(i, a[i] - a[i - 1]);
    while (m--){
        int opt = read();
        if (opt == 1){
            int x = read(), y = read(), k = read();
            add(x, k); add(y + 1, -k);
        }else{
            int x = read();
            printf("%d\n", get(x));
        }
    }
    return 0;
}

最长单调子序列

首先你要学会上面的求前缀最值和最最基础的动态规划。我们用树状数组求 LIS,其实是和二分求一样的复杂度,都是 $O(nlogn)$,不过我认为:用树状数组来求能用的范围广,直接来说就是更简单嘛。下面就求最长上升子序列吧。

我们还是先离散化一下,不过我的代码还是懒得离散化(没狼心)。我们先回顾一下 $O(n^2)$ 的做法:每次往前找值小于它的,对 $dp_j$ 取 $max$,所以遍历一遍 $+$ 往前找的复杂度是 $O(n^2)$。我们用树状数组维护向前找的过程,所以做法是:我们用 $C_i$ 表示 $a_j = i$ 当中最大的 $dp_j$。然后从左往右扫序列,在求 $dp_i$ 的时候,我们查询前缀最大值就好了。可能有点难理解,可以手算一下加强理解。

#include <bits/stdc++.h>
using namespace std;
#define re register
#define F first
#define S second
#define lowbit(x) x & (-x)
typedef long long ll;
typedef pair<int, int> P;
const int N = 5e5 + 5;
const int INF = 0x3f3f3f3f;
inline int read() {
    int X = 0,w = 0; char ch = 0;
    while(!isdigit(ch)) {w |= ch == '-';ch = getchar();}
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar();
    return w ? -X : X;
}
int a[N], C[N], n, m, dp[N]; 
void add(int x,int k){
    for (int i = x;i <= n; i += lowbit(i)) C[i] = max(C[i], k);
}
int get(int x){
    int ans = 0;
    for (int i = x; i; i -= lowbit(i)) ans = max(ans, C[i]);
    return ans;
}
int main(){
    n = read(), m = read();
    for (int i = 1; i <= n; i++) a[i] = read();
    int ans = 0;
    for (int i = 1; i <= n; i++){
        dp[i] = get(a[i] - 1) + 1;
        ans = max(ans, dp[i]);
        add(a[i], dp[i]);
    } 
    return 0;
}

总结

树状数组不仅能高效简洁的解决区间修改查询的问题,还能优化动态规划等算法,常数小,代码短,但是局限性较大。

点赞

发表评论

电子邮件地址不会被公开。 必填项已用*标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据