Google Kickstart 2020 Round 1Bの最後の問題、Wandering Robot についてやっと理解できたので解説を書きます。
問題概要
W*Hの盤面が与えられる。盤面には ある矩形1つ分の穴が空いている。ロボットは左上の(1, 1)を出発し、等確率で右か下に移動する。ただし盤面の右端に到達した場合は必ず下に、盤面の下端に到達した場合は必ず右に移動する。ロボットが穴に落ちずに盤面の右下(W, H)に到着できる確率を求めよ。
着想
例えば以下の盤面を考えます。■は穴を表します。
★に着目してください。ロボットが右下に到達するためには、必ず★のどれか一つを通らなければいけません。かつ、★を2つ以上通ることはできません。なので、各★を通る確率を個別に求め、その和をとれば、答が求まりそうです。
各セルを通る確率は、以下のように求められます。さきほどの盤面の場合、3つの★の確率の和 1/16 + 4/16 + 1/8 = 7/16 = 0.4375 が答になります。
上記の確率はいかにも二項分布で一般化できそうな形をしています。しかし、上図であえて空白にしている部分の確率はどうでしょうか。すべての確率を求めた図を以下に示します。ロボットが右端 or 下端に達したときは確率1で移動方向が決定するため、二項分布の法則が崩れます。
なので、以下のような盤面が与えられたとき、★の部分の確率の和を求めるのは一筋縄ではいきません。
工夫1: 仮想盤面の導入
本ラウンド1位のscotwu氏の回答 がそれへの解になります。この回答では、まるでロボットが右端や下端を無視して移動し続けられるかのような仮想盤面を考えます。
xxx
上の仮想盤面において☆の確率を足すと、ちょうど前の盤面の★の確率の和と同じになっていることが分かります!! 仮想盤面では二項分布の法則が崩れないので、確率の計算が容易です。
ただし、穴が盤面の下または右に接するコーナーケースに注意です。以下の例の場合、×印で示した★の確率を足してはいけません。
工夫2: logの計算
本質的にはここまでの説明で解けたようなものですが、実際に確率を算出するためにはもうひと工夫必要です。
ロボットがマス(x, y)を通過する確率は C((x - 1) + (y - 1), (x - 1)) / 2^((x - 1) + (y - 1))
で書けます(ここで、C(n, k)は二項分布)。しかし、分子も分母も巨大な整数になるため愚直に計算すると計算時間がかかりすぎます。
そこで、分子、分母それぞれを底を2とするlogをとることで、計算時間を抑えることができます。C(n, k) = n! / k! (n - k)!
なので、1!, 2!, 3!, ... それぞれについて底を2とするlogの値を前計算で求めておけばOKです。
コード
Python 3です。
#!/usr/bin/env python3 import math # log2(n!)のテーブルを作成 log_table = [0] * 200001 log_value = 0 for i in range(1, 200001): log_value += math.log(i, 2) log_table[i] = log_value def log_of_choose(n, k): """log C(n, k)を求める""" return log_table[n] - log_table[k] - log_table[n - k] def calc_probability(x, y): """(1, 1)から(x, y)に到着する確率を求める """ x -= 1 y -= 1 return log_of_choose(x + y, x) - (x + y) def solve_fast(W, H, L, U, R, D): ans = 0.0 # ◎の部分を下から順に # □□□□□◎ # □□□□◎ # □□□■□ # □□□■□ # □□□□□ if R != W: x = R + 1 y = U - 1 while y >= 1: log2_of_prob = calc_probability(x, y) ans += pow(2, log2_of_prob) x += 1 y -= 1 # ◎の部分を右から順に # □□□□□ # □□□□□ # □□□■□ # □□□■□ # □□◎□□ # ◎ # ◎ if D != H: x = L - 1 y = D + 1 while x >= 1: log2_of_prob = calc_probability(x, y) ans += pow(2, log2_of_prob) x -= 1 y += 1 return ans def solve(): W, H, L, U, R, D = map(int, input().split()) print("{:.20f}".format(solve_fast(W, H, L, U, R, D))) T = int(input()) for testcase in range(T): print("Case #{}: ".format(testcase+1), end="") solve()