其他分享
首页 > 其他分享> > P5298[PKUWC2018]Minimax (线段树合并)

P5298[PKUWC2018]Minimax (线段树合并)

作者:互联网

题目链接

  思路:因为所有点的权值是互不相同的,并且概率\(0 < p_x < 1\),也就是所有的点都会被选到。所以用\(dp[i][j]\)来表示节点\(i\)权值为\(j\)的概率。首先考虑叶子节点,叶子节点都没有子节点所以他们的权值是确定的,\(dp[i][j] = [i = val]\);再考虑只有一个子节点的节点,那么它也没有什么选择,直接继承子节点的贡献就好了\(dp[i][j] = dp[son][j]\);再考虑两个子节点的点,用\(lc\)表示左儿子\(rc\)表示右儿子。对于\(dp[lc][j] > 0\),那么节点一定在左子树中,右子树同理,所以仅分析左子树。因为\(i\)节点的权值有\(p_x\)的可能是子节点的最大值,\(1 - p_x\)的概率是子节点的最小值,而我们限定了这个点是在左子树里面,所以有\(p_x\)的概率比右子树的最大值要大也就是\(p_x × \sum\limits_{k = 1} ^ {j - 1} dp[rc][k]\),有\(1 - p_x\)的概率比右子树的最小值要小,\((1 - p_x) × \sum\limits_{k = j + 1}^{n} dp[rc][k]\)所以\(dp[i][j] = dp[lc][j] × \left(p_x × \sum\limits_{k = 1} ^ {j - 1} dp[rc][k] + \left(1 - p_x\right) × \sum\limits_{k = j + 1} ^ {n} dp[rc][k] \right)\).
    不难发现\(\sum\limits_{k = 1} ^ {j - 1} dp[rc][k]\), \(\sum\limits_{k = j + 1}^{n} dp[rc][k]\)这两个一个是前缀和一个是后缀和,那么就需要一种能够维护区间信息和前后缀和并且可以方便转移\(dp\)状态方程的一个数据结构,那么线段树就可以,主要到了叶节点的状态是固定的,并且所有的状态都是从叶节点转移过去的,所以就一直递归到叶子节点之后,一直往上合并子树,这样就转移了前后缀和,每次乘上概率的时候就是做一个区间乘法就好了。

    int n;
    std::cin >> n;
    std::vector<std::array<int, 2>> p(n + 1);
    std::vector<int> cnt(n + 1);
    for (int i = 1; i <= n; i++) {
        int fa; std::cin >> fa;
        if (fa) p[fa][cnt[fa]++] = i;
    }

    std::vector<int> val(n + 1);
    int tmp[n + 10] = {};
    for (int i = 1; i <= n; i++) std::cin >> val[i];
    int idx = 0;
    for (int i = 1; i <= n; i++) {
        if (cnt[i]) val[i] = 1ll * val[i] * qpow(10000) % mod;
        else tmp[++idx] = val[i];
    }
    std::sort(tmp + 1, tmp + idx + 1);

    for (int i = 1; i <= n; i++) 
        if (!cnt[i]) val[i] = std::lower_bound(tmp + 1, tmp + 1 + idx, val[i]) - tmp;

    std::vector<int> root(n + 1);
    std::vector<Info> tr((n << 5) + 1);
    int tot = 0;

    auto newnode = [&]() -> int {
        int x = ++tot;
        tr[x].sum = tr[x].ch[0] = tr[x].ch[1] = 0;
        tr[x].tag = 1;
        return x;
    };

    auto settag = [&] (int u, i64 x) {
        if (!u) return ;
        tr[u].sum = 1ll * tr[u].sum * x % mod;
        tr[u].tag = 1ll * tr[u].tag * x % mod;
    };

    auto push = [&] (int u) -> void {
        if (tr[u].tag == 1) return ;
        if (tr[u].ch[0]) settag(tr[u].ch[0], tr[u].tag);
        if (tr[u].ch[1]) settag(tr[u].ch[1], tr[u].tag);
        tr[u].tag = 1;
    };

    std::function<void(int&, int, int, int, int)> update = [&] (int& u, int l, int r, int x, int v) -> void {
        if (!u) u = newnode();
        if (l == r) return void(tr[u].sum = v);
        int mid = (l + r) >> 1;
        push(u);
        if (mid >= x) update(tr[u].ch[0], l, mid, x, v);
        else update(tr[u].ch[1], mid + 1, r, x, v);
        tr[u].sum = (tr[tr[u].ch[0]].sum + tr[tr[u].ch[1]].sum) % mod;
    };

    std::function<int(int, int, int, int, int, int, int)> merge = [&] (int x, int y, int l, int r, int tag1, int tag2, int v) -> int {
        if (!x || !y) {
            settag(x, tag1), settag(y, tag2);
            return x | y;
        }
        push(x), push(y);
        int mid = l + r >> 1;
        i64 lpre = tr[tr[x].ch[0]].sum, lsuf = tr[tr[y].ch[0]].sum;
        i64 rpre = tr[tr[x].ch[1]].sum, rsuf = tr[tr[y].ch[1]].sum;
        tr[x].ch[0] = merge(tr[x].ch[0], tr[y].ch[0], l, mid, (tag1 + 1ll * rsuf % mod * (1 - v + mod)) % mod, 
                            (tag2 + 1ll * rpre % mod * (1 - v + mod)) % mod, v);
        tr[x].ch[1] = merge(tr[x].ch[1], tr[y].ch[1], mid + 1, r, (tag1 + 1ll * lsuf % mod * v) % mod, 
                            (tag2 + 1ll * lpre % mod * v) % mod, v);
        tr[x].sum = (tr[tr[x].ch[0]].sum + tr[tr[x].ch[1]].sum) % mod;
        return x;
    };

    std::function<void(int)> calc = [&] (int u) -> void {
        if (!cnt[u]) update(root[u], 1, idx, val[u], 1);
        if (cnt[u] == 1) calc(p[u][0]), root[u] = root[p[u][0]];
        if (cnt[u] == 2) calc(p[u][0]), calc(p[u][1]), root[u] = merge(root[p[u][0]], root[p[u][1]], 1, idx, 0, 0, val[u]);
    };

    std::vector<int> ans(idx + 1);

    std::function<void(int, int, int)> dfs = [&] (int u, int l, int r) -> void {
        if (!u) return ;
        if (l == r) return void(ans[l] = tr[u].sum);
        push(u);
        int mid = (l + r) >> 1;
        dfs(tr[u].ch[0], l, mid);
        dfs(tr[u].ch[1], mid + 1, r);
    };

    calc(1);
    dfs(root[1], 1, idx);
    i64 res = 0;
    for (int i = 1; i <= idx; i++) {
        res = (res + 1ll * i * tmp[i] % mod * ans[i] % mod * ans[i]) % mod;
    }
    std::cout << res << "\n";

标签:ch,int,Minimax,tr,dp,P5298,sum,PKUWC2018,mod
来源: https://www.cnblogs.com/Haven-/p/16670382.html