Pythonでロジスティック回帰
分類問題をPythonとロジスティック回帰で解いてみる。
データセットは下記のIrisデータセットを使用する。
線形回帰(単回帰)は下記。
線形回帰(重回帰)は下記。
ライブラリ、データセットのロードまでは以下。
import numpy as np from sklearn import datasets from matplotlib import pyplot as plt from sklearn import linear_model %pylab inline --no-import-all iris = datasets.load_iris()
データの様子が見たい場合は以下。
features = iris.data target = iris.target target_names = iris.target_names labels = target_names[target] setosa_petal_length = features[labels == 'setosa', 2] setosa_petal_width = features[labels == 'setosa', 3] setosa = np.c_[setosa_petal_length, setosa_petal_width] versicolor_petal_length = features[labels == 'versicolor', 2] versicolor_petal_width = features[labels == 'versicolor', 3] versicolor = np.c_[versicolor_petal_length, versicolor_petal_width] virginica_petal_length = features[labels == 'virginica', 2] virginica_petal_width = features[labels == 'virginica', 3] virginica = np.c_[virginica_petal_length, virginica_petal_width] plt.scatter(setosa[:, 0], setosa[:, 1], color='red') plt.scatter(versicolor[:, 0], versicolor[:, 1], color='blue') plt.scatter(virginica[:, 0], virginica[:, 1], color='green') plt.show()
ここでは、花びらの長さ(petal length)と花びらの幅(petal width)から、3種のアヤメ(setosa、versicolor、virginica)を分類する。
ライブラリで解く
scikit-learnで解く
「sklearn.linear_model.logistic.LogisticRegression」を使用して解く。
X = iris.data[:, [2, 3]] # irisのデータセットの第3, 4カラム(花びらの長さと幅) y = iris.target # irisのそれぞれのデータごとのラベル、アヤメの種類で0、1、2が格納されている lr = linear_model.LogisticRegression(C=1e5) # データセットから学習 lr.fit(X, y)
学習したモデル(lr)のpredict関数に花びらの長さ・幅を入力するとアヤメの種類を返してくれる。
predict関数の引数はnp.arrayでもよい。
# 花びらの長さ(x_1と置く)の値域の最小値-0.5、最大値+0.5を求める x_1_min, x_1_max = X[:, 0].min() - .5, X[:, 0].max() + .5 # 花びらの幅(x_2と置く)の値域の最小値-0.5、最大値+0.5を求める x_2_min, x_2_max = X[:, 1].min() - .5, X[:, 1].max() + .5 # 上記のx_1、x_2の値域から0.1刻みでプロット点を生成する ax_1, ax_2 = np.meshgrid(np.arange(x_1_min, x_1_max, 0.1), np.arange(x_2_min, x_2_max, 0.1)) # 上記で作成した(x_1, x_2)プロット点をモデルに代入し、アヤメの種類(0or1or2)を得る Z = lr.predict(np.c_[ax_1.ravel(), ax_2.ravel()]) # プロット点と同じ行列(行列×列)に整形する Z = Z.reshape(ax_1.shape)
描画は下記。
# 分類結果の描画 plt.pcolormesh(ax_1, ax_2, Z, cmap=plt.cm.Paired) # 線形回帰の結果のように直線(曲線)を描画するのではなく、 # グラフ上の点に対してアヤメの種類(0、1、2)によって色を付ける # 結果、境界が見える # データセットの描画 features = iris.data target = iris.target target_names = iris.target_names labels = target_names[target] setosa_petal_length = features[labels == 'setosa', 2] setosa_petal_width = features[labels == 'setosa', 3] setosa = np.c_[setosa_petal_length, setosa_petal_width] versicolor_petal_length = features[labels == 'versicolor', 2] versicolor_petal_width = features[labels == 'versicolor', 3] versicolor = np.c_[versicolor_petal_length, versicolor_petal_width] virginica_petal_length = features[labels == 'virginica', 2] virginica_petal_width = features[labels == 'virginica', 3] virginica = np.c_[virginica_petal_length, virginica_petal_width] plt.scatter(setosa[:, 0], setosa[:, 1], color='red') plt.scatter(versicolor[:, 0], versicolor[:, 1], color='blue') plt.scatter(virginica[:, 0], virginica[:, 1], color='green') plt.show()