跳转至

数据结构

数据结构

线段树

// Luogu P3373
template<typename T = int>
inline T read() {
    T ret;
    cin >> ret;
    return ret;
}

template<class Fun>
class Y_combinator {
private:
    Fun fun_;
public:
    template<class F>
    Y_combinator(F&& fun) : fun_(static_cast<F&&>(fun)) {}
    template<class... Args>
    decltype(auto) operator () (Args&&...args) const {
        return fun_(*this, static_cast<Args&&>(args)...);
    }
};
template<class T> Y_combinator(T)->Y_combinator<T>;

#define MID ((l + r) >> 1)
#define LEFT (cur << 1)
#define RIGHT ((cur << 1) | 1)

int main(int argc, char* argv[]) {
    fastIO();
    int n, m;
    cin >> n >> m;
    vector<ll> arr{0};
    for (int i = 0; i < n; ++i) {
        arr.push_back(read<ll>());
    }
    vector<ll> lazy((n << 2) + 10);
    vector<ll> node((n << 2) + 10);
    Y_combinator(
        [&](auto&& build, int cur, int l, int r) -> void {
            if (l == r) {
                node[cur] = arr[l];
            } else {
                build(LEFT, l, MID);
                build(RIGHT, MID + 1, r);
                node[cur] = node[LEFT] + node[RIGHT];
            }
        }
    )(1, 1, n);
    auto&& lazyUpdate = [&](int cur, int l, int r) -> void {
        if (lazy[cur] != 0) {
            node[LEFT] += lazy[cur] * (MID - l + 1);
            node[RIGHT] += lazy[cur] * (r - MID);
            lazy[LEFT] += lazy[cur];
            lazy[RIGHT] += lazy[cur];
            lazy[cur] = 0;
        }
    };
    auto&& update = Y_combinator(
        [&](auto&& update, int cur, int l, int r, int s, int e, ll v) {
            if (s > r or e < l) return;
            if (s <= l and r <= e) {
                node[cur] += (r - l + 1) * v;
                lazy[cur] += v;
            } else {
                lazyUpdate(cur, l, r);
                update(LEFT, l, MID, s, e, v);
                update(RIGHT, MID + 1, r, s, e, v);
                node[cur] = node[LEFT] + node[RIGHT];
            }
        }
    );
    auto&& query = Y_combinator(
        [&](auto&& query, int cur, int l, int r, int s, int e)->ll {
            if (s > r or e < l) return 0;
            if (s <= l and e >= r) {
                return node[cur];
            }
            lazyUpdate(cur, l, r);
            ll ret = query(LEFT, l, MID, s, e);
            ret += query(RIGHT, MID + 1, r, s, e);
            return ret;
        }
    );
    while (m--) {
        int q, x, y, k;
        cin >> q >> x >> y;
        if (q == 1) {
            cin >> k;
            update(1, 1, n, x, y, k);
        } else {
            cout << query(1, 1, n, x, y) << endl;
        }
    }
    return 0;
}

ST表(稀疏表)(C++17)

template<typename iter, typename BinOp>
class SparseTable {
    using T = typename remove_reference<decltype(*declval<iter>())>::type;
    vector<vector<T>> arr;
    BinOp binOp;
public:
    SparseTable(iter begin, iter end, BinOp binOp) : arr(1), binOp(binOp) {
        int n = distance(begin, end);
        arr.assign(32 - __builtin_clz(n), vector<T>(n));
        arr[0].assign(begin, end);
        for (int i = 1; i < arr.size(); ++i) {
            for (int j = 0; j < n - (1 << i) + 1; ++j) {
                arr[i][j] = binOp(arr[i - 1][j], arr[i - 1][j + (1 << (i - 1))]);
            }
        }
    }
    T query(int lPos, int rPos) {
        int h = floor(log2(rPos - lPos + 1));
        return binOp(arr[h][lPos], arr[h][rPos - (1 << h) + 1]);
    }
};

树状数组

template<typename T>
struct FenWick {
    int N;
    vector<T> arr;
    FenWick(int sz): N(sz), arr(sz + 1, 0) {}
    void update(int pos, T val) {
        for (; pos <= N;pos |= (pos + 1)) {
            arr[pos] += val;
        }
    }
    // 获取 [1, pos] 的和
    T get(int pos) {
        T ret = 0;
        for (; pos > 0; --pos) {
            ret += arr[pos];
            pos &= (pos + 1);
        }
        return ret;
    }
    // 获取 [l, r] 的和
    T query(int l, int r) {
        return get(r) - get(l - 1);
    }
};

珂朵莉树

namespace Chtholly{
struct Node{
    int l, r;
    mutable int v;
    Node(int il, int ir, int iv): l(il), r(ir), v(iv){}
    bool operator < (const Node& arg) const{
        return l < arg.l;
    }
};
class Tree{
protected:
    auto split(int pos){
        if(pos > _sz) return odt.end();
        auto it = --odt.upper_bound(Node{pos, 0, 0});
        if(it->l == pos) return it;
        auto tmp = *it;
        odt.erase(it);
        odt.insert({tmp.l, pos - 1, tmp.v});
        return odt.insert({pos, tmp.r, tmp.v}).first;
    }  
public:
    Tree(int sz, int ini = 1): _sz(sz), odt({Node{1, sz, ini}}) {}
    virtual void assign(int l, int r, int v){
        auto itr = split(r + 1), itl = split(l);
        // operations here
        odt.erase(itl, itr);
        odt.insert({l, r, v});
    }
protected:
    int _sz;
    set<Node> odt;
};
}

Splay树

https://loj.ac/p/104 有误,暂未修

#include <vector>
#include <array>
#include <iostream>
#include <cassert>
using namespace std;
template<typename T>
class SplayTree{
public:
    struct Node{
        Node *parent{};
        std::array<Node*, 2> child{};
        T val;
        // cnt: repeat of current element, sz: element count of child tree, sum: repeats of child tree
        size_t cnt, sz, sum;
        explicit Node(T value_arg): val(value_arg), cnt(1), sz(1), sum(1){}
        bool side() const{
            return parent->child[1] == this;
        }
        // maintain information of current element
        void maintain(){
            if(!this) return;
            this->sum = this->cnt;
            this->sz = 1;
            if(this->child[0]) {
                this->sum += this->child[0]->sum;
                this->sz += this->child[0]->sz;
            }
            if(this->child[1]) {
                this->sum += this->child[1]->sum;
                this->sz += this->child[1]->sz;
            }
        }
        // left rotate and right rotate
        void rotate(){
            const auto p = parent;
            const bool i = side();
            if(p->parent){
                p->parent->attach(p->side(), this);
            }else{
                parent = nullptr;
            }
            p->attach(i, child[!i]);
            attach(!i, p);
            p->maintain();
            maintain();
        }
        void splay(){
            for(;parent;rotate()){
                if(parent->parent){
                    (side() == parent->side() ? parent: this)->rotate();
                }
            }
        }
        // attach node new_ as the node's side child
        void attach(bool side, Node* const new_){
            if(new_) new_->parent = this;
            child[side] = new_;
        }
    };
    struct iterator{
        using iterator_category = std::bidirectional_iterator_tag;
        using value_type = T;
        using pointer = T*;
        using reference = T&;
        using difference_type = long long;
    public:
        Node* node;
        void operator--(){ advance<false>();}
        void operator++(){ advance<true>();}
        const T& operator*(){return node->val;}
        explicit iterator(Node* node_arg): node(node_arg){}
        bool operator==(const iterator oth) const{
            return node == oth.node;
        }
        bool operator != (const iterator oth) const{
            return *this != oth;
        }
    private:
        template<bool dir> void advance(){
            if(node->child[dir]){
                node = extremum<!dir>(node->child[dir]);
                return;
            }
            for(;node->parent and node->side() == dir; node = node->parent);
            node = node->parent;
        }
    };

    template<bool i> static Node* extremum(Node* x){
        assert(x);
        for(;x->child[i]; x = x->child[i]);
        return x;
    }
    Node* rt{};
    explicit SplayTree()= default;
    ~SplayTree(){ destroy(rt);}
    void insert(const T& arg){
        if(!rt){
            rt = new Node(arg);
            rt->maintain();
            return;
        }
        Node* cur = rt, *f = nullptr;
        while(true){
            if(cur->val == arg){
                cur->cnt++;
                cur->maintain();
                f->maintain();
                cur->splay();
                rt = cur;
                break;
            }
            f = cur;
            cur = cur->child[cur->val < arg];
            if(!cur){
                Node* tmp = new Node(arg);
                f->child[f->val < arg] = tmp;
                tmp->parent = f;
                tmp->maintain();
                f->maintain();
                tmp->splay();
                rt = tmp;
                break;
            }
        }
    }

    // size, sum
    std::pair<size_t, size_t> rank(const T& arg){
        std::pair<size_t, size_t> res{0, 0};
        Node* cur = rt;
        while(cur){
            if(arg < cur->val){
                cur = cur->child[0];
            }else{
                if(cur->child[0]) {
                    res.first += cur->child[0]->sz;
                    res.second += cur->child[0]->sum;
                }
                res.first ++;
                res.second += cur->cnt;
                if(arg == cur->val){
                    cur->splay();
                    rt = cur;
                    break;
                }
                cur = cur->child[1];
            }
        }
        return res;
    }
    template<bool unique = false>
    iterator kth(size_t k){
        assert(k <= (rt != nullptr ? (unique ? rt->sz : rt->sum) : 0));
        Node* cur = rt;
        while(true){
            if(cur->child[0] and k <= (unique ? cur->child[0]->sz : cur->child[0]->sum)){
                cur = cur->child[0];
            }else{
                if(cur->child[0]) k -= (unique ? cur->child[0]->sz : cur->child[0]->sum);
                if(k <= cur->cnt){
                    cur->splay();
                    rt = cur;
                    return iterator{cur};
                }
                k -= (unique ? 1 : cur->cnt);
                cur = cur->child[1];
            }
        }
    }
    static void destroy(Node* const node){
        if(!node) return;
        for(Node* const child: node->child){
            destroy(child);
        }
        delete node;
    }
    bool empty() const{
        return rt == nullptr;
    }
    size_t sum() const{
        return (rt == nullptr ? 0 : rt->sum);
    }
    size_t size() const{
        return (rt == nullptr ? 0 : rt->sz);
    }

    template<bool side = false>
    iterator begin(){
        return iterator{extremum<side>(rt)};
    }
    iterator rend(){
        return iterator{nullptr};
    }
    iterator end(){
        return iterator{nullptr};
    }
    iterator find(const T& key){
        Node* cur = rt;
        while(cur and key != cur->val){
            const auto nex = cur->child[key > cur->val];
            if(!nex) {
                cur->splay();
                rt = cur;
            }
            cur = nex;
        }
        return iterator{cur};
    }
    iterator lower_bound(const T& key){
        Node* cur = rt;
        Node* ret = nullptr;
        while(cur){
            if(cur->val > key){
                ret = cur;
                cur = cur->child[0];
            }else if(cur->val == key){
                ret = cur;
                break;
            }else{
                cur = cur->child[1];
            }
        }
        if(ret){
            ret->splay();
            rt = ret;
        }
        return iterator{ret};
    }
    Node* join(Node* const arg1, Node* const arg2){
        if(!arg1){
            arg2->parent = nullptr;
            return arg2;
        }
        arg1->parent = nullptr;
        Node* const mx = extremum<true>(arg1);
        mx->splay();
        rt = mx;
        assert(mx->child[1] == nullptr);
        mx->child[1] = arg2;
        mx->parent = nullptr;
        if(arg2) arg2->parent = mx;
        mx->maintain();
        return mx;
    }
    void erase(const iterator itr){
        if(!itr.node) return;
        Node* x = itr.node;
        x->splay();
        rt = x;
        rt = join(x->child[0], x->child[1]);
    }
    void extract(const iterator itr){
        if(!itr.node) return;
        if(itr.node->cnt == 1) erase(itr);
        else{
            itr.node->cnt--;
            itr.node->splay();
            rt = itr.node;
        }
    }
};
typedef pair<int, int> PII;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    SplayTree<int> st;
    int n;
    cin >> n;
    while(n--){
        int op;
        cin >> op;
        if(op == 1){
            int tv;
            cin >> tv;
            st.insert(tv);
        }else if(op == 2){
            int tv;
            cin >> tv;
            st.extract(st.find(tv));
        }else if(op == 3){
            int tv;
            cin >> tv;
            auto itr = st.find(tv);
            auto res = st.rank(tv);
            cout << res.second - (itr.node->cnt) + 1 << endl;
        }else if(op == 4){
            int tv;
            cin >> tv;
            auto itr = st.kth(tv);
            cout << itr.node->val << endl;
        }else if(op == 5){
            int tv;
            cin >> tv;
            auto itr = st.lower_bound(tv);
            if(itr == st.end()) itr = st.begin<true>();
            else --itr;
            cout << itr.node->val << endl;
        }else{
            int tv;
            cin >> tv;
            auto itr = st.lower_bound(tv);
            if(itr.node->val == tv) ++itr;
            cout << itr.node->val << endl;
        }
    }
    return 0;
}

AVL树

/**
 * @brief AVL 树
 * @author dianhsu
 * @date 2021/05/25
 * @ref https://zh.wikipedia.org/wiki/AVL树
 * */
#include <bits/stdc++.h>

template<class T>
struct AVLNode {
    T data;
    AVLNode<T> *leftChild;
    AVLNode<T> *rightChild;
    int height;

    AVLNode(T data): data(data), height(1), leftChild(nullptr), rightChild(nullptr) { }

    ~AVLNode() {
        delete leftChild;
        delete rightChild;
    }
};

template<class T>
class AVL {
public:
    AVL() {
        root = nullptr;
    }

    ~AVL() {
        delete root;
    }

    /**
     * @brief 将结点插入到AVL树中
     * @param val 需要插入的值
     * @note 如果发现这个树中已经有这个值存在了,就不会进行任何操作
     * */
    void insert(T val) {
        _insert(&root, val);
    }

    /**
     * @brief 检查结点是否在AVL树中
     * @param val 需要检查的值
     * */
    bool exist(T val) {
        auto ptr = &root;
        while (*ptr != nullptr) {
            if (val == (*ptr)->data) {
                return true;
            } else if (val < (*ptr)->data) {
                *ptr = (*ptr)->leftChild;
            } else {
                *ptr = (*ptr)->rightChild;
            }
        }
        return false;
    }

    /**
     * @brief 找到值为val的结点
     * @param val 目标值
     * @return 返回值为指向该结点的指针的地址
     */
    AVLNode<T> **find(T val) {
        auto ptr = &root;
        while ((*ptr) != nullptr) {
            if (val == (*ptr)->data) {
                break;
            } else if (val < (*ptr)->data) {
                *ptr = (*ptr)->leftChild;
            } else {
                *ptr = (*ptr)->rightChild;
            }
        }
        return ptr;
    }

    /**
     * @brief 删除结点
     * @note 首先找到结点,然后将结点旋转到叶子结点,然后回溯检查树的平衡性
     * @param val 需要删除的结点的值
     * @note 这个地方需要递归寻找该值的结点,因为需要回溯更新平衡树
     * */
    void remove(T val) {
        _remove(&root, val);
    }


private:
    void _remove(AVLNode<T> **ptr, T val) {
        if (*ptr == nullptr) {
            return;
        }
        if ((*ptr)->data == val) {
            _rotateNodeToLeaf(ptr);
        } else if ((*ptr)->data < val) {
            _remove(&((*ptr)->rightChild), val);
        } else {
            _remove(&((*ptr)->leftChild), val);
        }
        // 完了之后回溯,重新平衡二叉树
        _balance(ptr);
        _updateHeight(*ptr);
    }

    /**
     * @brief 将一个结点旋转到叶子结点
     * @param ptr 将要被旋转至叶子的结点的指针的地址
     * @note 旋转的时候,将当前结点旋转到高度比较小的一边。
     */
    void _rotateNodeToLeaf(AVLNode<T> **ptr) {
        // 当前结点已经是叶子结点了
        if ((*ptr)->leftChild == nullptr and (*ptr)->rightChild == nullptr) {
            *ptr = nullptr;
            return;
        }
        int leftHeight = (*ptr)->leftChild != nullptr ? (*ptr)->leftChild->height : 0;
        int rightHeight = (*ptr)->rightChild != nullptr ? (*ptr)->rightChild->height : 0;
        // 左边高度比较小,左旋
        if (leftHeight <= rightHeight) {
            _leftRotate(ptr);
            _rotateNodeToLeaf(&((*ptr)->leftChild));
        } else {
            // 右旋
            _rightRotate(ptr);
            _rotateNodeToLeaf(&((*ptr)->rightChild));
        }
        _balance(ptr);
        _updateHeight(*ptr);
    }

    /**
     * @brief 插入结点
     *
     * */
    void _insert(AVLNode<T> **ptr, T val) {
        if (*ptr == nullptr) {
            *ptr = new AVLNode<T>(val);
            return;
        }
        if (val < (*ptr)->data) {
            _insert(&((*ptr)->leftChild), val);
        } else if (val > (*ptr)->data) {
            _insert(&((*ptr)->rightChild), val);
        } else {
            // 如果当前平衡二叉树中已经存在这个结点了,不做任何处理
            return;
        }
        _balance(ptr);
        _updateHeight(*ptr);
    }

    /**
     * @brief 平衡结点
     *
     * */
    void _balance(AVLNode<T> **ptr) {
        if (*ptr == nullptr) return;
        int leftHeight = (*ptr)->leftChild != nullptr ? (*ptr)->leftChild->height : 0;
        int rightHeight = (*ptr)->rightChild != nullptr ? (*ptr)->rightChild->height : 0;
        if (abs(leftHeight - rightHeight) <= 1) return;

        if (leftHeight < rightHeight) {
            auto rightElement = (*ptr)->rightChild;
            int rightElementLeftHeight = rightElement->leftChild != nullptr ? rightElement->leftChild->height : 0;
            int rightElementRightHeight = rightElement->rightChild != nullptr ? rightElement->rightChild->height : 0;
            if (rightElementLeftHeight < rightElementRightHeight) {
                // RR
                _leftRotate(ptr);
            } else {
                // RL
                _rightRotate(&((*ptr)->rightChild));
                _leftRotate(ptr);
            }
        } else {
            auto leftElement = (*ptr)->leftChild;
            int leftElementLeftHeight = leftElement->leftChild != nullptr ? leftElement->leftChild->height : 0;
            int leftElementRightHeight = leftElement->rightChild != nullptr ? leftElement->rightChild->height : 0;
            if (leftElementLeftHeight > leftElementRightHeight) {
                // LL
                _rightRotate(ptr);
            } else {
                // LR
                _leftRotate(&((*ptr)->leftChild));
                _rightRotate(ptr);
            }
        }
    }

    /**
     * @brief 右旋
     *
     * */
    void _rightRotate(AVLNode<T> **ptr) {
        auto tmp = (*ptr)->leftChild;
        (*ptr)->leftChild = tmp->rightChild;
        tmp->rightChild = *ptr;
        _updateHeight(tmp);
        _updateHeight(*ptr);
        *ptr = tmp;
    }

    /**
     * @brief 左旋
     * */
    void _leftRotate(AVLNode<T> **ptr) {
        auto tmp = (*ptr)->rightChild;
        (*ptr)->rightChild = tmp->leftChild;
        tmp->leftChild = *ptr;
        _updateHeight(tmp);
        _updateHeight(*ptr);
        *ptr = tmp;
    }

    void _updateHeight(AVLNode<T> *ptr) {
        if (ptr == nullptr) return;
        int leftHeight = ptr->leftChild != nullptr ? ptr->leftChild->height : 0;
        int rightHeight = ptr->rightChild != nullptr ? ptr->rightChild->height : 0;
        ptr->height = std::max(leftHeight, rightHeight) + 1;
    }

    AVLNode<T> *root;
};

int main() {
    auto avl = new AVL<int>();
    int n = 20;
    std::random_device rd{};
    std::mt19937 gen{rd()};
    std::normal_distribution<> d{100, 100};
    std::uniform_int_distribution<int> u(0, INT_MAX >> 1);
    std::vector<int> vec;
    for (int i = 0; i < n; ++i) {
        vec.push_back((int) std::round(d(gen)));
        //vec.push_back(i);
    }
    for (auto it : vec) {
        avl->insert(it);
    }
    avl->remove(15);
    avl->remove(32);
    avl->remove(31);
    std::cout << *avl << std::endl;
    delete avl;
    return 0;
}

评论