NumPyのブロードキャストのルールについて曖昧にしか理解していなかったので調べました。わかってしまえば簡単です。
2つのarrayの間でブロードキャストができるかどうかは、2つのarrayのshapeによってのみ決まります。アルゴリズムは以下です。
- もし2つのarrayの次元が異なれば、小さい方のarrayの左に1を詰めて、同じ次元にする
- 2つのarrayのshapeのうち、各次元の値を比較。「すべての次元の値が、「一致」または「どちらか片方が1」である」という条件を満たすならば、ブロードキャスト可能。大きな値
例えば、array1が(2, 3, 4, 1), array2が(4, 8)の場合を考えます。
まず、array2のほうがarray1より次元が小さいので、array2の左に1を詰めて、array2を(1, 1, 4, 8)とします。この結果、array1とarray2は以下のようになります。
- array1: (2, 3, 4, 1)
- array2: (1, 1, 4, 8)
つぎに、array1, array2の各次元の値を比較していきます。
ここで、array1とarary2は、上で述べた「すべての次元の値が、「一致」または「どちらか片方が1」である」という条件を満たすので、ブロードキャスト可能です。この場合、各次元の値は数値の大きい方に合わせるので、ブロードキャスト後のshapeは(2, 3, 4, 8)になります。
コード例は以下です。
>>> import numpy as np >>> a = np.zeros((2, 3, 4, 1)) >>> b = np.zeros((4, 8)) >>> (a + b).shape (2, 3, 4, 8)