ぺーぺーSEのブログ

備忘録・メモ用サイト。

matplotlibで3Dグラフを描画する

matplotlibMATLABライクにグラフを描画できるライブラリ。
基本は以下。

blog.pepese.com

ここでは3Dで表示する方法について記載する。
mpl_toolkits.mplot3dというモジュールを使用する。
使用するデータの準備は以下。

import numpy as np
from sklearn import datasets
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# %pylab inline --no-import-all
 
boston = datasets.load_boston()
rooms = boston.data[:, 5]
criminals = boston.data[:, 0]
house_prices = boston.target

%pylab inline --no-import-all」を使用するとJupyter(IPython) Notebook上にグラフを表示できるが、3D回転してグラフを見たい場合は使用しない。(ここではコメントアウトしている)

3D描画は以下。
Axes3Dというモジュールを使用する。

fig = plt.figure() # figureオブジェクトを取得
ax = fig.gca(projection='3d') # 軸(ax)オブジェクトを3d指定で取得
# 他にも下記のように作成可能
# ax = fig.add_subplot(111,projection='3d')
# ax = Axes3D(fig)

# データセットの描画
ax.scatter3D(rooms, criminals, house_prices)
ax.set_xlabel("x_1:rooms")
ax.set_ylabel("x_2:criminals")
ax.set_zlabel("y:house_prices")

# プロットする各軸の点を作成
x_0 = 1
x_1 = np.arange(3, 10, 1)
x_2 = np.arange(-20, 100, 1)
ax_1, ax_2 = np.meshgrid(x_1, x_2)

# 重回帰を解くと下記の式になる。(ここでは解かない)
y = - 29.30168135 * x_0 + 8.3975317 * ax_1 + 0.2618229 * ax_2

# 平面をワイヤーフレーム形式で描画
ax.plot_wireframe(ax_1, ax_2, y)
# 他にも様々な形式で描画可能
# ax.plot_surface(ax_1, ax_2, y)

plt.show()

その他の描画の方法(plot_xxxxx)については以下参照。

http://matplotlib.org/mpl_toolkits/mplot3d/tutorial.html