Pythonのscikit-learnを勉強中です。今回は、公式ページにある、手書き文字を0から9に分類するコード (Recognizing hand-written digits — scikit-learn 0.16.1 documentation) を読み解いてみます。
準備
Python 3でscikit-learnを使えるようにします。今回は Download Anaconda Python Distribution で配布されているAnacondaを使いました。
コード
手書き文字データセット
scikit-learnに付属する、8x8の白黒画像で表現された手書き文字データセットを使用します。
from sklearn import datasets digits = datasets.load_digits()
このデータセットには1797個分のデータが含まれています。digits.imagesには、8x8サイズの2次元リストで表現された白黒画像データが1797個分入っています。値域は0-16です。 digits.targetには0から9のラベルが1797個分入っています。
データの変形
8x8の画像データを64x1の1次元リストにフラット化します。
n_samples = len(digits.images) # サンプル数 data = digits.images.reshape((n_samples, -1)) # 8x8の2次元データを、64次元の1次元ベクトルに変形
分類器の学習
1797個分の半分のデータを使って多クラス分類器を学習します。使用する分類器はsvm.SVCです。このSVCでは、one-vs-one戦略(すべてのクラス間の組について分類器を学習)を採用しているそうです。
# Create a classifier: a support vector classifier # SVMを用いて多クラス分類(one-vs-one戦略を使用) # RBFカーネル(Gaussianカーネル)のハイパーパラメータであるガンマの値を設定 # この値が大きいほど複雑な決定境界となる classifier = svm.SVC(gamma=0.001) # We learn the digits on the first half of the digits # データの半分を学習に使う classifier.fit(data[:n_samples / 2], digits.target[:n_samples / 2])
分類
学習した分類器を使って、残りのデータセットを分類します。
predicted = classifier.predict(data[n_samples / 2:])
結果の表示
SVCの詳細情報の表示
分類に使用したSVCの詳細情報と、分類結果をprintします。
expected = digits.target[n_samples / 2:] predicted = classifier.predict(data[n_samples / 2:]) print("Classification report for classifier %s:\n%s\n" % (classifier, # 分類に用いたSVNの詳細情報 metrics.classification_report(expected, predicted) # 各ラベルごとの分類結果 ) )
表示例は以下です。
Classification report for classifier SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.001, kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False): precision recall f1-score support 0 1.00 0.99 0.99 88 1 0.99 0.97 0.98 91 2 0.99 0.99 0.99 86 3 0.98 0.87 0.92 91 4 0.99 0.96 0.97 92 5 0.95 0.97 0.96 91 6 0.99 0.99 0.99 91 7 0.96 0.99 0.97 89 8 0.94 1.00 0.97 88 9 0.93 0.98 0.95 92 avg / total 0.97 0.97 0.97 899
さらに、confusion matrix(実際のラベルと、分類先ラベルとを一覧の表にしたもの。分類性能が高いほど対角線に値が集まる)も表示できます。
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))
表示例は以下です。
Confusion matrix: [[87 0 0 0 1 0 0 0 0 0] [ 0 88 1 0 0 0 0 0 1 1] [ 0 0 85 1 0 0 0 0 0 0] [ 0 0 0 79 0 3 0 4 5 0] [ 0 0 0 0 88 0 0 0 0 4] [ 0 0 0 0 0 88 1 0 0 2] [ 0 1 0 0 0 0 90 0 0 0] [ 0 0 0 0 0 1 0 88 0 0] [ 0 0 0 0 0 0 0 0 88 0] [ 0 0 0 1 0 1 0 0 0 90]]
コード全文
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # コピー元:http://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html#example-classification-plot-digits-classification-py print(__doc__) # Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org> # License: BSD 3 clause # Standard scientific Python imports import matplotlib.pyplot as plt # Import datasets, classifiers and performance metrics from sklearn import datasets, svm, metrics # The digits dataset digits = datasets.load_digits() # The data that we are interested in is made of 8x8 images of digits, let's # have a look at the first 3 images, stored in the `images` attribute of the # dataset. If we were working from image files, we could load them using # pylab.imread. Note that each image must have the same size. For these # images, we know which digit they represent: it is given in the 'target' of # the dataset. # 手書き数字データセットの学習サンプルを4個描画 images_and_labels = list(zip(digits.images, digits.target)) print("dataset size = ", len(images_and_labels)) for index, (image, label) in enumerate(images_and_labels[:4]): plt.subplot(2, 4, index + 1) # 2x4マスの上半分に描画 plt.axis('off') plt.imshow(image, cmap=plt.cm.gray_r, # cmap = color map interpolation='nearest' ) plt.title('Training: %i' % label) # To apply a classifier on this data, we need to flatten the image, to # turn the data in a (samples, feature) matrix: n_samples = len(digits.images) # サンプル数 # もともとのサイズは 1794 x 8 x 8 print("digits.images shape = ", digits.images.shape) data = digits.images.reshape((n_samples, -1)) # 8x8の部分を64次元の1次元ベクトルに変形 # 変換後のサイズは1794 x 64 # つまり、画像を64次元の1次元ベクトルとした print("digits.images shape = ", data.shape) # Create a classifier: a support vector classifier # SVMを用いて多クラス分類(one-vs-one戦略を使用) # RBFカーネル(Gaussianカーネル)のハイパーパラメータであるガンマの値を設定 # この値が大きいほど複雑な決定境界となる classifier = svm.SVC(gamma=0.001) # We learn the digits on the first half of the digits # データの半分を学習に使う classifier.fit(data[:n_samples / 2], digits.target[:n_samples / 2]) # 学習データの一つを見てみる print("training data example: ") print(data[0]) # Now predict the value of the digit on the second half: # 残りの半分に対して、GTと、SVMの予測結果を比較する expected = digits.target[n_samples / 2:] predicted = classifier.predict(data[n_samples / 2:]) print("Classification report for classifier %s:\n%s\n" % (classifier, # 分類に用いたSVNの詳細情報 metrics.classification_report(expected, predicted) # 各ラベルごとの分類結果 ) ) print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted)) # SVMによる分類例を4個描画 images_and_predictions = list(zip(digits.images[n_samples / 2:], predicted)) for index, (image, prediction) in enumerate(images_and_predictions[:4]): plt.subplot(2, 4, index + 5) plt.axis('off') plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Prediction: %i' % prediction) plt.show()