From 5c24ca38ad9c2bcc57d046be0ce48900c1bb2dde Mon Sep 17 00:00:00 2001 From: subcrip Date: Tue, 4 Jun 2024 23:22:06 +0800 Subject: [PATCH] Add trees/binary_trie.cc Signed-off-by: subcrip --- trees/binary_trie.cc | 75 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 trees/binary_trie.cc diff --git a/trees/binary_trie.cc b/trees/binary_trie.cc new file mode 100644 index 0000000..33b51b3 --- /dev/null +++ b/trees/binary_trie.cc @@ -0,0 +1,75 @@ +template +struct binary_trie { + vector, DataT, array>> tr; + binary_trie() : tr(1) {} + void emplace(ll x, const DataT& data = {}) { + int ptr = 0; + for (int i = N - 1; ~i; --i) { + int bit = x >> i & 1; + if (not std::get<0>(tr[ptr])[bit]) { + std::get<0>(tr[ptr])[bit] = tr.size(); + tr.emplace_back(); + } + std::get<2>(tr[ptr])[bit] += 1; + ptr = std::get<0>(tr[ptr])[bit]; + } + std::get<1>(tr[ptr]) = data; + } + + /// WARN: Don't try to erase a non-existent element! + void erase(ll x) { + int ptr = 0; + vector st; + for (int i = N - 1; ~i; --i) { + int bit = x >> i & 1; + st.emplace_back(ptr, bit); + ptr = std::get<0>(tr[ptr])[bit]; + } + while (st.size() and std::get<2>(tr[st.back().first])[st.back().second] == 1) { + std::get<0>(tr[st.back().first])[st.back().second] = 0; + std::get<2>(tr[st.back().first])[st.back().second] = 0; + st.pop_back(); + } + tr[ptr] = {}; + } + + optional get(ll x) { + int ptr = 0; + for (int i = N - 1; ~i; --i) { + int bit = x >> i & 1; + if (not std::get<0>(tr[ptr])[bit]) return {}; + ptr = std::get<0>(tr[ptr])[bit]; + } + return { std::get<1>(tr[ptr]) }; + } + + int count_prefix(ll x, int bits = N) { + int ptr = 0; + int res = 0; + for (int i = N - 1; i > max(-1, N - 1 - bits); --i) { + int bit = x >> 1 & 1; + if (not std::get<0>(tr[ptr])[bit]) return 0; + ptr = std::get<0>(tr[ptr])[bit]; + res = std::get<2>(tr[ptr])[bit]; + } + return res; + } + + optional> get_max_xor(ll x) { + int ptr = 0; + ll res = 0; + for (int i = N - 1; ~i; --i) { + int bit = x >> i & 1; + if (std::get<0>(tr[ptr])[1 ^ bit]) { + ptr = std::get<0>(tr[ptr])[1 ^ bit]; + res |= (1 ^ bit) << i; + } else if (std::get<0>(tr[ptr])[bit]) { + ptr = std::get<0>(tr[ptr])[bit]; + res |= bit << i; + } else { + return {}; + } + } + return {{ res, std::get<1>(tr[ptr]) }}; + } +};