【Scikit-learn】データセットの種類と呼び出し(読み込み)

この記事では、Scikit-learnのデータセットの種類と呼び出し(読み込み)方について紹介します。

データセットの種類

Scikit-learnには、動作テストに便利なデータセットがいくつかあります。
データセットの内容は以下の通りです。

データの内容 予測対象
load_iris 3種類のアヤメのがく片、花弁の幅および長さ 分類
load_diabetes 糖尿病患者の検査数値と1年後の疾患進行状況 回帰
load_digits 0~9の手書き文字画像(8×8) 分類
load_boston 米国ボストン市郊外における地域別の住宅価格 回帰
load_linnerud 成人男性の生理学的特徴と運動能力 回帰
load_wine 3種類のワインの科学的特徴 分類
load_breast_cancer 乳がんの診断結果 分類

データセットの呼び出し

例として、アヤメのデータセットを呼び出してみます。

from matplotlib import pyplot as plt
from sklearn import datasets # データ・セット

def main():
    # Iris のデータを呼び出す
    iris = datasets.load_iris()
    X = iris.data[:, :2]  # 最初の二次元のみの特徴量を抽出
    Y = iris.target       # 目標値(正解データ)
    # グラフの軸幅
    x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
    y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
    # 可視化のベースを作成
    plt.figure(2, figsize=(8, 6))
    plt.clf()
    # 実際にプロット
    plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)
    plt.xlabel('Sepal length')
    plt.ylabel('Sepal width')
    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.grid()
    plt.show()

if __name__ == "__main__":
    main()

関連記事

関連記事
1 Scikit-learnをインストールする方法
2 Scikit-learn入門・使い方
3 機械学習のアルゴリズム入門