今日も窓辺でプログラム

外資系企業勤めのエンジニアが勉強した内容をまとめておくブログ

scikit-learnのtrain_test_splitで訓練データとテストデータを分割する

はじめに

scikit-learnのtrain_test_splitという関数を使うと、データセットを訓練データをテストデータに簡単に分割できます。
同じくscikit-learnに付属している数字手書き文字のデータセットを使用した例を紹介します。

importとデータセットの用意

まずはtrain_test_split関数をimportし、説明に使うデータセットを用意します。私はscikit-learnのバージョン0.19.1を使用していますが、以前のバージョンではtrain_test_splitはsklearn.cross_validationにて定義されているので注意してください。

import numpy as np

# train_test_split をインポート
# 古いバージョンのscikit-learnを使用していると 
# sklearn.cross_validation からimportする必要があるようです。
from sklearn.model_selection import train_test_split

# 例として使うデータセット
from sklearn import datasets

# 数字手書き文字のデータセットをロード
digits = datasets.load_digits()

# 1797個のサンプルデータが入っている
print('size:', digits.data.shape[0])
print('data:', digits.data)
print('target:', digits.target)

実行結果はこんな形です。1797個のデータがロードされ、digits.dataに学習に使うデータ、
digits.targetに正解ラベルが格納されている様子が確認できます。

size: 1797
data: [[  0.   0.   5. ...,   0.   0.   0.]
 [  0.   0.   0. ...,  10.   0.   0.]
 [  0.   0.   0. ...,  16.   9.   0.]
 ..., 
 [  0.   0.   1. ...,   6.   0.   0.]
 [  0.   0.   2. ...,  12.   0.   0.]
 [  0.   0.  10. ...,  12.   1.   0.]]
target: [0 1 2 ..., 8 9 8]

test_train_splitで分割

train_test_split関数の第1引数に入力データ、第2引数に正解ラベルの配列を渡します。
test_sizeではテストデータのサイズ(割合)を0.0~1.0の実数で指定できます。

例えば20%をテストデータとして取り分けるには、次のようにします。

# 訓練:テスト = 8:2に分割
X_train, X_test, T_train, T_test = train_test_split(
    digits.data, digits.target, test_size=0.2)

訓練データの入力と正解ラベルがX_trainとT_trainに、テストデータの入力と正解ラベルはX_testとT_testに格納されます。
試しに入力データのサイズを見てみると、確かに8:2に分割されています。

# 訓練データとテストデータのサイズ
print('train size:', X_train.shape[0])
print('test size:', X_test.shape[0])

# train size: 1437
# test size: 360

実際に訓練データの中身を表示して確認してみます。

# 訓練データの中身を確認
print('train data:', X_train)
print('train target:', T_train)
train data: [[  0.   0.   6. ...,   3.   0.   0.]
 [  0.   4.  13. ...,   0.   0.   0.]
 [  0.   0.   9. ...,  15.   6.   0.]
 ..., 
 [  0.   1.  11. ...,  10.   1.   0.]
 [  0.   0.   7. ...,   3.   0.   0.]
 [  0.   0.   0. ...,   0.   0.   0.]]
train target: [5 3 3 ..., 3 6 4]

順番がシャッフルされデータが分割されています。もし順番をシャッフルしたくない場合は、

# 訓練:テスト = 8:2に分割
X_train, X_test, T_train, T_test = train_test_split(
    digits.data, digits.target, test_size=0.2, shuffle=False)

とすると順番を保存したままデータセットを分割してくれます。

以上、小粒ネタですがメモ代わりに記事に残しておきました。