NumPyのブロードキャストのルール


このエントリーをはてなブックマークに追加

NumPyのブロードキャストのルールについて曖昧にしか理解していなかったので調べました。わかってしまえば簡単です。

2つのarrayの間でブロードキャストができるかどうかは、2つのarrayのshapeによってのみ決まります。アルゴリズムは以下です。

  • もし2つのarrayの次元が異なれば、小さい方のarrayの左に1を詰めて、同じ次元にする
  • 2つのarrayのshapeのうち、各次元の値を比較。「すべての次元の値が、「一致」または「どちらか片方が1」である」という条件を満たすならば、ブロードキャスト可能。大きな値

例えば、array1が(2, 3, 4, 1), array2が(4, 8)の場合を考えます。

まず、array2のほうがarray1より次元が小さいので、array2の左に1を詰めて、array2を(1, 1, 4, 8)とします。この結果、array1とarray2は以下のようになります。

  • array1: (2, 3, 4, 1)
  • array2: (1, 1, 4, 8)

つぎに、array1, array2の各次元の値を比較していきます。

ここで、array1とarary2は、上で述べた「すべての次元の値が、「一致」または「どちらか片方が1」である」という条件を満たすので、ブロードキャスト可能です。この場合、各次元の値は数値の大きい方に合わせるので、ブロードキャスト後のshapeは(2, 3, 4, 8)になります。

コード例は以下です。

>>> import numpy as np
>>> a = np.zeros((2, 3, 4, 1))
>>> b = np.zeros((4, 8))
>>> (a + b).shape
(2, 3, 4, 8)

参考URLs