线段树入门
作者:互联网
区间问题
区间问题可以抽象为:
- 修改 \(a_i\) 为 \(x\)
- 查询 \([l,r]\) 中最小值
我们可以使用线段树来解决这类问题。
线段树
定义以及概念
我们用一棵二叉树来表示线段树。线段树中每个节点都表示一个区间,每个区间都有左右两个子树,分别对应左半和右半。为了方便起见,令根结点 id 为 \(1\),对于结点 \(i\),其左节点编号为 \(2\times i\),右节点编号为 \(2\times i+1\)。
对于一个结点,如果令其表示的区间为 \([l,r]\)。分情况如果 \(l=r\),那么这个是一个叶子结点;否则令 \(mid=\lfloor\frac{l+r}{2}\rfloor\),左儿子对应的区间为 \([l,mid]\),右儿子的区间为 \([mid+1,r]\)。
建树
我们为了维护区间最值,我们需要用一个额外的数组 minv
记录每个节点对应的区间最小值。
对于叶子结点,最小值就是一个数。而对于非叶子结点,区间最小值就是左儿子最小值和右儿子最小值中的最小值。
例如 \(n=10,a=[1,3,5,7,9,10,2,4,8,6]\) 时,线段树如下:
其实建树就是一个递归的过程,父节点的信息需要通过子结点来更新,所以我们需要先递归建好左右子树。
const int maxn = 100010;
int minv[4 * maxn], a[maxn];
void build(int id, int l, int r) {
if (l == r) {
minv[id] = a[l]; //或者写 a[r] 也可以
return;
}
int mid = (l + r) >> 1; // 等价于 (l + r) / 2
build(id << 1, l, mid); // 其中 id << 1 等价于 id * 2
build(id << 1 | 1, mid + 1, r); // 其中 id << 1 | 1 等价于 id * 2 + 1
minv[id] = min(minv[id << 1], minv[id << 1 | 1]); // 更新父节点的值
return;
}
总建树时间复杂度为 \(\mathcal{O}(n)\)。
#include <iostream>
using namespace std;
const int maxn = 110;
int a[maxn];
int minv[4 * maxn];
void build(int id, int l, int r) {
if (l == r) {
minv[id] = a[l];
return;
}
int mid = (l + r) >> 1;
build(id << 1, l, mid);
build(id << 1 | 1, mid + 1, r);
minv[id] = min(minv[id << 1], minv[id << 1 | 1]);
}
int main() {
int n;
cin >> n;
for (int i = 1;i <= n;i++) {
cin >> a[i];
}
build(1, 1, n);
return 0;
}
其中 minv[id] = min(minv[id << 1], minv[id << 1 | 1]);
可以使用一个 pushup
函数来代替。
void pushup(int id) {
minv[id] = min(minv[id << 1], minv[id << 1 | 1]);
}
然后 build
函数中将 minv[id] = min(minv[id << 1], minv[id << 1 | 1]);
改为 pushup(id)
。
单点更新
如果仅修改一个值却要重新建树的话,时间消耗过大。
注意到,一个点的修改只会影响到包含这个点的区间,而包含这个点的区间在树上实际上是一条链。
例如 \(a_6\) 修改为 \(1\),那么我们的更新方式如下:
一般而言,我们可以认为线段树的最大深度为 \(\log n\),所以这条链最长为 \(\log n\),可以认为时间复杂度为 \(\mathcal{O}(\log n)\)。
void update(int id, int l, int r, int x, int v) {
if (l == r) {
minv[id] = v;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) {
update(id << 1, l, mid, x, v);
} else {
update(id << 1 | 1, mid + 1, r, x, v);
}
pushup(id);
return;
}
如果将 \(a_x\) 的值改为 \(v\),那么调用方式为 update(1, 1, n, x, v)
。
#include <iostream>
using namespace std;
const int maxn = 110;
int a[maxn];
int minv[4 * maxn];
void pushup(int id) {
minv[id] = min(minv[id << 1], minv[id << 1 | 1]);
}
void build(int id, int l, int r) {
if (l == r) {
minv[id] = a[l];
return;
}
int mid = (l + r) >> 1;
build(id << 1, l, mid);
build(id << 1 | 1, mid + 1, r);
pushup(id);
}
void update(int id, int l, int r, int x, int v) {
if (l == r) {
minv[id] = v;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) {
update(id << 1, l, mid, x, v);
} else {
update(id << 1 | 1, mid + 1, r, x, v);
}
pushup(id);
}
int main() {
int n;
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
build(1, 1, n);
int q;
cin >> q;
for (int i = 0;i < q;i++) {
int x, v;
cin >> x >> v;
update(1, 1, n, x, v);
}
return 0;
}
单点查询
很简单,一直沿着链走到叶子结点即可。
int query(int id, int l, int r, int x) {
if (l == r) {
return minv[id];
}
int mid = (l + r) >> 1;
if (mid <= x) {
return query(mid << 1, l, mid, x);
} else {
return query(mid << 1 | 1. mid + 1, r, x);
}
}
区间查询
单点查询实际上是区间查询的特殊情况。
对于区间 \([x,y]\),有如下几种可能:
- 区间 \([x,y]\) 完全包含 \([l,r]\) 区间,则直接返回
minv[id]
。 - 如果左区间包含,则查询左子树
- 如果右区间包含,则查询右子树
int query(int id, int l, int r, int x, int y) {
if (x <= l && r <= y) {
return minv[id];
}
int mid = (l + r) >> 1;
int ans = inf;
if (x <= mid) {
ans = min(ans, query(id << 1, l, mid, x, y));
}
if (y > mid) {
ans = min(ans, query(id << 1 | 1, mid + 1, r, x, y));
}
return ans;
}
#include <iostream>
using namespace std;
const int inf = 0x3f3f3f3f;
const int maxn = 110;
int a[maxn];
int minv[4 * maxn];
void pushup(int id) {
minv[id] = min(minv[id << 1], minv[id << 1 | 1]);
}
void build(int id, int l, int r) {
if (l == r) {
minv[id] = a[l];
return;
}
int mid = (l + r) >> 1;
build(id << 1, l, mid);
build(id << 1 | 1, mid + 1, r);
pushup(id);
}
void update(int id, int l, int r, int x, int v) {
if (l == r) {
minv[id] = v;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) {
update(id << 1, l, mid, x, v);
} else {
update(id << 1 | 1, mid + 1, r, x, v);
}
pushup(id);
}
int query(int id, int l, int r, int x, int y) {
if (x <= l && r <= y) {
return minv[id];
}
int mid = (l + r) >> 1;
int ans = inf;
if (x <= mid) {
ans = min(ans, query(id << 1, l, mid, x, y));
}
if (y > mid) {
ans = min(ans, query(id << 1 | 1, mid + 1, r, x, y));
}
return ans;
}
int main() {
int n;
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
build(1, 1, n);
int q;
cin >> q;
for (int i = 0; i < q; ++i) {
int x, v;
cin >> x >> v;
update(1, 1, n, x, v);
}
int p;
cin >> p;
for (int i = 0;i < p;i++) {
int l, r;
cin >> l >> r;
cout << query(1, 1, n, l, r) << endl;
}
return 0;
}
标签:入门,int,线段,mid,maxn,区间,id,minv 来源: https://www.cnblogs.com/luogu-int64/p/15580087.html