Skip to the content.

:heavy_check_mark: test/library-checker/DataStructure/UnionfindWithPotential.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/unionfind_with_potential"

#include <atcoder/modint>
#include <bits/stdc++.h>

#include "library/algebra/group/Add.hpp"
#include "library/datastructure/unionfind/PotentialUnionFind.hpp"

using mint = atcoder::modint998244353;

int main() {
    int n, q;
    std::cin >> n >> q;

    PotentialUnionFind<GroupAdd<mint>> uf(n);
    while (q--) {
        int t, u, v;
        std::cin >> t >> u >> v;
        if (t == 0) {
            int x;
            std::cin >> x;

            std::cout << uf.merge(u, v, mint::raw(x)) << std::endl;
        } else {
            auto d = uf.diff(u, v);
            if (d.has_value())
                std::cout << d.value().val() << std::endl;
            else
                std::cout << -1 << std::endl;
        }
    }
}
#line 1 "test/library-checker/DataStructure/UnionfindWithPotential.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/unionfind_with_potential"

#include <atcoder/modint>
#include <bits/stdc++.h>

#line 2 "library/algebra/group/Add.hpp"
template<typename X>
struct GroupAdd {
  using value_type = X;
  static constexpr X op(const X &x, const X &y) noexcept { return x + y; }
  static constexpr void Rchop(X&x, const X&y){ x+=y; }
  static constexpr void Lchop(const X&x, X&y){ y+=x; }
  static constexpr X inverse(const X &x) noexcept { return -x; }
  static constexpr X power(const X &x, long long n) noexcept { return X(n) * x; }
  static constexpr X unit() { return X(0); }
  static constexpr bool commute = true;
};
#line 2 "library/datastructure/unionfind/PotentialUnionFind.hpp"
template <typename AbelGroup> class PotentialUnionFind {
    using T = typename AbelGroup::value_type;
    int n, num;
    std::vector<int> sz, parent;
    std::vector<T> potential; // parent[x] を基準とした時の x の値
  public:
    PotentialUnionFind() = default;
    PotentialUnionFind(int n)
        : n(n), num(n), sz(n, 1), parent(n, 0),
          potential(n, AbelGroup::unit()) {
        assert(AbelGroup::commute);
        std::iota(parent.begin(), parent.end(), 0);
    }

    std::pair<int, T> from_root(int x) {
        if (x == parent[x])
            return {x, AbelGroup::unit()};
        auto [r, add] = from_root(parent[x]);
        parent[x] = r;
        AbelGroup::Rchop(potential[x], add);
        return {r, potential[x]};
    }

    int leader(int x) { return from_root(x).first; }

    bool same(int x, int y) {
        assert(0 <= x and x < n and 0 <= y and y < n);
        return leader(x) == leader(y);
    }

    bool merge(int x, int y, T d) {
        // potential[y]-potential[x]=d にする
        // 矛盾する場合は変更はせず false を返す
        assert(0 <= x and x < n and 0 <= y and y < n);
        auto [rx, dx] = from_root(x);
        auto [ry, dy] = from_root(y);
        AbelGroup::Rchop(d, dx);
        AbelGroup::Rchop(d, AbelGroup::inverse(dy));
        if (rx == ry)
            return d == AbelGroup::unit();
        if (sz[rx] < sz[ry]) {
            std::swap(rx, ry);
            d = AbelGroup::inverse(d);
        }
        sz[rx] += sz[ry];
        parent[ry] = rx;
        potential[ry] = d;
        num--;
        return true;
    }

    std::optional<T> diff(int x, int y) {
        // x を基準とする
        auto [rx, dx] = from_root(x);
        auto [ry, dy] = from_root(y);
        if (rx != ry)
            return std::nullopt;
        return AbelGroup::op(dy, AbelGroup::inverse(dx));
    }

    int size(const int x) {
        assert(0 <= x and x < n);
        return sz[leader(x)];
    }

    int count() const { return num; }
};
#line 8 "test/library-checker/DataStructure/UnionfindWithPotential.test.cpp"

using mint = atcoder::modint998244353;

int main() {
    int n, q;
    std::cin >> n >> q;

    PotentialUnionFind<GroupAdd<mint>> uf(n);
    while (q--) {
        int t, u, v;
        std::cin >> t >> u >> v;
        if (t == 0) {
            int x;
            std::cin >> x;

            std::cout << uf.merge(u, v, mint::raw(x)) << std::endl;
        } else {
            auto d = uf.diff(u, v);
            if (d.has_value())
                std::cout << d.value().val() << std::endl;
            else
                std::cout << -1 << std::endl;
        }
    }
}
Back to top page