其他分享
首页 > 其他分享> > CF620E NewYearTree

CF620E NewYearTree

作者:互联网

题目链接

  主要要实现区间覆盖区间查询不同数,看见区间赋值操作可能会想到\(ODT\)来实现,区间查询不同数直接另外开一个数组记录一下就好了,但很可惜\(TLE\)了,代码如下:

    struct ODT {
        struct Node {
            i64 l, r;
            mutable i64 v;

            Node (i64 l, i64 r = 0, i64 v = 0) : l(l), r(r), v(v) {}

            bool operator < (const Node& lhs) const {
                return l < lhs.l;
            }
        };
        std::set<Node> s;

        std::set<Node>::iterator split(int pos) { // 分裂区间
            std::set<Node>::iterator it = s.lower_bound(Node(pos));
            if (it -> l == pos && it != s.end()) 
                return it;
            -- it;
            if (it -> r < pos) 
                return s.end();

            i64 l = it -> l, r = it -> r, v = it -> v;
            s.erase(it);
            s.insert(Node(l, pos - 1, v));
            return s.insert(Node(pos, r, v)).first;
        }

        void assign(int l, int r, i64 x) {    //区间推平
            std::set<Node>::iterator itr = split(r + 1), itl = split(l);
            s.erase(itl, itr);
            s.insert(Node(l, r, x));
        }
        
        i64 query(i64 l, i64 r) { // 区间查询
            auto itr = split(r + 1), itl = split(l);
            std::vector<int> ans(100);
            int res = 0;
            for (auto it = itl; it != itr; ++it) {
                if(!ans[it -> v]) res ++;
                ans[it -> v]++;
            }
            return res;
        }
    };

    void solve() {
        int n, m;
        std::cin >> n >> m;
        std::vector<int> col(n + 1);
        for (int i = 1; i <= n; i++) std::cin >> col[i];

        std::vector<int> G[n + 1];
        for (int i = 1; i < n; i++) {
            int a, b;
            std::cin >> a >> b;
            G[a].push_back(b);
            G[b].push_back(a);
        }

        // siz是子树大小,son是重儿子,dep是这个节点的深度,dfn是dfs序,top是链的端点
        std::vector<int> parent(n + 1), siz(n + 1), son(n + 1), dep(n + 1);
        std::vector<int> dfn(n + 1), top(n + 1), rnk(n + 1);
        int idx = 0;

        std::function<void(int, int, int)> dfs1 = [&] (int u, int fa, int depth) {  
            //预处理出来轻重链
            parent[u] = fa;
            dep[u] = depth;
            siz[u] = 1;
            for (auto v : G[u]) {
                if (v == fa) continue;
                dfs1(v, u, depth + 1);
                siz[u] += siz[v];
                if (siz[v] > siz[son[u]])
                    son[u] = v;
            }
        };

        std::function<void(int, int)> dfs2 = [&] (int u, int t) -> void {   //剖分
            dfn[u] = ++ idx;
            top[u] = t;
            rnk[idx] = u;

            if (!son[u]) return ;
            dfs2(son[u], t);

            for (auto v : G[u]) {
                 if (v == parent[u] || v == son[u]) continue;
                 dfs2(v, v);
            }
        };

        dfs1(1, 0, 1);
        dfs2(1, 1);

        ODT odt;
        for (int i = 1; i <= n; i++) {
            int ver = rnk[i];
            odt.s.insert(ODT::Node(dfn[ver], dfn[ver], col[ver]));
        }

        for (int i = 0; i < m; i++) {
            int op, u;
            std::cin >> op >> u;
            if (op == 1) {
                int cc;
                std::cin >> cc;
                odt.assign(dfn[u], dfn[u] + siz[u] - 1, cc);
            } else {
                std::cout << odt.query(dfn[u], dfn[u] + siz[u] - 1) << "\n";
            }
        }
    }

  既然\(TLE\)了那还是想想线段树怎么做吧,注意到数据范围只有\(0\leq c \leq 60\),而\(long long\)类型的数据最多可以存\(63\)位,我们就可以通过与运算来实现对一个区间内所有的数的种类进行统计,区间查询的答案就是二进制下\(1\)的个数。

    struct Seg {
        struct Node {
            int l, r;
            i64 val, tag;
        };

        const int n;
        std::vector<Node> tr;

        Seg(int n) : n(n), tr(4 << std::__lg(n)) {
            std::function<void(int, int, int)> build = [&] (int u, int l, int r) -> void {
                tr[u] = {l, r};
                if (l == r) return ;
                int mid = (l + r) >> 1;
                build(u << 1, l, mid);
                build(u << 1 | 1, mid + 1, r);
                tr[u].val = tr[u << 1].val | tr[u << 1 | 1].val;
            };
            build(1, 1, n);
        }

        void spread(int u, i64 x) {
            tr[u].val = tr[u].tag = x;
        }

        void push(int u) {
            if (tr[u].tag) {
                spread(u << 1, tr[u].tag);
                spread(u << 1 | 1, tr[u].tag);
                tr[u].tag = 0;
            }
        }

        void modify(int u, int pos, i64 x) {
            if (tr[u].l == tr[u].r && tr[u].l == pos) 
                return void(tr[u].val = x);
            int mid = (tr[u].l + tr[u].r) >> 1;
            if (mid >= pos) modify(u << 1, pos, x);
            else modify(u << 1 | 1, pos, x);
            tr[u].val = tr[u << 1].val | tr[u << 1 | 1].val;
        }

        void change(int u, int l, int r, int x) {
            if (tr[u].l >= l && tr[u].r <= r) {
                tr[u].val = (1ll << x);
                tr[u].tag = (1ll << x);
                return ;
            }
            push(u);
            int mid = (tr[u].l + tr[u].r) >> 1;
            if (mid >= l) change(u << 1, l, r, x);
            if (mid < r) change(u << 1 | 1, l, r, x);
            tr[u].val = tr[u << 1].val | tr[u << 1 | 1].val;
        }

        i64 query(int u, int l, int r) {
            if (tr[u].l >= l && tr[u].r <= r) return tr[u].val;
            push(u);
            int mid = (tr[u].l + tr[u].r) >> 1;
            i64 ans = 0;
            if (mid >= l) ans |= query(u << 1, l, r);
            if (mid < r) ans |= query(u << 1 | 1, l, r);
            return ans;
        }
    };

    void solve() {
        int n, m;
        std::cin >> n >> m;
        std::vector<int> col(n + 1);
        for (int i = 1; i <= n; i++) std::cin >> col[i];

        std::vector<int> G[n + 1];
        for (int i = 1; i < n; i++) {
            int a, b;
            std::cin >> a >> b;
            G[a].push_back(b);
            G[b].push_back(a);
        }

            // siz是子树大小,son是重儿子,dep是这个节点的深度,dfn是dfs序,top是链的端点
        std::vector<int> parent(n + 1), siz(n + 1), son(n + 1), dep(n + 1);
        std::vector<int> dfn(n + 1), top(n + 1);
        std::vector<int> rnk(n + 1);
        int idx = 0;

        std::function<void(int, int, int)> dfs1 = [&] (int u, int fa, int depth) {  
        //预处理出来轻重链
            parent[u] = fa;
            dep[u] = depth;
            siz[u] = 1;
            for (auto v : G[u]) {
                if (v == fa) continue;
                dfs1(v, u, depth + 1);
                siz[u] += siz[v];
                if (siz[v] > siz[son[u]])
                    son[u] = v;
            }
        };

        std::function<void(int, int)> dfs2 = [&] (int u, int t) -> void {   //剖分
            dfn[u] = ++ idx;
            top[u] = t;
            rnk[idx] = u;

            if (!son[u]) return ;
            dfs2(son[u], t);

            for (auto v : G[u]) {
                 if (v == parent[u] || v == son[u]) continue;
                 dfs2(v, v);
            }
        };

        dfs1(1, 0, 1);
        dfs2(1, 1);

        Seg SGT(n + 1);
        for (int i = 1; i <= n; i++) {
            int ver = rnk[i];
            SGT.modify(1, i, (1ll << col[ver]));
        }

        for (int i = 1; i <= m; i++) {
            int op, u;
            std::cin >> op >> u;
            if (op == 1) {
                int cc;
                std::cin >> cc;
                SGT.change(1, dfn[u], dfn[u] + siz[u] - 1, cc);
            } else {
                std::cout << __builtin_popcountll(SGT.query(1, dfn[u], dfn[u] + siz[u] - 1)) << "\n";
            }
        }
    }

标签:std,CF620E,siz,son,i64,int,vector,NewYearTree
来源: https://www.cnblogs.com/Haven-/p/16644217.html