読者です 読者をやめる 読者になる 読者になる

scikit-learnのSVMを使って多クラス分類を試す

python machine learning

このエントリーをはてなブックマークに追加

Pythonのscikit-learnを勉強中です。今回は、公式ページにある、手書き文字を0から9に分類するコード (Recognizing hand-written digits — scikit-learn 0.16.1 documentation) を読み解いてみます。

準備

Python 3でscikit-learnを使えるようにします。今回は Download Anaconda Python Distribution で配布されているAnacondaを使いました。

コード

手書き文字データセット

scikit-learnに付属する、8x8の白黒画像で表現された手書き文字データセットを使用します。

from sklearn import datasets
digits = datasets.load_digits()

このデータセットには1797個分のデータが含まれています。digits.imagesには、8x8サイズの2次元リストで表現された白黒画像データが1797個分入っています。値域は0-16です。 digits.targetには0から9のラベルが1797個分入っています。

データの変形

8x8の画像データを64x1の1次元リストにフラット化します。

n_samples = len(digits.images)                # サンプル数
data = digits.images.reshape((n_samples, -1)) # 8x8の2次元データを、64次元の1次元ベクトルに変形

分類器の学習

1797個分の半分のデータを使って多クラス分類器を学習します。使用する分類器はsvm.SVCです。このSVCでは、one-vs-one戦略(すべてのクラス間の組について分類器を学習)を採用しているそうです。

カーネルはデフォルト(RBFカーネル)を使っています。

# Create a classifier: a support vector classifier
# SVMを用いて多クラス分類(one-vs-one戦略を使用)
# RBFカーネル(Gaussianカーネル)のハイパーパラメータであるガンマの値を設定
# この値が大きいほど複雑な決定境界となる
classifier = svm.SVC(gamma=0.001)

# We learn the digits on the first half of the digits
# データの半分を学習に使う
classifier.fit(data[:n_samples / 2], digits.target[:n_samples / 2])

分類

学習した分類器を使って、残りのデータセットを分類します。

predicted = classifier.predict(data[n_samples / 2:])

結果の表示

SVCの詳細情報の表示

分類に使用したSVCの詳細情報と、分類結果をprintします。

expected = digits.target[n_samples / 2:]
predicted = classifier.predict(data[n_samples / 2:])
print("Classification report for classifier %s:\n%s\n"
      % (classifier,   # 分類に用いたSVNの詳細情報
         metrics.classification_report(expected, predicted)  # 各ラベルごとの分類結果
     )
)

表示例は以下です。

Classification report for classifier SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
  gamma=0.001, kernel='rbf', max_iter=-1, probability=False,
  random_state=None, shrinking=True, tol=0.001, verbose=False):
             precision    recall  f1-score   support

          0       1.00      0.99      0.99        88
          1       0.99      0.97      0.98        91
          2       0.99      0.99      0.99        86
          3       0.98      0.87      0.92        91
          4       0.99      0.96      0.97        92
          5       0.95      0.97      0.96        91
          6       0.99      0.99      0.99        91
          7       0.96      0.99      0.97        89
          8       0.94      1.00      0.97        88
          9       0.93      0.98      0.95        92

avg / total       0.97      0.97      0.97       899

さらに、confusion matrix(実際のラベルと、分類先ラベルとを一覧の表にしたもの。分類性能が高いほど対角線に値が集まる)も表示できます。

print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))

表示例は以下です。

Confusion matrix:
[[87  0  0  0  1  0  0  0  0  0]
 [ 0 88  1  0  0  0  0  0  1  1]
 [ 0  0 85  1  0  0  0  0  0  0]
 [ 0  0  0 79  0  3  0  4  5  0]
 [ 0  0  0  0 88  0  0  0  0  4]
 [ 0  0  0  0  0 88  1  0  0  2]
 [ 0  1  0  0  0  0 90  0  0  0]
 [ 0  0  0  0  0  1  0 88  0  0]
 [ 0  0  0  0  0  0  0  0 88  0]
 [ 0  0  0  1  0  1  0  0  0 90]]

コード全文

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# コピー元:http://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html#example-classification-plot-digits-classification-py

print(__doc__)

# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# License: BSD 3 clause

# Standard scientific Python imports
import matplotlib.pyplot as plt

# Import datasets, classifiers and performance metrics
from sklearn import datasets, svm, metrics

# The digits dataset
digits = datasets.load_digits()

# The data that we are interested in is made of 8x8 images of digits, let's
# have a look at the first 3 images, stored in the `images` attribute of the
# dataset.  If we were working from image files, we could load them using
# pylab.imread.  Note that each image must have the same size. For these
# images, we know which digit they represent: it is given in the 'target' of
# the dataset.
# 手書き数字データセットの学習サンプルを4個描画
images_and_labels = list(zip(digits.images, digits.target))
print("dataset size = ", len(images_and_labels))
for index, (image, label) in enumerate(images_and_labels[:4]):
    plt.subplot(2, 4, index + 1) # 2x4マスの上半分に描画
    plt.axis('off')
    plt.imshow(image,
               cmap=plt.cm.gray_r,     # cmap = color map
               interpolation='nearest'
    )
    plt.title('Training: %i' % label)

# To apply a classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
n_samples = len(digits.images)                # サンプル数

# もともとのサイズは 1794 x 8 x 8
print("digits.images shape = ", digits.images.shape)

data = digits.images.reshape((n_samples, -1)) # 8x8の部分を64次元の1次元ベクトルに変形

# 変換後のサイズは1794 x 64
# つまり、画像を64次元の1次元ベクトルとした
print("digits.images shape = ", data.shape)

# Create a classifier: a support vector classifier
# SVMを用いて多クラス分類(one-vs-one戦略を使用)
# RBFカーネル(Gaussianカーネル)のハイパーパラメータであるガンマの値を設定
# この値が大きいほど複雑な決定境界となる
classifier = svm.SVC(gamma=0.001)

# We learn the digits on the first half of the digits
# データの半分を学習に使う
classifier.fit(data[:n_samples / 2], digits.target[:n_samples / 2])

# 学習データの一つを見てみる
print("training data example: ")
print(data[0])

# Now predict the value of the digit on the second half:
# 残りの半分に対して、GTと、SVMの予測結果を比較する
expected = digits.target[n_samples / 2:]
predicted = classifier.predict(data[n_samples / 2:])


print("Classification report for classifier %s:\n%s\n"
      % (classifier,   # 分類に用いたSVNの詳細情報
         metrics.classification_report(expected, predicted)  # 各ラベルごとの分類結果
     )
)
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))


# SVMによる分類例を4個描画
images_and_predictions = list(zip(digits.images[n_samples / 2:], predicted))
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
    plt.subplot(2, 4, index + 5)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prediction: %i' % prediction)

plt.show()

参考