Skip to the content.

:warning: library/formalpowerseries/Sqrt.hpp

Depends on

Code

#pragma once
#include "library/formalpowerseries/Base.hpp"
#include "library/math/ModularSqrt.hpp"
#include <optional>

// Computes the square root of a formal power series f.
// Returns std::nullopt if the square root does not exist.
template <typename FPS> std::optional<FPS> sqrt(FPS f) {
    using T = typename FPS::value_type;
    f.shrink();
    if (f.size() == 0) {
        return FPS(0);
    }

    int d = 0;
    while (d < f.size() && f[d] == 0) {
        d++;
    }
    if (d == f.size()) {
        return FPS(0);
    }

    if (d % 2 != 0) {
        return std::nullopt;
    }

    f >>= d;

    std::optional<T> s0 = mod_sqrt(f[0]);
    if (!s0) {
        return std::nullopt;
    }

    FPS res(1, *s0);
    int n = 1;
    constexpr int MX = FPS::max_size;
    while (n < MX) {
        n <<= 1;
        res = (res + f.pre(n) * res.inv(n)) / 2;
    }
    res.strict(MX);
    res <<= (d / 2);
    return res;
}
#line 1 "library/util/Valarray.hpp"
#include <functional>
#include <ranges>
#include <vector>

template <typename T> struct Valarray : std::vector<T> {
    using std::vector<T>::vector; // コンストラクタ継承
    Valarray(const std::vector<T> &v) : std::vector<T>(v.begin(), v.end()) {}

  private:
    template <typename Op>
    Valarray &apply_inplace(const Valarray &other, Op op) {
        if (this->size() < other.size())
            this->resize(other.size(), T(0));

        for (auto [a, b] : std::views::zip(*this, other))
            a = op(a, b);

        return *this;
    }

  public:
    Valarray &operator+=(const Valarray &other) {
        return apply_inplace(other, std::plus<>());
    }
    Valarray &operator-=(const Valarray &other) {
        return apply_inplace(other, std::minus<>());
    }
    Valarray &operator*=(const Valarray &other) {
        return apply_inplace(other, std::multiplies<>());
    }
    Valarray &operator/=(const Valarray &other) {
        return apply_inplace(other, std::divides<>());
    }

    friend Valarray operator+(Valarray a, const Valarray &b) { return a += b; }
    friend Valarray operator-(Valarray a, const Valarray &b) { return a -= b; }
    friend Valarray operator*(Valarray a, const Valarray &b) { return a *= b; }
    friend Valarray operator/(Valarray a, const Valarray &b) { return a /= b; }

    Valarray operator-() const {
        Valarray g = *this;
        for (T &a : g)
            a = -a;
        return g;
    }
};
#line 3 "library/formalpowerseries/Base.hpp"

template <typename T, int MX> struct FormalPowerSeries : Valarray<T> {
    using FPS = FormalPowerSeries;
    static constexpr int max_size = MX;
    using Valarray<T>::Valarray;
    using Valarray<T>::size;
    using Valarray<T>::resize;
    using Valarray<T>::at;
    using Valarray<T>::begin;
    using Valarray<T>::end;
    using Valarray<T>::back;
    using Valarray<T>::pop_back;
    using value_type = T;

    void strict(int n) {
        if (size() > n)
            resize(n);
        shrink();
    }
    void shrink() {
        while (size() and back() == 0)
            pop_back();
    }

    FormalPowerSeries() = default;

    FormalPowerSeries(const std::vector<T> &f) : Valarray<T>(f) {
        strict(MX);
        shrink();
    }

    static FPS unit() { return {1}; }
    static FPS x() { return {0, 1}; }
#pragma region operator
    FPS operator-() const { return FPS(Valarray<T>::operator-()); }

    FPS &operator+=(const FPS &g) {
        Valarray<T>::operator+=(g);
        shrink();
        return *this;
    }
    FPS operator+(const FPS &g) const { return FPS(*this) += g; }

    FPS &operator-=(const FPS &g) {
        Valarray<T>::operator-=(g);
        shrink();
        return *this;
    }
    FPS operator-(const FPS &g) const { return FPS(*this) -= g; }

    FPS &operator+=(const T &a) {
        if (!size())
            resize(1);
        at(0) += a;
        return *this;
    }
    FPS operator+(const T &a) const { return FPS(*this) += a; }
    friend FPS operator+(const T &a, const FPS &f) { return f + a; }

    FPS &operator-=(const T &a) {
        if (!size())
            resize(1);
        at(0) -= a;
        return *this;
    }
    FPS operator-(const T &a) { return FPS(*this) -= a; }
    friend FPS operator-(const T &a, const FPS &f) { return a + (-f); }

    FPS operator*(const FPS &g) const { return FPS(convolution(*this, g)); }
    FPS &operator*=(const FPS &g) { return (*this) = (*this) * g; }

    FPS &operator*=(const T &a) {
        for (size_t i = 0; i < size(); i++)
            at(i) *= a;
        return *this;
    }
    FPS operator*(const T &a) const { return FPS(*this) *= a; }
    friend FPS operator*(const T &a, const FPS &f) { return f * a; }

    FPS operator/(const FPS &g) const { return (*this) * g.inv(); }
    FPS &operator/=(const FPS &g) { return (*this) = (*this) / g; }

    FPS &operator/=(const T &a) { return *this *= a.inv(); }
    FPS operator/(const T &a) { return FPS(*this) /= a; }

    FPS &operator<<=(const int d) {
        if (d >= MX)
            return *this = FPS(0);
        resize(std::min(MX, int(size()) + d));
        for (int i = int(size()) - 1 - d; i >= 0; i--)
            at(i + d) = at(i);
        for (int i = d - 1; i >= 0; i--)
            at(i) = 0;
        return *this;
    }
    FPS operator<<(const int d) const { return FPS(*this) <<= d; }
    FPS &operator>>=(const int d) {
        if (d >= size())
            return *this = FPS(0);
        for (size_t i = d; i < size(); i++)
            at(i - d) = at(i);
        strict(int(size()) - d);
        return *this;
    }
    FPS operator>>(const int d) const { return FPS(*this) >>= d; }
#pragma endregion operator

    FPS pre(int n) const {
        if (size() <= n)
            return *this;
        return FPS(Valarray<T>(this->begin(), this->begin() + n));
    }

    // 最小の非ゼロ次数(すべて 0 のときは size())を返す
    int order() const {
        for (int i = 0; i < int(size()); i++) {
            if (at(i) != 0)
                return i;
        }
        return int(size());
    }

    FPS inv(int SZ = MX) const {
        assert(size() and at(0) != 0);
        FPS res = {at(0).inv()};
        for (int n = 1; n < SZ; n <<= 1) {
            res *= (2 - this->pre(n << 1) * res);
            res.strict(n << 1);
        }
        res.strict(SZ);
        return res;
    }

    // *this = f_1 + f_2 x^n ⇒ [*this←f_1, return f_2]
    FPS separate(int n) {
        if (size() <= n)
            return FPS(0);
        FPS f_2(size() - n);
        for (size_t i = n; i < size(); i++)
            f_2[i - n] = at(i);
        strict(n);
        return f_2;
    }

    T operator()(T a) const {
        T res = 0, b = 1;
        for (size_t i = 0; i < size(); i++, b *= a)
            res += at(i) * b;
        return res;
    }
};
#line 2 "library/math/ExtraGCD.hpp"
using ll = long long;
std::pair<ll, ll> ext_gcd(ll a, ll b) {
    if (b == 0)
        return {1, 0};
    auto [X, Y] = ext_gcd(b, a % b);
    // bX + (a%b)Y = gcd(a,b)
    // a%b = a - b(a/b)
    // ∴ aY + b(X-(a/b)Y) = gcd(a,b)
    ll x = Y, y = X - (a / b) * Y;
    return {x, y};
}
#line 3 "library/mod/Modint.hpp"
template <typename T, T MOD = 998244353> struct Mint {
    inline static constexpr T mod = MOD;
    T v;
    Mint() : v(0) {}
    Mint(signed v) : v(v) {}
    Mint(long long t) {
        v = t % MOD;
        if (v < 0)
            v += MOD;
    }

    static Mint raw(int v) {
        Mint x;
        x.v = v;
        return x;
    }

    Mint pow(long long k) const {
        Mint res(1), tmp(v);
        while (k) {
            if (k & 1)
                res *= tmp;
            tmp *= tmp;
            k >>= 1;
        }
        return res;
    }

    static Mint add_identity() { return Mint(0); }
    static Mint mul_identity() { return Mint(1); }

    // Mint inv()const{return pow(MOD-2);}
    Mint inv() const { return Mint(ext_gcd(v, mod).first); }

    Mint &operator+=(Mint a) {
        v += a.v;
        if (v >= MOD)
            v -= MOD;
        return *this;
    }
    Mint &operator-=(Mint a) {
        v += MOD - a.v;
        if (v >= MOD)
            v -= MOD;
        return *this;
    }
    Mint &operator*=(Mint a) {
        v = 1LL * v * a.v % MOD;
        return *this;
    }
    Mint &operator/=(Mint a) { return (*this) *= a.inv(); }

    Mint operator+(Mint a) const { return Mint(v) += a; }
    Mint operator-(Mint a) const { return Mint(v) -= a; }
    Mint operator*(Mint a) const { return Mint(v) *= a; }
    Mint operator/(Mint a) const { return Mint(v) /= a; }
#define FRIEND(op)                                                             \
    friend Mint operator op(int a, Mint b) { return Mint(a) op b; }
    FRIEND(+);
    FRIEND(-);
    FRIEND(*);
    FRIEND(/);
#undef FRIEND
    Mint operator+() const { return *this; }
    Mint operator-() const { return v ? Mint(MOD - v) : Mint(v); }

    bool operator==(const Mint a) const { return v == a.v; }
    bool operator!=(const Mint a) const { return v != a.v; }

    static Mint comb(long long n, int k) {
        Mint num(1), dom(1);
        for (int i = 0; i < k; i++) {
            num *= Mint(n - i);
            dom *= Mint(i + 1);
        }
        return num / dom;
    }

    friend std::ostream &operator<<(std::ostream &os, const Mint &m) {
        os << m.v;
        return os;
    }
    friend std::istream &operator>>(std::istream &is, Mint &m) {
        is >> m.v;
        m.v %= MOD;
        if (m.v < 0)
            m.v += MOD;
        return is;
    }
};
#line 3 "library/math/ModularSqrt.hpp"
#include <optional>
#include <random>
#include <chrono>

template <typename T, T MOD>
bool is_quadratic_residue(Mint<T, MOD> a) {
    if (a == 0) return true;
    return a.pow((MOD - 1) / 2) == 1;
}

template <typename T, T MOD>
std::optional<Mint<T, MOD>> mod_sqrt(Mint<T, MOD> a) {
    if (a == 0) return Mint<T, MOD>(0);
    if (MOD == 2) return a;
    if (!is_quadratic_residue(a)) return std::nullopt;

    if (MOD % 4 == 3) {
        return a.pow((MOD + 1) / 4);
    }

    // Tonelli-Shanks
    long long s = 0, q = MOD - 1;
    while (q % 2 == 0) {
        q /= 2;
        s++;
    }

    // Find a non-quadratic residue z
    std::mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
    Mint<T, MOD> z;
    do {
        z = rng() % MOD;
    } while (is_quadratic_residue(z));

    long long m = s;
    Mint<T, MOD> c = z.pow(q);
    Mint<T, MOD> t = a.pow(q);
    Mint<T, MOD> r = a.pow((q + 1) / 2);

    while (t != 1) {
        if (t == 0) return Mint<T, MOD>(0);
        long long i = 0;
        Mint<T, MOD> temp = t;
        while (temp != 1) {
            temp *= temp;
            i++;
            if (i == m) return std::nullopt; // Should not happen for quadratic residues
        }

        Mint<T, MOD> b = c.pow(1LL << (m - i - 1));
        m = i;
        c = b * b;
        t *= c;
        r *= b;
    }
    return r;
}
#line 5 "library/formalpowerseries/Sqrt.hpp"

// Computes the square root of a formal power series f.
// Returns std::nullopt if the square root does not exist.
template <typename FPS> std::optional<FPS> sqrt(FPS f) {
    using T = typename FPS::value_type;
    f.shrink();
    if (f.size() == 0) {
        return FPS(0);
    }

    int d = 0;
    while (d < f.size() && f[d] == 0) {
        d++;
    }
    if (d == f.size()) {
        return FPS(0);
    }

    if (d % 2 != 0) {
        return std::nullopt;
    }

    f >>= d;

    std::optional<T> s0 = mod_sqrt(f[0]);
    if (!s0) {
        return std::nullopt;
    }

    FPS res(1, *s0);
    int n = 1;
    constexpr int MX = FPS::max_size;
    while (n < MX) {
        n <<= 1;
        res = (res + f.pre(n) * res.inv(n)) / 2;
    }
    res.strict(MX);
    res <<= (d / 2);
    return res;
}
Back to top page