library/datastructure/BinaryTrie.hpp
Depends on
Verified with
Code
template <int LOG, typename COUNT> class BinaryTrie {
static_assert(LOG <= 64, "Binary Trie overflow");
using T = std::conditional_t<LOG <= 32, unsigned int, unsigned long long>;
struct Node {
std::array<int, 2> nxt_node;
COUNT count; //
Node() : count(0) { std::ranges::fill(nxt_node, -1); }
};
std::vector<Node> nodes;
int &nxt(int now, bool f) { return nodes[now].nxt_node[f]; }
bool bit(const T &a, int i) const { return (a >> i) & 1; }
public:
BinaryTrie() : nodes(1, Node()) {}
int add(const T &a, COUNT num = 1) {
int now = 0;
for (int i = LOG - 1; i >= 0; i--) {
if (!~nxt(now, bit(a, i))) {
nxt(now, bit(a, i)) = nodes.size();
nodes.emplace_back();
}
nodes[now].count += num;
now = nxt(now, bit(a, i));
}
nodes[now].count += num;
return now;
}
int node_idx(const T &a) {
int now = 0;
for (int i = LOG - 1; i >= 0; i--) {
if (!~nxt(now, bit(a, i)))
return -1;
now = nxt(now, bit(a, i));
}
return now;
}
COUNT count(const T &a) {
int id = node_idx(a);
return (~id ? nodes[id].count : 0);
}
COUNT size() { return nodes[0].count; }
// 数列の各数に xor_add をした後、0-indexed で昇順 k 番目を出力
T k_th(COUNT k, T xor_add = 0) {
assert(size() > k);
T res = 0;
int now = 0;
for (int i = LOG - 1; i >= 0; i--) {
int f = bit(xor_add, i);
int s = f ^ 1;
if (nxt(now, f) == -1) {
now = nxt(now, s);
res += T{1} << i;
continue;
}
if (nodes[nxt(now, f)].count <= k) {
k -= nodes[nxt(now, f)].count;
now = nxt(now, s);
res += T{1} << i;
} else
now = nxt(now, f);
}
return res;
}
T min(T xor_add = 0) { return k_th(0, xor_add); }
T max(T xor_add = 0) { return k_th(size() - 1, xor_add); }
};
#line 1 "library/datastructure/BinaryTrie.hpp"
template <int LOG, typename COUNT> class BinaryTrie {
static_assert(LOG <= 64, "Binary Trie overflow");
using T = std::conditional_t<LOG <= 32, unsigned int, unsigned long long>;
struct Node {
std::array<int, 2> nxt_node;
COUNT count; //
Node() : count(0) { std::ranges::fill(nxt_node, -1); }
};
std::vector<Node> nodes;
int &nxt(int now, bool f) { return nodes[now].nxt_node[f]; }
bool bit(const T &a, int i) const { return (a >> i) & 1; }
public:
BinaryTrie() : nodes(1, Node()) {}
int add(const T &a, COUNT num = 1) {
int now = 0;
for (int i = LOG - 1; i >= 0; i--) {
if (!~nxt(now, bit(a, i))) {
nxt(now, bit(a, i)) = nodes.size();
nodes.emplace_back();
}
nodes[now].count += num;
now = nxt(now, bit(a, i));
}
nodes[now].count += num;
return now;
}
int node_idx(const T &a) {
int now = 0;
for (int i = LOG - 1; i >= 0; i--) {
if (!~nxt(now, bit(a, i)))
return -1;
now = nxt(now, bit(a, i));
}
return now;
}
COUNT count(const T &a) {
int id = node_idx(a);
return (~id ? nodes[id].count : 0);
}
COUNT size() { return nodes[0].count; }
// 数列の各数に xor_add をした後、0-indexed で昇順 k 番目を出力
T k_th(COUNT k, T xor_add = 0) {
assert(size() > k);
T res = 0;
int now = 0;
for (int i = LOG - 1; i >= 0; i--) {
int f = bit(xor_add, i);
int s = f ^ 1;
if (nxt(now, f) == -1) {
now = nxt(now, s);
res += T{1} << i;
continue;
}
if (nodes[nxt(now, f)].count <= k) {
k -= nodes[nxt(now, f)].count;
now = nxt(now, s);
res += T{1} << i;
} else
now = nxt(now, f);
}
return res;
}
T min(T xor_add = 0) { return k_th(0, xor_add); }
T max(T xor_add = 0) { return k_th(size() - 1, xor_add); }
};
Back to top page