モモノキ&ナノネと一緒にRで機械学習を試してみよう
Rで機械学習、決定木・ランダムフォレスト・SVMの決定境界を色分けして塗り潰してみよう
ナノネ、統計解析用フリーソフト『R』の使い方を練習するよ。今回は機械学習で分類された決定境界を色分けして塗り潰してしてみよう。
前回やった決定境界の続き?
うん。前回は境界線を引いたから、今回は色分けして領域を塗り潰すのをやってみよう。
OK、塗り潰すだけね。
今回はimage()というのを使って領域の色分けを試してみるね。
image()の使い方は?
まだよくわからないけど、等高線と似たような感じで使えそうだよ。とりあえず実践。
前回と同じで下のグラフに各モデルの決定境界を追加するよ。モデルも前回と一緒、決定木・ランダムフォレスト・SVMでやってみるよ。
exclude.cols <- names(iris) %in% c('Sepal.Length', 'Sepal.Width')
df.iris = iris[!exclude.cols]
# Irisデータを単純にプロット
par(mar = c(6, 7, 5, 2))
par(mgp = c(4, 1.2, 0.4))
par(lwd = 2)
plot(df.iris$Petal.Length, df.iris$Petal.Width,
main = "Iris",
xlab = "Petal length (cm)",
ylab = "Petal width (cm)",
col = c(2, 3, 4)[df.iris$Species],
pch=c(0, 1, 2)[df.iris$Species],
cex = 1.5,
cex.main = 2,
cex.lab = 2,
cex.axis = 1.5,
xlim = c(0, 7.5),
xaxp = c(0, 7, 7),
ylim = c(0, 3),
yaxp = c(0, 3, 3)
)
par(family = "serif")
par(font = 3)
legend("topleft",
legend = levels(df.iris$Species),
pch = c(0, 1, 2),
col = c(2, 3, 4),
cex = 1.5,
pt.cex = 1.5,
bty = 'n',
inset = c(0.05, 0.03))
まずは決定木の決定境界を塗り潰すよ。
# 決定木
# install.packages("rpart") # 必要に応じてインスール
library(rpart)
# Irisデータ準備(Sepal.LengthとSepal.Widthを除外)
exclude.cols <- names(iris) %in% c('Sepal.Length', 'Sepal.Width')
df.iris = iris[!exclude.cols]
# Irisデータをそのままプロット(x軸Petal.Length、y軸Petal.Width)
par(mar = c(6, 7, 5, 2))
par(mgp = c(4, 1.2, 0.4))
par(lwd = 2)
plot(df.iris$Petal.Length, df.iris$Petal.Width,
main = "Iris (Decision tree)",
xlab = "Petal length (cm)",
ylab = "Petal width (cm)",
col = c(2, 3, 4)[df.iris$Species],
pch=c(0, 1, 2)[df.iris$Species],
cex = 1.5,
cex.main = 2,
cex.lab = 2,
cex.axis = 1.5,
xlim = c(0, 7.5),
xaxp = c(0, 7, 7),
ylim = c(0, 3),
yaxp = c(0, 3, 3),
type = "n" # プロットしない(軸だけ)
)
# モデル作成
set.seed(100)
model.rp = rpart(Species ~ ., data = df.iris)
# 決定境界プロット用のメッシュ作成
px <- seq(0, 7.5, 0.01)
py <- seq(0, 3, 0.01)
pgrid <- expand.grid(px, py)
names(pgrid) <- c("Petal.Length", "Petal.Width")
# モデルでメッシュデータの分類を全て予測
pred.pgrid.rp <- predict(model.rp, pgrid, type="class")
# メッシュデータで塗り潰し
my.colors <- c("#FFCCCC","#CCFFCC","#CCCCFF")
image(px, py, array(as.numeric(pred.pgrid.rp)-1, dim=c(length(px), length(py))),
xlim=c(0, 7.5), ylim=c(0, 3), add=T, col = my.colors)
# メッシュデータで等高線をプロット(決定境界)
contour(px, py, array(as.numeric(pred.pgrid.rp)-1, dim=c(length(px), length(py))),
xlim=c(0, 7.5), ylim=c(0, 3),
col="orange", lwd=2, drawlabels=F, add=T)
# Irisデータプロット
points(df.iris$Petal.Length, df.iris$Petal.Width,
cex = 1.5,
col = c(2, 3, 4)[df.iris$Species],
pch = c(0, 1, 2)[df.iris$Species])
par(family = "serif")
par(font = 3)
legend("topleft",
legend = levels(df.iris$Species),
pch = c(0, 1, 2),
col = c(2, 3, 4),
cex = 1.5,
pt.cex = 1.5,
bty = 'n',
inset = c(0.05, 0.03))
次はランダムフォレストの決定境界を塗り潰すよ。
# ランダムフォレスト
# install.packages("randomForest") # 必要に応じてインストール
library(randomForest)
# Irisデータ準備(Sepal.LengthとSepal.Widthを除外)
exclude.cols <- names(iris) %in% c('Sepal.Length', 'Sepal.Width')
df.iris = iris[!exclude.cols]
# Irisデータをそのままプロット(x軸Petal.Length、y軸Petal.Width)
par(mar = c(6, 7, 5, 2))
par(mgp = c(4, 1.2, 0.4))
par(lwd = 2)
plot(df.iris$Petal.Length, df.iris$Petal.Width,
main = "Iris (Random forest)",
xlab = "Petal length (cm)",
ylab = "Petal width (cm)",
col = c(2, 3, 4)[df.iris$Species],
pch=c(0, 1, 2)[df.iris$Species],
cex = 1.5,
cex.main = 2,
cex.lab = 2,
cex.axis = 1.5,
xlim = c(0, 7.5),
xaxp = c(0, 7, 7),
ylim = c(0, 3),
yaxp = c(0, 3, 3),
type = "n" # プロットしない(軸だけ)
)
# モデル作成
set.seed(100)
model.rf = randomForest(Species ~ ., data = df.iris)
# 決定境界プロット用のメッシュ作成
px <- seq(0, 7.5, 0.01)
py <- seq(0, 3, 0.01)
pgrid <- expand.grid(px, py)
names(pgrid) <- c("Petal.Length", "Petal.Width")
# モデルでメッシュデータの分類を全て予測
pred.pgrid.rf <- predict(model.rf, pgrid)
# メッシュデータで塗り潰し
my.colors <- c("#FFCCCC","#CCFFCC","#CCCCFF")
image(px, py, array(as.numeric(pred.pgrid.rf)-1, dim=c(length(px), length(py))),
xlim=c(0, 7.5), ylim=c(0, 3), add=T, col = my.colors)
# メッシュデータで等高線をプロット(決定境界)
contour(px, py, array(as.numeric(pred.pgrid.rf)-1, dim=c(length(px), length(py))),
xlim=c(0, 7.5), ylim=c(0, 3),
col="orange", lwd=2, drawlabels=F, add=T)
# Irisデータプロット
points(df.iris$Petal.Length, df.iris$Petal.Width,
cex = 1.5,
col = c(2, 3, 4)[df.iris$Species],
pch = c(0, 1, 2)[df.iris$Species])
par(family = "serif")
par(font = 3)
legend("topleft",
legend = levels(df.iris$Species),
pch = c(0, 1, 2),
col = c(2, 3, 4),
cex = 1.5,
pt.cex = 1.5,
bty = 'n',
inset = c(0.05, 0.03))
最後はSVM(サポートベクトルマシン)の決定境界を塗り潰すよ。
# SVM(サポートベクトルマシン)
# install.packages( "kernlab" ) # 必要に応じてインストール
library(kernlab)
# Irisデータ準備(Sepal.LengthとSepal.Widthを除外)
exclude.cols <- names(iris) %in% c('Sepal.Length', 'Sepal.Width')
df.iris = iris[!exclude.cols]
# Irisデータをそのままプロット(x軸Petal.Length、y軸Petal.Width)
par(mar = c(6, 7, 5, 2))
par(mgp = c(4, 1.2, 0.4))
par(lwd = 2)
plot(df.iris$Petal.Length, df.iris$Petal.Width,
main = "Iris (SVM)",
xlab = "Petal length (cm)",
ylab = "Petal width (cm)",
col = c(2, 3, 4)[df.iris$Species],
pch=c(0, 1, 2)[df.iris$Species],
cex = 1.5,
cex.main = 2,
cex.lab = 2,
cex.axis = 1.5,
xlim = c(0, 7.5),
xaxp = c(0, 7, 7),
ylim = c(0, 3),
yaxp = c(0, 3, 3),
type = "n" # プロットしない(軸だけ)
)
# モデル作成
set.seed(100)
model.svm = ksvm(Species ~ ., data = df.iris)
# 決定境界プロット用のメッシュ作成
px <- seq(0, 7.5, 0.01)
py <- seq(0, 3, 0.01)
pgrid <- expand.grid(px, py)
names(pgrid) <- c("Petal.Length", "Petal.Width")
# モデルでメッシュデータの分類を全て予測
pred.pgrid.svm <- predict(model.svm, pgrid)
# メッシュデータで塗り潰し
my.colors <- c("#FFCCCC","#CCFFCC","#CCCCFF")
image(px, py, array(as.numeric(pred.pgrid.svm)-1, dim=c(length(px), length(py))),
xlim=c(0, 7.5), ylim=c(0, 3), add=T, col = my.colors)
# メッシュデータで等高線をプロット(決定境界)
contour(px, py, array(as.numeric(pred.pgrid.svm)-1, dim=c(length(px), length(py))),
xlim=c(0, 7.5), ylim=c(0, 3),
col="orange", lwd=2, drawlabels=F, add=T)
# Irisデータプロット
points(df.iris$Petal.Length, df.iris$Petal.Width,
cex = 1.5,
col = c(2, 3, 4)[df.iris$Species],
pch = c(0, 1, 2)[df.iris$Species])
par(family = "serif")
par(font = 3)
legend("topleft",
legend = levels(df.iris$Species),
pch = c(0, 1, 2),
col = c(2, 3, 4),
cex = 1.5,
pt.cex = 1.5,
bty = 'n',
inset = c(0.05, 0.03))
処理順は、plot()で軸だけ描画、image()で境界を塗り潰す、contour()で境界線を引く、最後にpoints()でデータプロットだね。
描画する順序に気を付けないと、見えなくなっちゃうからね。
グラフ3枚ループでいっぺんに描けそうだけど?
もうちょっとRに慣れたらね。
またね!
0 件のコメント :
コメントを投稿