決定木を使ったIrisデータの分類方法
Irisデータ
scikit-learnに付属のIris(アヤメの計測データ)を利用しています。
Irisデータは『setosa』、『versicolor』、『virginica』という3種類の品種のアヤメの”がく片 (Sepal)”と”花弁 (Petal)” の幅および長さを150点計測したデータです。
sepal length(cm) | がく片の長さ |
---|---|
sepal width(cm) | がく片の幅 |
petal length(cm) | 花弁の長さ |
petal width(cm) | 花弁の幅 |
トレーニングデータによる学習(モデル作成)
コードは以下の順で記載しています。- ライブリーのインポート、Irisデータ読み込み
- 訓練データとテストデータの分割
- 訓練データの内容確認、可視化
- モデル設定と訓練データを使った学習
- 学習結果の確認
- 決定木モデルの可視化
% matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import load_iris
iris_dataset = load_iris() # Irisデータ読み込み
# 訓練データとテストデータに分割(train:test=70%:30%)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
iris_dataset['data'], iris_dataset['target'],
test_size=0.3, random_state=0)
# Pandasデータフレーム作成(訓練データの内容確認用)
iris_df = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
iris_df['species'] = y_train
# 訓練データの件数確認
print('tarin data',len(iris_df))
print('0:setosa ',len(iris_df[iris_df['species']==0]))
print('1:versicolor ', len(iris_df[iris_df['species']==1]))
print('2:virginica', len(iris_df[iris_df['species']==2]))
# 訓練データをグラフで可視化
import seaborn as sns
sns.pairplot(iris_df, hue='species')
# モデル設定と訓練データを使った学習
from sklearn import tree
clf = tree.DecisionTreeClassifier(max_depth=3) # 決定木モデル(最大深さ3)
clf = clf.fit(X_train, y_train) # 訓練データで学習
# 訓練データでの正解率(学習検証)
from sklearn import metrics
predict_train = clf.predict(X_train)
ac_score = metrics.accuracy_score(y_train, predict_train)
print('train score: {0:.2f}%'.format(ac_score * 100))
# 作成された決定木モデルを可視化
import pydotplus
from sklearn.externals.six import StringIO
from IPython.display import Image
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data,
feature_names=iris_dataset.feature_names,
class_names=iris_dataset.target_names,
filled=True, rounded=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
テストデータによる作成したモデルの検証
コードは以下の順で記載しています。- モデルの作成まで(上のコード抜粋)
- テストデータで分類予測
- モデルの正解率確認
% matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import load_iris
iris_dataset = load_iris() # Irisデータ読み込み
# 訓練データとテストデータに分割
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
iris_dataset['data'], iris_dataset['target'],
test_size=0.3, random_state=0)
# モデル設定と訓練データを使った学習
from sklearn import tree
clf = tree.DecisionTreeClassifier(max_depth=3) # 決定木モデル(最大深さ3)
clf = clf.fit(X_train, y_train) # 訓練データで学習
# 作成したモデルでテストデータを分類予測
predict_test = clf.predict(X_test)
print(predict_test) # テストデータの予測結果
print(y_test) # テストデータの正解ラベル
# テストデータでの正解率(モデル検証)
from sklearn import metrics
ac_score = metrics.accuracy_score(y_test, predict_test)
print('test score: {0:.2f}%'.format(ac_score * 100))
今回の結果では、97.78%の正解率でアヤメの種類を識別できています。
0 件のコメント :
コメントを投稿