ダブリングの抽象化

競プロでダブリングと呼ばれるアルゴリズムの抽象化を考えてみました。ダブリングのアルゴリズム自体の解説はしません。

扱う問題

有限の状態集合  T と、半群  (S, \cdot) を考えます。2つの写像  t:T\to T v: T\to S とが与えられたとき、任意の  s\in T k\in\mathbb{N} に対して、 v(s)\cdot v(t(s))\cdot v(t^2(s))\cdot\cdots\cdot v(t^{k}(s)) を求めるという問題を扱います。
簡潔にいうと、状態遷移しながら状態がもつ値を拾っていってその積を求めるということです。半群というのは集合の上に二項演算が定義されていて、その演算が結合律を満たすものをいいます。これは、結合律を満たさないとダブリングができないためです。
また、この問題において  |T|=1 の場合が繰り返し二乗法であると見做すこともできます。

実装

template <typename S, int POW_MAX>
struct Doubling {
   private:
    int n;
    vector<int> to;
    vector<S> value;
    function<S(S, S)> op;
    vector<vector<int>> next;
    vector<vector<S>> dp;

    void init() {
        next.resize(POW_MAX, vector<int>(n));
        dp.resize(POW_MAX, vector<S>(n));
        for (int start = 0; start < n; ++start) {
            next[0][start] = to[start];
            dp[0][start] = value[to[start]];
        }
        for (int pow = 1; pow < POW_MAX; ++pow) {
            for (int start = 0; start < n; ++start) {
                next[pow][start] = next[pow - 1][next[pow - 1][start]];
                dp[pow][start] =
                    op(dp[pow - 1][start], dp[pow - 1][next[pow - 1][start]]);
            }
        }
    }

   public:
    Doubling(int n, const vector<int>& to, const vector<S>& value,
             function<S(S, S)> op)
        : n(n), to(to), value(value), op(op) {
        init();
    }
    S prod(int start, long long step) const {
        S prod{value[start]};
        int now{start};
        for (int pow = 0; pow < POW_MAX; ++pow) {
            if (step & (1LL << pow)) {
                prod = op(prod, dp[pow][now]);
                now = next[pow][now];
            }
        }
        return prod;
    }
};

POW_MAXというのはprodに投げるstepの最大値を2冪で指定します。例えば  10^{18} まで取りうるならば60とすればよいです。ダブリングの配列の添字を逆にしたい気持ちになったのですが、添字の順番を逆にすると激遅になります。キャッシュを意識するのは大事ですね。


抽象化パワーのお手並み拝見といきましょう

ABC167 D - Teleporter

ダブリングが使える基本問題ですね。 今回の問題設定に帰着させるには、

  • 状態集合は  [N]=\{1,2,\dots,N\}
  • 半群の台集合は  [N]、演算は  a\cdot b = b
  •  t:[N]\to[N] t(i) = A_i
  •  v:[N]\to[N] v(i) = i

とすればよいです。この演算が結合律 (a\cdot b)\cdot c = a\cdot (b\cdot c) を満たすのは簡単に確認できます。この問題では状態の遷移先だけ考えればいいのですが、冗長になってしまいました。

int main() {
    int N;
    long long K;
    cin >> N >> K;
    vector<int> to(N), value(N);
    for (auto& A : to) {
        cin >> A;
        --A;
    }
    iota(begin(value), end(value), 1);
    auto op = [](int a, int b) { return b; };
    Doubling<int, 60> doubling(N, to, value, op);
    cout << doubling.prod(0, K) << '\n';
}

ABC179 E - Sequence Sum

これは今回の問題設定そのまんまという感じです。 [M]_0=\{0,1,2,\dots,M-1\} と書くことにします。

  • 状態集合は  [M]_0
  • 半群の台集合は  \mathbb{N}、演算は  a\cdot b = a + b \mathbb{N} における標準的な和)
  •  t:[M]_0\to[M]_0 t(i) = i^2 \mathrm{mod} M
  •  v:[M]_0\to \mathbb{N} v(i) = i

とすればよいです。

int main() {
    long long n;
    int x, m;
    cin >> n >> x >> m;
    vector<int> to(m);
    for (long long from = 0; from < m; ++from) {
        to[from] = from * from % m;
    }
    vector<long long> value(m);
    iota(begin(value), end(value), 0LL);
    auto op = [](long long a, long long b) { return a + b; };
    Doubling<long long, 34> doubling(m, to, value, op);
    cout << doubling.prod(x, n - 1) << '\n';
}

ABC175 D - Moving Piece

実装でバグり散らかしそうな問題です。今回の問題設定に帰着させるには、

  • 状態集合は  [N]=\{1,2,\dots,N\}
  • 半群の台集合は  \mathbb{Z}^2、演算は  (a_1,a_2)\cdot(b_1,b_2) = (a_1+b_1,\max(a_2,a_1+b_2))
  •  t:[N]\to[N] t(i)=P_i
  •  v:[N]\to\mathbb{Z}^2 v(i) = (C_i,C_i)

とすればよいです。 (a_1,a_2)\in\mathbb{Z}^2の気持ちとしては、 a_1 がスコアの累積和、 a_2がスコアの累積和の累積最大値を表しています。演算が結合律を満たすことを確かめましょう。

 \begin{align}
( (a_1,a_2)\cdot(b_1,b_2))\cdot(c_1,c_2) &= (a_1+b_1,\max(a_2,a_1+b_2))\cdot(c_1,c_2) \\
&= (a_1+b_1+c_1,\max(\max(a_2,a_1+b_2),a_1+b_1+c_2)) \\
&= (a_1+b_1+c_1,\max(a_2,a_1+b_2,a_1+b_1+c_2))
\end{align}
 \begin{align}

(a_1,a_2)\cdot( (b_1,b_2)\cdot(c_1,c_2)) &= (a_1,a_2)\cdot(b_1+c_1,\max(b_2,b_1+c_2)) \\
&= (a_1+b_1+c_1,\max(a_2,a_1+\max(b_2,b_1+c_2)) \\
&= (a_1+b_1+c_1,\max(a_2,\max(a_1+b_2,a_1+b_1+c_2))) \\
&= (a_1+b_1+c_1,\max(a_2,a_1+b_2,a_1+b_1+c_2))
\end{align}

int main() {
    int N, K;
    cin >> N >> K;
    vector<int> to(N);
    for (auto& P : to) {
        cin >> P;
        --P;
    }
    using S = pair<long long, long long>;
    vector<S> value(N);
    for (auto& val : value) {
        long long C;
        cin >> C;
        val = {C, C};
    }
    auto op = [](S a, S b) -> S {
        return {a.first + b.first, max(a.second, a.first + b.second)};
    };
    Doubling<S, 30> doubling(N, to, value, op);
    long long ans{-(1LL << 62)};
    for (int i = 0; i < N; ++i) {
        ans = max(ans, doubling.prod(i, K - 1).second);
    }
    cout << ans << '\n';
}


ほかに適用できる問題を見つけたら随時追加していきます。


感想

抽象化したデータ構造やアルゴリズムを使って問題を解くと見通しはよくなりますが、持つべきデータと演算を定めるときに工夫が必要なときがあり、往々にしてパズルをするはめになるのがつらそうです。