Skip to the content.

:question: library/graph/MinimumSpanningArborescence.hpp

Depends on

Verified with

Code

#pragma once
#include "library/datastructure/unionfind/UnionFind.hpp"
template <typename WG, typename W = typename WG::weight_type>
std::optional<std::pair<W, std::vector<int>>>
minimum_spanning_arborescence(WG g, int r = 0) {
    int n = g.n;
    W res = 0;
    std::vector<W> new_add(n, 0);
    std::vector<int> tree(n), pre(n), state(n, 0);
    UnionFind uf(n);
    state[r] = 2;

    auto compare = [&](const int &a, const int &b) {
        return g.edges[a].weight > g.edges[b].weight;
    };
    using PQ = std::priority_queue<int, std::vector<int>, decltype(compare)>;
    std::vector<std::pair<PQ, W>> pq_add(n, {PQ{compare}, 0});
    for (int i = 0; i < g.edges.size(); i++)
        pq_add[g.edges[i].to].first.push(i);
    std::vector<int> pq_id(n);
    std::iota(pq_id.begin(), pq_id.end(), 0);

    auto merge = [&](int u, int v) {
        u = uf.leader(u);
        v = uf.leader(v);
        if (u == v)
            return;
        uf.merge(u, v);
        auto &[pq1, add1] = pq_add[pq_id[u]];
        auto &[pq2, add2] = pq_add[pq_id[v]];
        if (pq1.size() > pq2.size()) {
            while (pq2.size()) {
                int edge_id = pq2.top();
                pq2.pop();
                g.edges[edge_id].weight -= add2 - add1;
                pq1.push(edge_id);
            }
            pq_id[uf.leader(v)] = pq_id[u];
        } else {
            while (pq1.size()) {
                int edge_id = pq1.top();
                pq1.pop();
                g.edges[edge_id].weight -= add1 - add2;
                pq2.push(edge_id);
            }
            pq_id[uf.leader(v)] = pq_id[v];
        }
    };

    for (int i = 0; i < n; i++) {
        int now = uf.leader(i);
        if (state[now])
            continue;
        std::vector<int> processing;
        while (state[now] != 2) {
            processing.push_back(now);
            state[now] = 1;
            auto &[pq, add] = pq_add[pq_id[now]];
            if (!pq.size())
                return std::nullopt;
            int edge_id = pq.top();
            pq.pop();
            auto &e = g.edges[edge_id];
            res += e.weight - add;
            tree[e.to] = edge_id;
            pre[now] = uf.leader(e.from);
            new_add[now] = e.weight;
            if (state[pre[now]] == 1) {
                int v = now;
                do {
                    pq_add[pq_id[v]].second = new_add[v];
                    merge(v, now);
                    v = uf.leader(pre[v]);
                } while (!uf.same(v, now));
                now = uf.leader(now);
            } else
                now = uf.leader(pre[now]);
        }
        for (int v : processing)
            state[v] = 2;
    }
    tree.erase(tree.begin() + r);
    return std::make_pair(res, tree);
}
#line 1 "library/datastructure/unionfind/UnionFind.hpp"
#include <numeric>
#include <vector>

class UnionFind {
    int n, num;
    std::vector<int> sz, parent;

  public:
    UnionFind() = default;
    UnionFind(int n) : n(n), num(n), sz(n, 1), parent(n, 0) {
        std::iota(parent.begin(), parent.end(), 0);
    }

    int leader(int x) {
        assert(0 <= x and x < n);
        return (x == parent[x] ? x : parent[x] = leader(parent[x]));
    }

    bool same(int x, int y) {
        assert(0 <= x and x < n and 0 <= y and y < n);
        return leader(x) == leader(y);
    }

    bool merge(int x, int y) {
        assert(0 <= x and x < n and 0 <= y and y < n);
        x = leader(x);
        y = leader(y);
        if (x == y)
            return false;
        if (sz[x] < sz[y])
            std::swap(x, y);
        sz[x] += sz[y];
        parent[y] = x;
        num--;
        return true;
    }

    int size(const int x) {
        assert(0 <= x and x < n);
        return sz[leader(x)];
    }

    int count() const { return num; }

    std::vector<std::vector<int>> groups() {
        std::vector<std::vector<int>> res(n);
        for (int i = 0; i < n; i++)
            res[leader(i)].push_back(i);
        std::erase_if(res, [](const auto &vec) { return vec.empty(); });
        return res;
    }
};
#line 3 "library/graph/MinimumSpanningArborescence.hpp"
template <typename WG, typename W = typename WG::weight_type>
std::optional<std::pair<W, std::vector<int>>>
minimum_spanning_arborescence(WG g, int r = 0) {
    int n = g.n;
    W res = 0;
    std::vector<W> new_add(n, 0);
    std::vector<int> tree(n), pre(n), state(n, 0);
    UnionFind uf(n);
    state[r] = 2;

    auto compare = [&](const int &a, const int &b) {
        return g.edges[a].weight > g.edges[b].weight;
    };
    using PQ = std::priority_queue<int, std::vector<int>, decltype(compare)>;
    std::vector<std::pair<PQ, W>> pq_add(n, {PQ{compare}, 0});
    for (int i = 0; i < g.edges.size(); i++)
        pq_add[g.edges[i].to].first.push(i);
    std::vector<int> pq_id(n);
    std::iota(pq_id.begin(), pq_id.end(), 0);

    auto merge = [&](int u, int v) {
        u = uf.leader(u);
        v = uf.leader(v);
        if (u == v)
            return;
        uf.merge(u, v);
        auto &[pq1, add1] = pq_add[pq_id[u]];
        auto &[pq2, add2] = pq_add[pq_id[v]];
        if (pq1.size() > pq2.size()) {
            while (pq2.size()) {
                int edge_id = pq2.top();
                pq2.pop();
                g.edges[edge_id].weight -= add2 - add1;
                pq1.push(edge_id);
            }
            pq_id[uf.leader(v)] = pq_id[u];
        } else {
            while (pq1.size()) {
                int edge_id = pq1.top();
                pq1.pop();
                g.edges[edge_id].weight -= add1 - add2;
                pq2.push(edge_id);
            }
            pq_id[uf.leader(v)] = pq_id[v];
        }
    };

    for (int i = 0; i < n; i++) {
        int now = uf.leader(i);
        if (state[now])
            continue;
        std::vector<int> processing;
        while (state[now] != 2) {
            processing.push_back(now);
            state[now] = 1;
            auto &[pq, add] = pq_add[pq_id[now]];
            if (!pq.size())
                return std::nullopt;
            int edge_id = pq.top();
            pq.pop();
            auto &e = g.edges[edge_id];
            res += e.weight - add;
            tree[e.to] = edge_id;
            pre[now] = uf.leader(e.from);
            new_add[now] = e.weight;
            if (state[pre[now]] == 1) {
                int v = now;
                do {
                    pq_add[pq_id[v]].second = new_add[v];
                    merge(v, now);
                    v = uf.leader(pre[v]);
                } while (!uf.same(v, now));
                now = uf.leader(now);
            } else
                now = uf.leader(pre[now]);
        }
        for (int v : processing)
            state[v] = 2;
    }
    tree.erase(tree.begin() + r);
    return std::make_pair(res, tree);
}
Back to top page