【Python/Keras】CNNと転移学習・ファインチューニングで画像認識・分類

Pythonの機械学習モジュール「Keras」でCNN(畳み込みニューラルネットワーク)を実装し、転移学習・Fine-tuningで学習して画像認識・分類する方法をソースコード付きでまとめました。

スポンサーリンク

【Keras】CNNとFine-tuningで画像分類器の作成

CNN(畳み込みニューラルネット)を使って画像分類器を作成するには「万単位の膨大な学習用データ画像」と「長い学習時間」「強力なマシンパワー」が必要となります。
しかし、Fine-tuning(ファインチューニングや転移学習)という手法を用いれば、「少量のデータセット」「短い学習時間」「貧弱なマシンパワー」で比較的精度の高い画像分類器を作成できます。

用語 説明
VGG16 1000種類の膨大な画像データセット「ImageNet」で作成された16層の学習済みCNNモデル。
転移学習 既存の学習済モデルの最終出力層を付け替えます。学習済の重みは変更せずに、新規のデータを学習させて効率的に新しいモデルを学習する方法です(出力層のパラメータは変更されます)。1から新規で学習せずに、VGG16やCIFAR-10など、既存の学習済みモデルを活用することで、少数のデータセットで精度の高い学習モデルを作成できます。例えば、(飛行機、自動車、バイク)を分類する学習済みモデルを、(飛行機、バイク)のデータで再学習させ、その2つを分類するモデルを作ったりできます。
ファインチューニング 既存の学習済モデルの最終出力層を付け替えます。また、学習済みの重みのうち、一部(出力層に近い深いところにある畳み込み層:特徴抽出器)を再学習します。浅い層(入力層)に近いほど、輪郭やブロブなど汎用的な特徴が抽出される傾向にあります。一方m深い層ほど学習データに特化した特徴が抽出される傾向があります。そこで、ファインチューニングでは浅い層にある汎用的な畳込み層は学習せず(frozen)、深い層のみを新規のデータで学習させます。1から新規で学習せずに、VGG16やCIFAR-10など、既存の学習済みモデルを活用することで、少数のデータセットで精度の高い学習モデルを作成できます。例えば、(飛行機、自動車、バイク)を分類する学習済みモデルを、(飛行機、トラック、船舶)のデータで再学習させ、それらを分類するモデルなど、既存の学習済モデルと比べて全く異なる学習モデルを作ったりできます。

サンプルコード

サンプルプログラムのソースコードです。

VGG16のデータセットを使ってファインチューニングにより学習を行い、学習結果を保存します。
「ImageDataGenerator.flow_from_directory」指定ディレクトリ内の画像をディレクトリ名でラベル付けしてくれます。
ですので、各ディレクトリには round_grasses などのラベル名を付けた上で画像を格納しています。


【学習用画像の例】
●土鍋の学習用画像(C:\github\sample\python\keras\05_vgg16\ex2_data\train\donabe内に格納)

●マグカップの学習用画像(C:\github\sample\python\keras\05_vgg16\ex2_data\train\magcupに格納)

●やかんの学習用画像(C:\github\sample\python\keras\05_vgg16\ex2_data\train\yakanに格納)

【検証用画像の例】
学習用とは別に検証用に使う土鍋、マグカップ、やかんの画像も用意して保存します。
●土鍋の学習用画像(C:\github\sample\python\keras\05_vgg16\ex2_data\train\donabe内に格納)

●マグカップの学習用画像(C:\github\sample\python\keras\05_vgg16\ex2_data\train\magcupに格納)

●やかんの学習用画像(C:\github\sample\python\keras\05_vgg16\ex2_data\train\yakanに格納)

実行結果


前回、Fine-tuningを使わなかったときよりも改善されました。

【Python/Keras】CNN(畳み込みニューラルネット)で画像の分類
Pythonの機械学習モジュール「Keras」でCNN(畳み込みニューラルネット)を実装し、画像を分類する方法をソースコード付きでまとめました。

続いてある1枚の画像を入力して分類してみます。
試しに、データセットにないやかんの写真を与えてみました。

■作成した分類器で分類


【TensorFlow版Keras入門】ディープラーニングを簡単に学ぶ方法
Pythonモジュール「TensorFlow/Keras」で深層学習(ディープラーニング)を行う方法について入門者向けに使い方を解説します。
Python機械学習
スポンサーリンク

コメント

  1. ys より:

    大変参考になるコード誠にありがとうございます。

    すべてを解説いただいているため、初心者の私には大変助かりました。

    すみません、下記の、ヤカン、土なべ、マグカップの画像データセットの
    ご提供は可能でございますか??

    # 分類するクラス名
    classes = [‘yakan’, ‘donabe’, ‘magcup’]

    • 管理人 より:

      コメントありがとうございます。
      ネット上にある素材を利用しましたので、提供ができません。
      つきましては、ys様ご自身で用意いただけますでしょうか。