Skip to the content.

:heavy_check_mark: test/yukicoder/1502.test.cpp

Depends on

Code

#define PROBLEM "https://yukicoder.me/problems/no/1502"
#include <bits/stdc++.h>

#define REP(i, n) for (int i = 0; i < (n); i++)

#include "library/datastructure/unionfind/IntegerSumRuleUnionFind.hpp"
#include "library/mod/Modint.hpp"
using mint = Mint<long long, 1000'000'007>;

void chmin(int &a, int b) { a = std::min(a, b); }
void chmax(int &a, int b) { a = std::max(a, b); }

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int n, m, k;
    std::cin >> n >> m >> k;

    IntegerSumRuleUnionFind UF(n);
    REP (_, m) {
        int x, y, z;
        std::cin >> x >> y >> z;
        x--;
        y--;
        if (!UF.merge(x, y, z)) {
            std::cout << 0 << std::endl;
            return 0;
        }
    }

    auto solve = [&](int upper) {
        std::vector<int> low(n, 1), high(n, upper);
        REP (i, n) {
            auto [r, a, b] = UF.from_root(i);
            if (UF.val(r)) {
                int v = UF.val(r).value() * a + b;
                if (v < 1 || upper < v)
                    return mint::raw(0);
                continue;
            }
            // 1 <= ra+b <= upper
            if (a == 1) {
                chmax(low[r], 1 - b);
                chmin(high[r], upper - b);
            } else {
                chmax(low[r], b - upper);
                chmin(high[r], b - 1);
            }
        }
        mint res = 1;
        REP (r, n)
            if (UF.leader(r) == r and !UF.val(r))
                res *= std::max(high[r] - low[r] + 1, 0);
        return res;
    };

    std::cout << solve(k) - solve(k - 1) << std::endl;
}
#line 1 "test/yukicoder/1502.test.cpp"
#define PROBLEM "https://yukicoder.me/problems/no/1502"
#include <bits/stdc++.h>

#define REP(i, n) for (int i = 0; i < (n); i++)

#line 1 "library/datastructure/unionfind/IntegerSumRuleUnionFind.hpp"
class IntegerSumRuleUnionFind {
    using ll = long long;
    int n, num;
    std::vector<int> sz, parent;
    std::vector<std::pair<int, ll>> potential;
    std::vector<std::optional<ll>> value;

  public:
    IntegerSumRuleUnionFind() = default;
    IntegerSumRuleUnionFind(int n)
        : n(n), num(n), sz(n, 1), parent(n, 0), potential(n, {1, 0}),
          value(n, std::nullopt) {
        std::iota(parent.begin(), parent.end(), 0);
    }

    std::tuple<int, int, ll> from_root(int x) {
        if (x == parent[x])
            return {x, 1, 0LL};
        auto [r, a, b] = from_root(parent[x]);
        auto [c, d] = potential[x];
        parent[x] = r;
        potential[x] = {a * c, b * c + d};
        return {r, a * c, b * c + d};
    }

    int leader(int x) { return get<0>(from_root(x)); }

    bool same(int x, int y) {
        assert(0 <= x and x < n and 0 <= y and y < n);
        return leader(x) == leader(y);
    }

    bool merge(int x, int y, ll sum) {
        // 矛盾する場合は変更はせず false を返す
        assert(0 <= x and x < n and 0 <= y and y < n);
        auto [rx, a, b] = from_root(x);
        auto [ry, c, d] = from_root(y);
        if (rx == ry) {
            // ar+b + cr+d = sum
            if (a == -c)
                return b + d == sum;
            if ((sum - b - d) & 1)
                return false;
            ll r = (sum - b - d) / (a + c);
            if (value[rx] and value[rx].value() != r)
                return false; // これ起きる?
            value[rx] = r;
            return true;
        }
        if (sz[rx] < sz[ry]) {
            std::swap(rx, ry);
            std::swap(a, c);
            std::swap(b, d);
        }
        // a * rx + b + c * ry + d == sum
        // rx = -c/a ry + (sum-b-d)/a
        // ry = -a/c rx + (sum-b-d)/c
        if (value[ry]) {
            ll k = -c * a * value[ry].value() + (sum - b - d) * a;
            if (value[rx] and value[rx].value() != k)
                return false;
            value[rx] = k;
        }
        sz[rx] += sz[ry];
        parent[ry] = rx;
        potential[ry] = {-a * c, (sum - b - d) * c};
        num--;
        return true;
    }

    std::optional<ll> val(int x) {
        auto [r, a, b] = from_root(x);
        if (value[r])
            return value[r].value() * a + b;
        return std::nullopt;
    }

    // x と y が隣接してないなら std::nullopt
    // x と y が隣接しているが、sum が一意でない場合も std::nullopt
    std::optional<ll> sum(int x, int y) {
        auto [rx, a, b] = from_root(x);
        auto [ry, c, d] = from_root(y);
        if (rx != ry)
            return std::nullopt;
        if (a == c) {
            assert(b == d);
            return std::nullopt;
        }
        return b + d;
    }

    int size(const int x) {
        assert(0 <= x and x < n);
        return sz[leader(x)];
    }

    int count() const { return num; }
};
#line 2 "library/math/ExtraGCD.hpp"
using ll = long long;
std::pair<ll, ll> ext_gcd(ll a, ll b) {
    if (b == 0)
        return {1, 0};
    auto [X, Y] = ext_gcd(b, a % b);
    // bX + (a%b)Y = gcd(a,b)
    // a%b = a - b(a/b)
    // ∴ aY + b(X-(a/b)Y) = gcd(a,b)
    ll x = Y, y = X - (a / b) * Y;
    return {x, y};
}
#line 3 "library/mod/Modint.hpp"
template <typename T, T MOD = 998244353> struct Mint {
    inline static constexpr T mod = MOD;
    T v;
    Mint() : v(0) {}
    Mint(signed v) : v(v) {}
    Mint(long long t) {
        v = t % MOD;
        if (v < 0)
            v += MOD;
    }

    static Mint raw(int v) {
        Mint x;
        x.v = v;
        return x;
    }

    Mint pow(long long k) const {
        Mint res(1), tmp(v);
        while (k) {
            if (k & 1)
                res *= tmp;
            tmp *= tmp;
            k >>= 1;
        }
        return res;
    }

    static Mint add_identity() { return Mint(0); }
    static Mint mul_identity() { return Mint(1); }

    // Mint inv()const{return pow(MOD-2);}
    Mint inv() const { return Mint(ext_gcd(v, mod).first); }

    Mint &operator+=(Mint a) {
        v += a.v;
        if (v >= MOD)
            v -= MOD;
        return *this;
    }
    Mint &operator-=(Mint a) {
        v += MOD - a.v;
        if (v >= MOD)
            v -= MOD;
        return *this;
    }
    Mint &operator*=(Mint a) {
        v = 1LL * v * a.v % MOD;
        return *this;
    }
    Mint &operator/=(Mint a) { return (*this) *= a.inv(); }

    Mint operator+(Mint a) const { return Mint(v) += a; }
    Mint operator-(Mint a) const { return Mint(v) -= a; }
    Mint operator*(Mint a) const { return Mint(v) *= a; }
    Mint operator/(Mint a) const { return Mint(v) /= a; }
#define FRIEND(op)                                                             \
    friend Mint operator op(int a, Mint b) { return Mint(a) op b; }
    FRIEND(+);
    FRIEND(-);
    FRIEND(*);
    FRIEND(/);
#undef FRIEND
    Mint operator+() const { return *this; }
    Mint operator-() const { return v ? Mint(MOD - v) : Mint(v); }

    bool operator==(const Mint a) const { return v == a.v; }
    bool operator!=(const Mint a) const { return v != a.v; }

    static Mint comb(long long n, int k) {
        Mint num(1), dom(1);
        for (int i = 0; i < k; i++) {
            num *= Mint(n - i);
            dom *= Mint(i + 1);
        }
        return num / dom;
    }

    friend std::ostream &operator<<(std::ostream &os, const Mint &m) {
        os << m.v;
        return os;
    }
    friend std::istream &operator>>(std::istream &is, Mint &m) {
        is >> m.v;
        m.v %= MOD;
        if (m.v < 0)
            m.v += MOD;
        return is;
    }
};
#line 8 "test/yukicoder/1502.test.cpp"
using mint = Mint<long long, 1000'000'007>;

void chmin(int &a, int b) { a = std::min(a, b); }
void chmax(int &a, int b) { a = std::max(a, b); }

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int n, m, k;
    std::cin >> n >> m >> k;

    IntegerSumRuleUnionFind UF(n);
    REP (_, m) {
        int x, y, z;
        std::cin >> x >> y >> z;
        x--;
        y--;
        if (!UF.merge(x, y, z)) {
            std::cout << 0 << std::endl;
            return 0;
        }
    }

    auto solve = [&](int upper) {
        std::vector<int> low(n, 1), high(n, upper);
        REP (i, n) {
            auto [r, a, b] = UF.from_root(i);
            if (UF.val(r)) {
                int v = UF.val(r).value() * a + b;
                if (v < 1 || upper < v)
                    return mint::raw(0);
                continue;
            }
            // 1 <= ra+b <= upper
            if (a == 1) {
                chmax(low[r], 1 - b);
                chmin(high[r], upper - b);
            } else {
                chmax(low[r], b - upper);
                chmin(high[r], b - 1);
            }
        }
        mint res = 1;
        REP (r, n)
            if (UF.leader(r) == r and !UF.val(r))
                res *= std::max(high[r] - low[r] + 1, 0);
        return res;
    };

    std::cout << solve(k) - solve(k - 1) << std::endl;
}
Back to top page