【Scikit-learn】ニューラルネットワーク学習モデルのファイル出力・保存

この記事では、Pythonと機械学習ライブラリ「scikit-learn」を用いて、ニューラルネットワーク(NN)で学習したモデルをファイルに出力し、保存する方法とソースコードを解説します。

ニューラルネットワークとは

前回までは、Python + scikit-learnでニューラルネットワーク(パーセプトロン方式)を実装し、学習・予測・識別率の計算を行いました。

前回までの記事
1 【Scikit-learn】ニューラルネットワークで学習・予測
2 【Scikit-learn】ニューラルネットワークの識別率を計算
3 ニューラルネットワークの原理・計算式・特徴

今回は学習データをファイル出力(保存)してみます。

書式

sklearn.externals.joblib.dump(clf, filepath)
パラメータ 説明
clf 学習データ
filepath 出力先のファイルパス

ソースコード

サンプルプログラムのソースコードは下記の通りです。

# -*- coding: utf-8 -*-
import pandas as pd
from sklearn.neural_network import MLPClassifier
from sklearn.externals import joblib

def main():
    # データを取得
    data = pd.read_csv("data.csv", sep=",")

    # ニューラルネットで学習
    clf = MLPClassifier(solver="sgd",random_state=0,max_iter=10000)

    # 学習(説明変数x1, x2、目的変数x3)
    clf.fit(data[['x1', 'x2']], data['x3'])

    # 学習データを元に説明変数x1, x2から目的変数x3を予測
    pred = clf.predict(data[['x1', 'x2']])

    # 結果表示
    print (pred)
    joblib.dump(clf, 'nn.learn')
    
if __name__ == "__main__":
    main()

data.csv

x1,x2,x3
45,17.5,30
38,17.0,25
41,18.5,20
34,18.5,30
59,16.0,45
47,19.0,35
35,19.5,25
43,16.0,35
54,18.0,35
52,19.0,40

実行結果

サンプルプログラムの実行結果は下記の通りです。

【学習ファイル】
nn.learn

関連記事
1 Scikit-learn入門・使い方
2 Scikit-learnをインストールする方法
3 Python入門 基本文法
関連記事