【Python/keras】学習モデルの構造を可視化

Pythonと深層学習ライブラリ「Keras」で作成した学習モデルを可視化する方法についてまとめました。

【モデルの描画】model.summary()、plot_model

Pythonと深層学習ライブラリ「Keras」で作成した学習モデルの構造を可視化します。

動画解説

本ページの内容は以下動画でも解説しています。

model.summary()

model.summary()は、モデル構造を標準出力します。

plot_model

plot_modelは、モデルの構造を画像で出力できるものです。
plot_modelの引数の設定は以下のとおり。

引数 機能
show_shapes グラフ中に出力のshapeを出力するかどうか(デフォルト:False)
show_layer_names グラフ中にレイヤー名を出力するかどうか (デフォルトはTrue)を制御します.
expand_nested グラフ中にネストしたモデルをクラスタに展開するかどうか (デフォルトはFalse)
dpi 画像のdpi(デフォルトは96)

使用するのに別途モジュール等のインストールが必要です。

pip install pydot
pip install graphviz
brew install graphviz 
【Python/keras】'`pydot` failed to call GraphViz.'エラーなどが出た場合
Pythonと深層学習ライブラリ「Keras」で'pydot failed to call GraphViz.'エラーなどが出た場合の解決方法についてまとめました。

サンプルコード


#### 実行結果

【Python/Keras】モデル構造の可視化
Pythonの機械学習モジュール「Keras」でモデルを可視化する方法をソースコード付きでまとめました。

【その他】Modelクラスのプロパティ・メソッド

Modelクラスのプロパティ・メソッドをいかにまとめました。

プロパティ 概要
model.name modelの名前
model.inputs 入力のTensorクラスのリスト
model.outputs 出力のTensorクラスのリスト
model.layers Layerクラスのリスト
model.trainable modelがtrainableか(boolean)
model.input_shape 入力のshape
model.output_shape 出力のshape
model.weights Variableクラスのリスト
model.trainable_weights 学習可能なVariableクラスのリスト
model.non_trainable_weights 学習不可能なVariableクラスのリスト
メソッド 概要
model.summary(line_length=None, positions=None, print_fn=None) model構造をプリント表示
model.get_layer(name=None, index=None) 指定したLayerクラスを取得
model.get_weights() 全layerの重みのリストを取得
model.set_weights(weight) get_weights()で取得した各layerの重みを設定
model.get_config() モデルのコンフィグを辞書形式で取得
model.from_config(config, custom_objects=None) get_config()で取得したコンフィグからモデルを生成
model.to_json() モデルの構造をjson形式の文字列で取得
model.to_yaml() モデルの構造をYAML形式の文字列で取得
model.save(filepath, overwrite=True, include_optimizer=True) モデルの構造と重みと学習に関する設定や状態を全てHDF5形式でセーブ
model.save_weights(filepath, overwrite=True) モデルの重みをHDF5形式でセーブ
model.load_weights(filepath, by_name=False) save_weights()で保存した重みをモデルにロード
keras.models.load_model(filepath,custom_objects=None,compile=True) save()で保存されたモデルの状態をロード
keras.models.model_from_json(json_str) to_json()で取得したモデルの構造をロード
keras.models.model_from_yaml(yaml_str) to_yaml()で取得したモデルの構造をロード
【Keras入門】ディープラーニングの使い方/サンプル集
Pythonモジュール「Keras」で深層学習(ディープラーニング)を行う方法について入門者向けに使い方を解説します。

コメント

タイトルとURLをコピーしました