Skip to the content.

:x: test/library-checker/Polynomial/Composition.test.cpp

Depends on

Code

#define PROBLEM                                                                \
    "https://judge.yosupo.jp/problem/composition_of_formal_power_series"
#include <bits/stdc++.h>

#include "library/formalpowerseries/Base.hpp"
#include "library/formalpowerseries/functions/composition.hpp"

#include <atcoder/convolution>
#include <atcoder/modint>
using namespace atcoder;
using mint = modint998244353;
std::ostream &operator<<(std::ostream &os, mint a) {
    os << a.val();
    return os;
}
std::istream &operator>>(std::istream &is, mint &a) {
    long long b;
    is >> b;
    a = b;
    return is;
}

using FPS = FormalPowerSeries<mint, 8000>;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int n;
    std::cin >> n;
    FPS f(n), g(n);
    for (int i = 0; i < n; i++)
        std::cin >> f[i];
    for (int i = 0; i < n; i++)
        std::cin >> g[i];
    f = fps::composition(f, g);
    for (int i = 0; i < n; i++)
        std::cout << f[i] << "\n "[i + 1 < n];
}
#line 1 "test/library-checker/Polynomial/Composition.test.cpp"
#define PROBLEM                                                                \
    "https://judge.yosupo.jp/problem/composition_of_formal_power_series"
#include <bits/stdc++.h>

#line 2 "library/util/Valarray.hpp"
#include <ranges>
#line 4 "library/util/Valarray.hpp"

template <typename T> struct Valarray : std::vector<T> {
    using std::vector<T>::vector; // コンストラクタ継承
    Valarray(const std::vector<T> &v) : std::vector<T>(v.begin(), v.end()) {}

  private:
    template <typename Op>
    Valarray &apply_inplace(const Valarray &other, Op op) {
        if (this->size() < other.size())
            this->resize(other.size(), T(0));

        for (auto [a, b] : std::views::zip(*this, other))
            a = op(a, b);

        return *this;
    }

  public:
    Valarray &operator+=(const Valarray &other) {
        return apply_inplace(other, std::plus<>());
    }
    Valarray &operator-=(const Valarray &other) {
        return apply_inplace(other, std::minus<>());
    }
    Valarray &operator*=(const Valarray &other) {
        return apply_inplace(other, std::multiplies<>());
    }
    Valarray &operator/=(const Valarray &other) {
        return apply_inplace(other, std::divides<>());
    }

    friend Valarray operator+(Valarray a, const Valarray &b) { return a += b; }
    friend Valarray operator-(Valarray a, const Valarray &b) { return a -= b; }
    friend Valarray operator*(Valarray a, const Valarray &b) { return a *= b; }
    friend Valarray operator/(Valarray a, const Valarray &b) { return a /= b; }

    Valarray operator-() const {
        Valarray g = *this;
        for (T &a : g)
            a = -a;
        return g;
    }
};
#line 3 "library/formalpowerseries/Base.hpp"

template <typename T, int MX> struct FormalPowerSeries : Valarray<T> {
    using FPS = FormalPowerSeries;
    static constexpr int max_size = MX;
    using Valarray<T>::Valarray;
    using Valarray<T>::size;
    using Valarray<T>::resize;
    using Valarray<T>::at;
    using Valarray<T>::begin;
    using Valarray<T>::end;
    using Valarray<T>::back;
    using Valarray<T>::pop_back;
    using value_type = T;

    void strict(int n) {
        if (size() > n)
            resize(n);
        shrink();
    }
    void shrink() {
        while (size() and back() == 0)
            pop_back();
    }

    FormalPowerSeries() = default;

    FormalPowerSeries(const std::vector<T> &f) : Valarray<T>(f) {
        strict(MX);
        shrink();
    }

    static FPS unit() { return {1}; }
    static FPS x() { return {0, 1}; }
#pragma region operator
    FPS operator-() const { return FPS(Valarray<T>::operator-()); }

    FPS &operator+=(const FPS &g) {
        Valarray<T>::operator+=(g);
        shrink();
        return *this;
    }
    FPS operator+(const FPS &g) const { return FPS(*this) += g; }

    FPS &operator-=(const FPS &g) {
        Valarray<T>::operator-=(g);
        shrink();
        return *this;
    }
    FPS operator-(const FPS &g) const { return FPS(*this) -= g; }

    FPS &operator+=(const T &a) {
        if (!size())
            resize(1);
        at(0) += a;
        return *this;
    }
    FPS operator+(const T &a) const { return FPS(*this) += a; }
    friend FPS operator+(const T &a, const FPS &f) { return f + a; }

    FPS &operator-=(const T &a) {
        if (!size())
            resize(1);
        at(0) -= a;
        return *this;
    }
    FPS operator-(const T &a) { return FPS(*this) -= a; }
    friend FPS operator-(const T &a, const FPS &f) { return a + (-f); }

    FPS operator*(const FPS &g) const { return FPS(convolution(*this, g)); }
    FPS &operator*=(const FPS &g) { return (*this) = (*this) * g; }

    FPS &operator*=(const T &a) {
        for (size_t i = 0; i < size(); i++)
            at(i) *= a;
        return *this;
    }
    FPS operator*(const T &a) const { return FPS(*this) *= a; }
    friend FPS operator*(const T &a, const FPS &f) { return f * a; }

    FPS operator/(const FPS &g) const { return (*this) * g.inv(); }
    FPS &operator/=(const FPS &g) { return (*this) = (*this) / g; }

    FPS &operator/=(const T &a) { return *this *= a.inv(); }
    FPS operator/(const T &a) { return FPS(*this) /= a; }

    FPS &operator<<=(const int d) {
        if (d >= MX)
            return *this = FPS(0);
        resize(std::min(MX, int(size()) + d));
        for (int i = int(size()) - 1 - d; i >= 0; i--)
            at(i + d) = at(i);
        for (int i = d - 1; i >= 0; i--)
            at(i) = 0;
        return *this;
    }
    FPS operator<<(const int d) const { return FPS(*this) <<= d; }
    FPS &operator>>=(const int d) {
        if (d >= size())
            return *this = FPS(0);
        for (size_t i = d; i < size(); i++)
            at(i - d) = at(i);
        strict(int(size()) - d);
        return *this;
    }
    FPS operator>>(const int d) const { return FPS(*this) >>= d; }
#pragma endregion operator

    FPS pre(int n) const {
        if (size() <= n)
            return *this;
        return FPS(Valarray<T>(this->begin(), this->begin() + n));
    }

    // 最小の非ゼロ次数(すべて 0 のときは size())を返す
    int order() const {
        for (int i = 0; i < int(size()); i++) {
            if (at(i) != 0)
                return i;
        }
        return int(size());
    }

    FPS inv(int SZ = MX) const {
        assert(size() and at(0) != 0);
        FPS res = {at(0).inv()};
        for (int n = 1; n < SZ; n <<= 1) {
            res *= (2 - this->pre(n << 1) * res);
            res.strict(n << 1);
        }
        res.strict(SZ);
        return res;
    }

    // *this = f_1 + f_2 x^n ⇒ [*this←f_1, return f_2]
    FPS separate(int n) {
        if (size() <= n)
            return FPS(0);
        FPS f_2(size() - n);
        for (size_t i = n; i < size(); i++)
            f_2[i - n] = at(i);
        strict(n);
        return f_2;
    }

    T operator()(T a) const {
        T res = 0, b = 1;
        for (size_t i = 0; i < size(); i++, b *= a)
            res += at(i) * b;
        return res;
    }
};
#line 3 "library/formalpowerseries/functions/differential.hpp"

namespace fps {

template <typename T, int MX>
FormalPowerSeries<T, MX> differential(FormalPowerSeries<T, MX> f) {
    if (f.size() <= 1) {
        return FormalPowerSeries<T, MX>{};
    }
    for (std::size_t i = 0; i < f.size() - 1; i++) {
        f[i] = (i + 1) * f[i + 1];
    }
    f.pop_back();
    return f;
}

} // namespace fps
#line 3 "library/formalpowerseries/functions/integral.hpp"

namespace fps {

template <typename T, int MX>
FormalPowerSeries<T, MX> integral(FormalPowerSeries<T, MX> f) {
    if (f.size() < MX) {
        f.resize(f.size() + 1);
    }
    for (int i = f.size() - 1; i > 0; i--) {
        f[i] = f[i - 1] / i;
    }
    f[0] = 0;
    return f;
}

} // namespace fps
#line 5 "library/formalpowerseries/functions/log.hpp"

namespace fps {

template <typename T, int MX>
FormalPowerSeries<T, MX> log(const FormalPowerSeries<T, MX> &f) {
    assert(f.size() and f[0] == 1);
    return integral(differential(f) / f);
}

} // namespace fps
#line 4 "library/formalpowerseries/functions/exp.hpp"

namespace fps {

template <typename T, int MX>
FormalPowerSeries<T, MX> exp(const FormalPowerSeries<T, MX>& f) {
    if (!f.size()) {
        return {1};
    }
    assert(f.size() > 0 && f[0] == 0);
    FormalPowerSeries<T, MX> res = {1};
    for (int n = 1; n < MX; n <<= 1) {
        res = res * (f.pre(n << 1) + 1 - log(res).pre(n << 1));
        res.strict(n << 1);
    }
    return res;
}

template <typename T, int MX>
FormalPowerSeries<T, MX> exp(const T& n) {
    if (n == 0) {
        return {1};
    }
    FormalPowerSeries<T, MX> res(MX);
    res[0] = 1;
    for (int i = 1; i < MX; i++) {
        res[i] = res[i - 1] * n / i;
    }
    return res;
}

} // namespace fps
#line 5 "library/formalpowerseries/functions/pow.hpp"

namespace fps {

template <typename T, int MX>
FormalPowerSeries<T, MX> pow(FormalPowerSeries<T, MX> f, long long n) {
    using FPS = FormalPowerSeries<T, MX>;

    assert(n >= 0);
    f.shrink();

    if(n == 0)
        return FPS::unit();
    if(n == 1)
        return f;
    
    if(f.size() == 0)
        return f;
    if(f.size() == 1)
        return FPS{f[0].pow(n)};
    
    int d = f.order();
    if (d > 0 && (unsigned __int128)d * n >= MX) 
        return FPS(0);

    // f(x) = x^d g(x) の時 f^n = x^{dn} g^n
    f >>= d;

    // f(x) = f_0 * g(x) のとき f^n = f_0^n g^n
    auto f_0 = f[0];
    if(f_0 != 1)
        f /= f_0;

    // f^n = exp(n log(f))
    return f_0.pow(n) * fps::exp(n * fps::log(f)) << (d * n);
}

} // namespace fps
#line 5 "library/formalpowerseries/functions/composition.hpp"

namespace fps {

template <typename T, int MX>
FormalPowerSeries<T, MX> composition(const FormalPowerSeries<T, MX>& f, FormalPowerSeries<T, MX> g) {
    using FPS = FormalPowerSeries<T, MX>;
    // f(g(x)) をブロック分割+テイラー展開で計算する
    assert(!g.size() or g[0] == 0);
    switch (f.size()) {
        case 0:
            return f;
        case 1:
            return {f[0]};
        case 2:
            return f[0] + f[1] * g;
        default:
            break;
    }

    int m = sqrt(MX / 20);
    // g(x) = g1(x) + x^m g2(x) に分割
    FPS g1 = g;
    FPS g2 = g1.separate(m);
    
    if (g1 == FPS(0)) {
        // f(g) = f(x^m g2(x))
        FPS res(0), g2pow = FPS::unit();
        for (int i = 0; i * m < MX and i < f.size(); i++, g2pow *= g2)
            res += f[i] * g2pow << (i * m);
        return res;
    }

    // f を二分しつつ g1 の累乗をまとめて掛ける再帰
    auto rec = [&](auto rec, int l, int d) -> FPS {
        if (d == 0 or l >= f.size()) {
            return {};
        }
        if (d == 1) {
            return {f[l]};
        }
        if (d == 2) {
            return f[l] + (l + 1 < f.size() ? f[l + 1] * g1 : FPS{});
        }
        FPS f1 = rec(rec, l, d >> 1);
        FPS f2 = rec(rec, l + (d >> 1), d - (d >> 1));
        f2 *= fps::pow(g1, d >> 1);
        return f1 + f2;
    };
    FPS res = rec(rec, 0, f.size()); // まず g1 を代入した分を計算

    FPS dfg = res; // d^k f(g1) を順に更新するための一時変数
    FPS g1_diff = fps::differential(g);
    g1_diff >>= (g1.order() - 1);
    FPS g1inv = g1_diff.inv();
    FPS g2pow = {1};
    T factinv = 1;

    // g2 によるテイラー展開の各項を足し込む
    for (int i = 1; i * m < MX; i++) {
        dfg = (fps::differential(dfg) >> g1.order()) * g1inv;
        dfg.strict(MX - m * i);
        (g2pow *= g2).strict(MX - m * i);
        factinv /= i;
        res += factinv * (dfg * g2pow) << (m * i);
    }
    return res;
}

} // namespace fps
#line 7 "test/library-checker/Polynomial/Composition.test.cpp"

#include <atcoder/convolution>
#include <atcoder/modint>
using namespace atcoder;
using mint = modint998244353;
std::ostream &operator<<(std::ostream &os, mint a) {
    os << a.val();
    return os;
}
std::istream &operator>>(std::istream &is, mint &a) {
    long long b;
    is >> b;
    a = b;
    return is;
}

using FPS = FormalPowerSeries<mint, 8000>;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int n;
    std::cin >> n;
    FPS f(n), g(n);
    for (int i = 0; i < n; i++)
        std::cin >> f[i];
    for (int i = 0; i < n; i++)
        std::cin >> g[i];
    f = fps::composition(f, g);
    for (int i = 0; i < n; i++)
        std::cout << f[i] << "\n "[i + 1 < n];
}
Back to top page