STDOJ pyf的树(树形dp)
作者:互联网
题目
思路
首先我们需要两个数组:
- c n t cnt cnt 数组记录每一个节点的子节点个数
- d p dp dp 数组在第一个 d f s dfs dfs 中记录的是以 1 1 1 为根节点,每个节点对树价值的贡献;在第二个 d f s dfs dfs 中记录的是以 i i i 为根节点整个树的价值;
第一个
d
f
s
dfs
dfs 没什么可以说的,重点在第二个
d
f
s
dfs
dfs ,我们从
1
1
1 为根节点开始,用
d
p
[
1
]
dp[1]
dp[1] 去更新与
1
1
1 相连的子节点,一直递归下去…
d
p
[
j
]
=
d
p
[
j
]
−
c
n
t
[
j
]
+
d
p
[
u
]
−
d
p
[
j
]
+
n
−
c
n
t
[
j
]
dp[j] = dp[j] - cnt[j] + dp[u] - dp[j] + n - cnt[j]
dp[j]=dp[j]−cnt[j]+dp[u]−dp[j]+n−cnt[j];
d
p
[
j
]
−
c
n
t
[
j
]
dp[j] - cnt[j]
dp[j]−cnt[j] 表示以
j
j
j 为根节点时的
j
j
j 和以
1
1
1 为根节点时
j
j
j 的子节点对以
j
j
j 为根的树的价值的贡献;
d
p
[
u
]
−
d
p
[
j
]
+
n
−
c
n
t
[
j
]
dp[u] - dp[j] + n - cnt[j]
dp[u]−dp[j]+n−cnt[j] 表示以
1
1
1 为根节点时除去
j
j
j 和
j
j
j 的子节点的点对以
j
j
j 为根的树的价值的贡献;
AC代码
#include <iostream>
#include <cmath>
#include <algorithm>
#include <cstring>
#include <map>
#include <vector>
#include <queue>
#define x first
#define y second
using namespace std;
typedef long long ll;
typedef pair <int,int> PII;
const int N = 1000010;
const ll mod = 998244353;
int n;
int e[2 * N], ne[2 * N], h[N], idx;
bool st[N];
ll cnt[N], dp[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs1(int u) {
st[u] = true;
cnt[u] = 1, dp[u] = 1;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (st[j]) continue;
dfs1(j);
cnt[u] += cnt[j];
dp[u] += dp[j] + cnt[j];
}
}
void dfs2(int u) {
st[u] = true;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (st[j]) continue;
dp[j] = dp[j] - cnt[j] + dp[u] - dp[j] + n - cnt[j];
dfs2(j);
}
}
int main()
{
cin >> n;
memset(h, -1, sizeof h);
for (int i = 1; i <= n - 1; i ++){
int a, b;
scanf("%d%d", &a, &b);
add(a, b);
add(b, a);
}
dfs1(1);
memset(st, false, sizeof st);
dfs2(1);
ll ans = dp[n], id = n;
for (int i = n - 1; i >= 1; i --) {
if (ans <= dp[i]){
ans = dp[i];
id = i;
}
}
cout << id << endl;
return 0;
}
——END
标签:cnt,int,dfs,节点,dp,STDOJ,pyf,为根 来源: https://blog.csdn.net/m0_52348473/article/details/120230289