TensorFlow / Kerasにおけるlogitsの意味


このエントリーをはてなブックマークに追加

TensorFlowやKerasを使うと遭遇する logits という用語、ざっと検索してもすぐに意味が出てこなかったので書いておきます。なお、この用語は複数の意味を持つ単語なので注意願います。この記事ではあくまで TensorFlow / Keras でのlogitsの意味を説明しています。

logitsの定義

scikit-learnとTensorFlowによる実践機械学習 (Amazonアフィリエイトリンク) のp267を引用すると、logitsとは、「ソフトマックス活性化関数に通す前のニューラルネットワークの出力」です。

概念図

画像を3クラスに分類する分類器をニューラルネットで設計したときの典型的な例を以下に示します。ニューラルネットに1枚画像を入力し、  \left[4.0, 2.0, 1.0  \right] ^T という出力を得たとします。これはソフトマックス活性化関数に通す前の値であり、logitsと呼ばれるものです。

f:id:minus9d:20201025192350p:plain

このlogitsを、ソフトマックス活性化関数に与えて得られる出力は、確率を表すベクトルです。ソフトマックス活性化関数の定義に従って計算すると、 \left[0.844, 0.114, 0.042  \right] ^T が得られます。確率なので足すと1.0になっています。0.844の部分の計算はPythonだと以下のようになります。

>>> import math
>>> print(math.exp(4.0) / (math.exp(4.0) + math.exp(2.0) + math.exp(1.0)))
0.8437947344813395

この確率を表すベクトルと、正解ベクトル  \left[1.0, 0.0, 0.0  \right] ^T との間の誤差は、通常、2つの確率分布の距離を計算する関数である交差エントロピー誤差を用いて計算されます。交差エントロピー誤差の定義に従って計算すると、 0.170 になります。この計算はPythonだと以下のようになります。

>>> import math
>>> print((- 1.0 * math.log(0.8437947) - 0.0 * math.log(0.11419519) - 0.0 * mat
h.log(0.04201007)))
0.1698460604208926

tensorflowでの関数

さきほどの図と同じ計算をtensorflowの関数を使って行うと以下のようになります。

import tensorflow as tf


logits = [[4.0, 2.0, 1.0]]
labels = [[1.0, 0.0, 0.0]]

probabilites = tf.nn.softmax(logits)
print(probabilites)  # [[0.8437947  0.11419519 0.04201007]]

cross_entropy_loss = tf.losses.categorical_crossentropy(y_true=labels, y_pred=probabilites)
print(cross_entropy_loss)  # [0.16984598]

softmaxを経由せず、いきなり交差エントロピー誤差を計算する方法もあります。1つ目は、softmax_cross_entropy_with_logits()を使う方法です。

cross_entropy_loss1 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
print(cross_entropy_loss1)  # [0.16984604]

2つ目は、categorical_crossentropy()の引数from_logitsをTrueにする方法です。

cross_entropy_loss2 = tf.losses.categorical_crossentropy(y_true=labels, y_pred=logits, from_logits=True)
print(cross_entropy_loss2)  # [0.16984604]

logitsと確率を間違えて使ってしまうと正しい値を計算できないので気をつけましょう。