matplotlibで3Dグラフを描画する
matplotlibはMATLABライクにグラフを描画できるライブラリ。
基本は以下。
ここでは3Dで表示する方法について記載する。
mpl_toolkits.mplot3dというモジュールを使用する。
使用するデータの準備は以下。
import numpy as np from sklearn import datasets from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D # %pylab inline --no-import-all boston = datasets.load_boston() rooms = boston.data[:, 5] criminals = boston.data[:, 0] house_prices = boston.target
「%pylab inline --no-import-all」を使用するとJupyter(IPython) Notebook上にグラフを表示できるが、3D回転してグラフを見たい場合は使用しない。(ここではコメントアウトしている)
3D描画は以下。
Axes3Dというモジュールを使用する。
fig = plt.figure() # figureオブジェクトを取得 ax = fig.gca(projection='3d') # 軸(ax)オブジェクトを3d指定で取得 # 他にも下記のように作成可能 # ax = fig.add_subplot(111,projection='3d') # ax = Axes3D(fig) # データセットの描画 ax.scatter3D(rooms, criminals, house_prices) ax.set_xlabel("x_1:rooms") ax.set_ylabel("x_2:criminals") ax.set_zlabel("y:house_prices") # プロットする各軸の点を作成 x_0 = 1 x_1 = np.arange(3, 10, 1) x_2 = np.arange(-20, 100, 1) ax_1, ax_2 = np.meshgrid(x_1, x_2) # 重回帰を解くと下記の式になる。(ここでは解かない) y = - 29.30168135 * x_0 + 8.3975317 * ax_1 + 0.2618229 * ax_2 # 平面をワイヤーフレーム形式で描画 ax.plot_wireframe(ax_1, ax_2, y) # 他にも様々な形式で描画可能 # ax.plot_surface(ax_1, ax_2, y) plt.show()
その他の描画の方法(plot_xxxxx)については以下参照。