其他分享
首页 > 其他分享> > 基础树上问题之 树的直径 + 最近公共祖先 例题及学习笔记(入门版)

基础树上问题之 树的直径 + 最近公共祖先 例题及学习笔记(入门版)

作者:互联网

本篇博客是关于洛谷题单【图论2-1】基础树上问题 的题目题解合集
紫题还不会,先鸽
同时附加一点我的个人学习心得

基础树上问题 除了 树形dp 外,还有 树的直径LCA 等问题

树的直径

树的直径即树上最长路的长度

求法是首先任取一点作为根,求出一个到根最远的点,此为直径的一端;再以这个端点为根再进行一次dfs,求到根最远的点,为直径的另一端点

先放个树的直径的板子:

树的直径
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int n, k;
vector<int>e[N];
int d[N];   //点的实际深度
int maxd[N];//点可以到达的最大深度
int s, t, mxd;
int f[N], ans[N]; //到其他点的最大距离

void dfs1(int now, int fa) {
    d[now] = d[fa] + 1;
    if(d[now] > mxd){
        mxd = d[now];
        s = now;
    }
    for(auto i:e[now]) {
        if(i == fa) continue;
        dfs1(i, now);
    }
}

void dfs2(int now, int fa) {
    d[now] = d[fa] + 1;
    f[now] = fa;
    if(d[now] > mxd){
        mxd = d[now];
        t = now;
    }
    for(auto i:e[now]) {
        if(i == fa) continue;
        dfs2(i, now);
    }
}

void solve() {
    //两次dfs求直径
    mxd = -1;
    dfs1(1, 0);  
    d[0] = -1; mxd = -1;
    dfs2(s, 0);
    //s 和 t 即为树的直径
}

int main(){
    cin >> n >> k;
    for(int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        e[u].pb(v);
        e[v].pb(u);
    }
    solve();
    system("pause");
    return 0;
}

----------------接下来是例题-----------------------

P1395 会议

题意
求到n个人距离之和最小的树上的点

思路
其实就是先任选一个点,求出距离,可以 \(O(n)\) 更新其他的点

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 5e5+10, mod = 998244353;
int t, n, q, u, v;
struct Edge{
	int to,nex;
}e[2*N]; 
int head[N],d[N],fa[N][30], sz[N];
ll f[N], mx;
int ind, cnt;

void add(int u,int v){
	e[++cnt].to=v;
	e[cnt].nex=head[u];
	head[u]=cnt;
} 

void dfs(int now,int father){
	sz[now] = 1;
    d[now] = d[father] + 1;
	for(int i=head[now];i;i=e[i].nex){
		if(e[i].to!=father){
            dfs(e[i].to,now);
            sz[now] += sz[e[i].to];
        }
	}
}

void dfs2(int now,int father){
    f[now] = f[father] - sz[now] + (n - sz[now]);
	for(int i=head[now];i;i=e[i].nex){
        int x = e[i].to;
		if(x != father){
            dfs2(x, now);
        }
	}
}

int main(){
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        scanf("%d%d", &u, &v);
        add(u, v);
        add(v, u);
    }
    dfs(1, 0);
    
    f[1] = 0;
    for (int j = 1; j <= n; j++ ) {
        f[1] += d[j] - d[1];
    }
    mx = f[1];
    ind = 1;

    for (int i = head[1]; i; i = e[i].nex) {
        dfs2(e[i].to, 1);
    }

    for (int i = 1; i <= n; i++) {
        if(mx > f[i]){
            mx = f[i];
            ind = i;
        }
    }

    cout<<ind<<' '<<mx<<endl;

    system("pause");
    return 0;
}

P5536 【XR-3】核心城市

题意
选k个不经过其他城市就两两可达的点作为核心城市,求非核心城市到核心城市的最大距离的最小值

思路
如果 $k = 1 $ ,那这个城市就是树的直径的中点
如果 $k > 1 $ ,先找到第一个核心城市,然后从这个城市开始dfs,贪心地选取剩下的城市。具体见代码

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int n, k;
vector<int>e[N];
int d[N];   //点的实际深度
int maxd[N];//点可以到达的最大深度
int s, t, mxd;
int f[N], ans[N]; //到其他点的最大距离

void dfs1(int now, int fa) {
    d[now] = d[fa] + 1;
    if(d[now] > mxd){
        mxd = d[now];
        s = now;
    }
    for(auto i:e[now]) {
        if(i == fa) continue;
        dfs1(i, now);
    }
}

void dfs2(int now, int fa) {
    d[now] = d[fa] + 1;
    f[now] = fa;
    if(d[now] > mxd){
        mxd = d[now];
        t = now;
    }
    for(auto i:e[now]) {
        if(i == fa) continue;
        dfs2(i, now);
    }
}

void dfs_k(int now, int fa) {
    d[now] = d[fa] + 1;
    maxd[now] = d[now];
    for(auto i:e[now]){
        if(i == fa) continue;
        dfs_k(i, now);
        maxd[now] = max(maxd[now], maxd[i]);
    }
}

void solve() {
    //两次dfs求直径
    mxd = -1;
    dfs1(1, 0);  
    d[0] = -1; mxd = -1;
    dfs2(s, 0);

    //找直径中点t
    int tt = t;
    for(int i = 1; i <= (d[tt] - d[s]) / 2 ; i++) t = f[t];

    //确定k个点 , 首先求出每个点能到达(往下走)的最大深度
    d[0] = -1;
    dfs_k(t, 0);
    for(int i = 1; i <= n; i++) {
        // cout<<i<<' '<<d[i]<<' '<<maxd[i]<<endl; ///
        ans[i] = maxd[i] - d[i];
    }
    sort(ans + 1, ans + n + 1, greater<int>());
    printf("%d\n", ans[k + 1] + 1);
}

int main(){
    cin >> n >> k;
    for(int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        e[u].pb(v);
        e[v].pb(u);
    }
    solve();
    system("pause");
    return 0;
}

P1099 [NOIP2007 提高组] 树网的核

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int n, S;
vector<pii>e[N];
int d[N];   //点的实际深度
int s, t, mxd;
int f[N];
bool vis[N];  //直径上的点

void dfs1(int now, int fa) {
    if(d[now] > mxd){
        mxd = d[now];
        s = now;
    }
    for(auto i:e[now]) {
        if(i.first == fa) continue;
        d[i.first] = d[now] + i.second;
        dfs1(i.first, now);
    }
}

void dfs2(int now, int fa) {
    if(d[now] > mxd){
        mxd = d[now];
        t = now;
    }
    for(auto i:e[now]) {
        if(i.first == fa) continue;
        d[i.first] = d[now] + i.second;
        f[i.first] = now;
        dfs2(i.first, now);
    }
}

void solve() {
    //两次dfs求直径
    mxd = -1; d[1] = 0;
    dfs1(1, 1);  
    mxd = -1; d[s] = 0;
    dfs2(s, s);
    f[s] = 0;

    int ans = 1e9;
    //答案第一种来源:直径上的
    for(int i = t; i; i = f[i]){
        vis[i] = 1;
        for(int j = i; j; j = f[j]){
            if(d[i] - d[j] <= S){
                ans = min(ans, max(d[j], d[t] - d[i]));
            }
        }
    }
    // printf("%d\n", ans);

    //答案另外一种来源:直径之外的
    for(int j = 1; j <= n; j++){
        if(vis[j]) continue;
        int mx = 1e9;
        for(int i = t; i; i = f[i]) {
            if(d[j] > d[i]) mx = min(mx, d[j] - d[i]);
        }
        ans = max(ans, mx);
    }

    printf("%d\n", ans);
}

int main(){
    cin >> n >> S;
    for(int i = 1, u, v, w; i < n; i++) {
        scanf("%d%d%d", &u, &v, &w);
        e[u].pb({v,w});
        e[v].pb({u,w});
    }
    solve();
    system("pause");
    return 0;
}

最近公共祖先(LCA)

先放个LCA的板子,亲测能通过洛谷上LCA相关的题目

LCA

#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int t, n, q;
vector<int> e[N];
int a[2], b[2];
int f[N][33], d[N];

void dfs(int now, int fa) {
	d[now] = d[fa] + 1;
	f[now][0] = fa;
	for(int i = 1; (1 << i) <= d[now]; i++) {
		f[now][i] = f[f[now][i - 1]][i - 1];
	}
	for(auto i:e[now]) {
		if(i == fa) continue;
		dfs(i, now);
	}
}

int lca(int a, int b) {
	if(d[a] < d[b]) swap(a, b);
	int dep;
	for(dep = 0; (1 << dep) <= d[a]; dep++) ; dep--;
	for(int i = dep; i >= 0 ; i--) {
		if(d[a] - (1 << i) >= d[b]) a = f[a][i];
	}
	if(a == b) return a;
	for(int i = dep; i >= 0; i--) {
		if(f[a][i] == f[b][i]) continue;
		else {
			a = f[a][i];
			b = f[b][i];
		}
	}
	return f[a][0];
}

inline int dis(int a, int b) {
	return d[a] + d[b] - 2 * d[lca(a, b)];
}

inline bool check(int a, int b, int ff) {
	if(dis(a, ff) + dis(b, ff) == dis(a, b)) return 1;
	return 0;
}

int main(){
    cin >> n >> q;
	for(int i = 1, u, v; i < n; i++) {
		scanf("%d%d", &u, &v);
		e[u].pb(v);
		e[v].pb(u);
	}

	dfs(1, 0);
	while(q--) {
		scanf("%d%d%d%d", &a[0], &b[0], &a[1], &b[1]);
		int f1 = lca(a[0], b[0]); int low1 = max(d[a[0]], d[b[0]]);
		int f2 = lca(a[1], b[1]); int low2 = max(d[a[1]], d[b[1]]);
		int f = lca(f1, f2);
		// cout<< f1 <<' '<<f2<<endl; ///
		if(check(a[0], b[0], f2) || check(a[1], b[1], f1) ) puts("Y");
		else puts("N");
	}
    system("pause");
    return 0;
}

----------------接下来是例题-----------------------

P5836 [USACO19DEC]Milk Visits S

题意
一棵树上,每个点有一种品种的奶牛,总共有两种奶牛。
有 \(q\) 位客人要从 \(u\) 点到 \(v\) 点参观,问能否经过特定种类的奶牛。

思路
随便指定一个点为根,可以用 \(dfs\) 求出每个点到根这条路径上两种牛的数目,询问一条到祖先的路径上牛的数目只要用这个点的减去祖先的即可
对于每个询问求出 \(u\) 到 \(lca(u,v)\)、 \(v\) 到 \(lca(u,v)\) 上的牛的数目,大于零即puts("Y")

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5+10, mod = 998244353;
int t, n, q, u, v;
char s[N];
vector<int> g[N];
struct Edge{
	int to,nex;
}e[2*N]; 
int head[N],d[N],fa[N][30],num[N][2];
int cnt;
char ch;

void add(int u,int v){
	e[++cnt].to=v;
	e[cnt].nex=head[u];
	head[u]=cnt;
} 

void dfs(int now,int father){
	fa[now][0]=father;
	d[now]=d[father]+1;
    num[now][0] = num[father][0] + (s[now] == 'H');
    num[now][1] = num[father][1] + (s[now] == 'G');
	for(int i=1;(1<<i)<=d[now];i++){
		fa[now][i]=fa[fa[now][i-1]][i-1];
	}
	for(int i=head[now];i;i=e[i].nex){
		if(e[i].to!=father) dfs(e[i].to,now);
	}
}

int lca(int a,int b) {                                         //非常标准的lca查找{
    if(d[a]<d[b]) swap(a,b);    //d[a]大 
    int dep;
    for(dep=0;(1<<dep)<=d[a];dep++);
	dep--;
    for(int i=dep;i>=0;i--)
        if(d[a]-(1<<i)>=d[b])
            a=fa[a][i];             //先把b移到和a同一个深度
    if(a==b) return a;                 //特判,如果b上来和就和a一样了,那就可以直接返回答案了
    for(int i=dep;i>=0;i--){
        if(fa[a][i]==fa[b][i])
            continue;
        else
            a=fa[a][i],b=fa[b][i];           //A和B一起上移
    }
    return fa[a][0];            
}

int main(){
    scanf("%d%d", &n, &q);
    scanf("%s", s+1);
    for (int i = 1; i < n; i++) {
        scanf("%d%d", &u, &v);
        add(u, v);
        add(v, u);
    }
    dfs(1, 1);
    while (q--) {
        scanf("%d%d", &u, &v); cin>>ch;
        int f = lca(u, v);
        int ans;
        if(ch == 'H') ans = num[u][0] + num[v][0] - num[f][0] - num[fa[f][0]][0];
        else ans = num[u][1] + num[v][1] - num[f][1] - num[fa[f][0]][1];
        if(ans) printf("1");
        else printf("0");
    }
    system("pause");
    return 0;
}

P3398 仓鼠找 sugar

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
int t, n, q;
vector<int> e[N];
int a[2], b[2];
int f[N][33], d[N];

//yes 的情况:一条路径的lca在另外一条路径上

//怎么知道一条路径上包含另外一个点?

void dfs(int now, int fa) {
	d[now] = d[fa] + 1;
	f[now][0] = fa;
	for(int i = 1; (1 << i) <= d[now]; i++) {
		f[now][i] = f[f[now][i - 1]][i - 1];
	}
	for(auto i:e[now]) {
		if(i == fa) continue;
		dfs(i, now);
	}
}

int lca(int a, int b) {
	if(d[a] < d[b]) swap(a, b);
	int dep;
	for(dep = 0; (1 << dep) <= d[a]; dep++) ; dep--;
	for(int i = dep; i >= 0 ; i--) {
		if(d[a] - (1 << i) >= d[b]) a = f[a][i];
	}
	if(a == b) return a;
	for(int i = dep; i >= 0; i--) {
		if(f[a][i] == f[b][i]) continue;
		else {
			a = f[a][i];
			b = f[b][i];
		}
	}
	return f[a][0];
}

inline int dis(int a, int b) {
	return d[a] + d[b] - 2 * d[lca(a, b)];
}

inline bool check(int a, int b, int ff) {
	if(dis(a, ff) + dis(b, ff) == dis(a, b)) return 1;
	return 0;
}

int main(){
    cin >> n >> q;
	for(int i = 1, u, v; i < n; i++) {
		scanf("%d%d", &u, &v);
		e[u].pb(v);
		e[v].pb(u);
	}

	dfs(1, 0);
	while(q--) {
		scanf("%d%d%d%d", &a[0], &b[0], &a[1], &b[1]);
		int f1 = lca(a[0], b[0]); int low1 = max(d[a[0]], d[b[0]]);
		int f2 = lca(a[1], b[1]); int low2 = max(d[a[1]], d[b[1]]);
		int f = lca(f1, f2);
		// cout<< f1 <<' '<<f2<<endl; ///
		if(check(a[0], b[0], f2) || check(a[1], b[1], f1) ) puts("Y");
		else puts("N");
	}
    system("pause");
    return 0;
}

标签:例题,入门,int,笔记,fa,long,mxd,now,define
来源: https://www.cnblogs.com/re0acm/p/16613422.html