其他分享
首页 > 其他分享> > HDU7162. Equipment Upgrade (2022杭电多校第3场1001)

HDU7162. Equipment Upgrade (2022杭电多校第3场1001)

作者:互联网

HDU7162. Equipment Upgrade (2022杭电多校第3场1001)

题意

有一件装备,一开始是 \(0\) 级,可以强化它,当它在第 \(i\) 级时,需要花费 \(c_i\) 强化它,有 \(p_i\) 的概率强化成功(升高一级),\(1-p_i\) 的概率强化失败(降 \(1\) 至 \(i\) 级),其中降 \(j\) 级的权重为 \(w_j\),也就是说有 \((1-p_i)\frac{w_j}{\sum\limits_{k=1}^i w_k}\) 的概率降 \(j\) 级。求从 \(0\) 级升级到 \(n\) 级的期望花费。

分析

设 \(f(i)\) 为处于第 \(i\) 级的期望代价,即要算 \(f(0)\)。

记 \(S_i=\sum\limits_{j=1}^i w_j\)

根据题意,可以写出式子

\[f(i)=p_if(i+1)+\frac{1-p_i}{S_i}\sum_{j=1}^i w_jf(i-j)+c_i \]

上面的式子后面的和式已经出现了卷积形式,但是不便于求解。

容易发现,上面式子包含 \(f(0),f(1),...,f(i+1)\)。稍作移项,得到

\[f(i+1)=\frac{f(i)-c_i-\frac{1-p_i}{S_i}\sum\limits_{j=1}^i w_jf(i-j)}{p_i} \]

这说明如果已知 \(f(0),f(1),...,f(i)\),就能线性地推出 \(f(i+1)\)

而现在已知 \(f(n)\) 是 \(0\),要求 \(f(0)\)。我们可以尝试将 \(f(i)\) 线性地用 \(f(0)\) 表示,具体地说就是设

\[f(i)=a_i f(0)+b_i \]

我们只要求出 \(a_i\) 和 \(b_i\),最终答案就是 \(-\frac{b_n}{a_n}\)

根据 \(f(i)\) 的递推式和所设 \(a_i\) 与 \(b_i\) 的含义,可以推得 \(a_i\) 和 \(b_i\) 的递推式。

\[a_{i+1}=\frac{a_i-\frac{1-p_i}{S_i}\sum\limits_{j=1}^i w_j a_{i-j}}{p_i}\\ b_{i+1}=\frac{b_i-c_i-\frac{1-p_i}{S_i}\sum\limits_{j=1}^i w_j b_{i-j}}{p_i} \]

这样就可以CDQ分治配合卷积 \(O(n\log^2n)\) 求解 \(a_i\) 和 \(b_i\) 了。

CDQ分治的时候注意 \(w_i\) 和 \(a_j\) / \(b_j\) 贡献在了 \(a_{i+j+1}\) / \(b_{i+j+1}\) 上。调用NTT卷积时,要仔细计算好相应的位置。

代码

#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
namespace NTT {
typedef int Lint;
typedef long long LLint;
// 2的幂次
const int maxn = (1 << 21) + 10;
const Lint mod = 998244353;
const Lint g = 3;
Lint fpow(Lint a, Lint b, Lint mod) {
    Lint res = 1;
    for (; b; b >>= 1) {
        if (b & 1)
            res = (LLint)res * a % mod;
        a = (LLint)a * a % mod;
    }
    return res;
}
inline Lint add(Lint a, Lint b) {
    a += b;
    return a >= mod ? a - mod : a;
}
inline Lint mul(Lint a, Lint b) {
    return (LLint)a * b % mod;
}
int r[maxn];
void cal_r(int n) {
    for (int i = 0; i < n; i++) {
        r[i] = (i & 1) * (n >> 1) + (r[i >> 1] >> 1);
    }
}
void dft(Lint* a, int n, int type) {
    for (int i = 0; i < n; i++)
        if (i < r[i])
            swap(a[i], a[r[i]]);
    for (int i = 1; i < n; i <<= 1) {
        int p = i << 1;
        Lint w = fpow(g, (mod - 1) / p, mod);
        if (type == -1)
            w = fpow(w, mod - 2, mod);
        for (int j = 0; j < n; j += p) {
            Lint t = 1;
            for (int k = 0; k < i; k++, t = mul(t, w)) {
                Lint tmp = mul(a[j + k + i], t);
                a[j + k + i] = add(a[j + k], mod - tmp);
                a[j + k] = add(a[j + k], tmp);
            }
        }
    }
    if (type == -1) {
        Lint inv = fpow(n, mod - 2, mod);
        for (int i = 0; i < n; i++)
            a[i] = mul(a[i], inv);
    }
}
Lint p[maxn], q[maxn];
vector<Lint> poly_mul(const vector<Lint>& a, const vector<Lint>& b) {
    vector<Lint> res;
    int n = a.size(), m = b.size();
    res.resize(n + m - 1);
    int len = n + m - 1;
    int lim = 1;
    while (lim < len)
        lim <<= 1;
    copy(a.begin(), a.end(), p);
    fill(p + n, p + lim, 0);
    copy(b.begin(), b.end(), q);
    fill(q + m, q + lim, 0);
    cal_r(lim);
    dft(p, lim, 1), dft(q, lim, 1);
    for (int i = 0; i < lim; i++)
        p[i] = mul(p[i], q[i]);
    dft(p, lim, -1);
    for (int i = 0; i < n + m - 1; i++)
        res[i] = p[i];
    return res;
}

};  // namespace NTT

int inv(int a) {
    return NTT::fpow(a, NTT::mod - 2, NTT::mod);
}
using NTT::add;
using NTT::mod;
using NTT::mul;
using NTT::poly_mul;

int n;
const int maxn = (1 << 18) + 10;
const int inv100 = 828542813;
int lim;
int w[maxn], p[maxn], sum[maxn], inv_sum[maxn], inv_p[maxn];
int a[maxn], b[maxn], c[maxn];
void solve_ab(int l, int r) {
    if (l == r) {
        if (l == 0)
            return;
        a[l] = mul(add(a[l - 1], mod - mul(add(1, mod - p[l - 1]), mul(inv_sum[l - 1], a[l]))), inv_p[l - 1]);
        b[l] = mul(add(add(b[l - 1], mod - c[l - 1]), mod - mul(add(1, mod - p[l - 1]), mul(inv_sum[l - 1], b[l]))), inv_p[l - 1]);
        return;
    }
    int mid = l + r >> 1;
    solve_ab(l, mid);

    vector<int> P(r - l), Q(mid - l + 1);

    for (int i = 0; i < r - l; i++)
        P[i] = w[i];
    for (int i = l; i <= mid; i++)
        Q[i - l] = a[i];
    vector<int> res = poly_mul(P, Q);
    for (int i = mid + 1; i <= r; i++) {
        a[i] = add(a[i], res[i - l - 1]);
    }

    for (int i = 0; i < r - l; i++)
        P[i] = w[i];
    for (int i = l; i <= mid; i++)
        Q[i - l] = b[i];
    res = poly_mul(P, Q);
    for (int i = mid + 1; i <= r; i++) {
        b[i] = add(b[i], res[i - l - 1]);
    }
    
    solve_ab(mid + 1, r);
}
void solve() {
    cin >> n;
    for (int i = 0; i < n; i++) {
        cin >> p[i] >> c[i];
        p[i] = mul(p[i], inv100);
        inv_p[i] = inv(p[i]);
    }
    for (int i = 1; i <= n - 1; i++) {
        cin >> w[i];
        sum[i] = add(sum[i - 1], w[i]);
        inv_sum[i] = inv(sum[i]);
    }
    fill(a + 1, a + 1 + n, 0);
    fill(b + 1, b + 1 + n, 0);
    solve_ab(0, n);
    cout << mul(mod - b[n], inv(a[n])) << '\n';
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    a[0] = 1;
    int T;
    cin >> T;
    while (T--)
        solve();
    return 0;
}

标签:杭电多校,Upgrade,frac,limits,HDU7162,int,res,sum,Lint
来源: https://www.cnblogs.com/Bamboo-Wind/p/16522926.html