キャッサバの葉の病害分類(実践編)#

この課題では,キャッサバの葉画像を対象として,病害の種類や健康状態を画像分類によって判別します.

題材としているのは,2021 年に開催された Kaggle の画像分類コンペティション Cassava Leaf Disease Classification です.
このコンペでは,キャッサバの葉画像から病害の種類を分類するモデルを構築し,その分類精度を競います.
本課題は,画像分類におけるデータ前処理,データ拡張,モデル設計,学習,評価,結果分析といった一連の流れを実践的に学びます.

分類対象#

この課題では,以下の5クラスを分類対象とします.

  1. 健康(healthy)

    健康

  2. モザイク病(mosaic_disease)

    モザイク病

  3. 緑斑ウイルス(green_mottle)

    緑斑ウイルス

  4. 条斑ウイルス(brown_streak_disease)

    条斑ウイルス

  5. 細菌性胴枯れ病(bacterial_blight)

    細菌性胴枯れ病

データセット#

データセットは Cassava Leaf Disease Classification の画像分類コンペのものを使用します.

峰野研究室の学生は,演習時に配布するデータセットを使用してください.
それ以外の方は,以下の Kaggle 公式ページから各自でダウンロードしてください.

ダウンロードしたデータセットは cassava_sample ディレクトリの直下に配置してください.

project_root/
├── intro.md               
├── requirements.txt       
├── summary.md             
├── data/                  
├── notebooks/              
├── cassava_sample/         
│   ├── ★dataset/   # データセットはここに置く
│   ├── dataset.py    
│   ├── main.py
│   ├── model.py
│   ├── train_eval.py
│   ├── utils.py
│   └── visualize.py
└── util.py     

データセットの特徴#

このデータセットには,以下のような特徴があります.

  • クラスごとに画像枚数の偏りがある

    データ分布

  • ノイズを含む画像が多数存在する

    • 対象が画像の中心にない

      中心にない

    • 葉が枯れている

      枯れている葉

    • 画角が遠い

      遠い

    • コントラストが高い

      ハイコントラスト

このような性質があるため,モデルの精度向上には,前処理やデータ拡張,モデル設計の工夫が重要になります.

課題内容#

PyTorch を用いて病害分類モデルを実装し,公開しているサンプルプログラムをベースに性能向上を目指します.

サンプルプログラム#

サンプルプログラムのディレクトリ構成は以下の通りです.

project_root/
├── intro.md               
├── requirements.txt       
├── summary.md             
├── data/                  
├── notebooks/              
├── ★cassava_sample/   # サンプルプログラム
│   ├── dataset/
│   ├── ★dataset.py    
│   ├── ★main.py
│   ├── ★model.py
│   ├── ★train_eval.py
│   ├── ★utils.py
│   └── ★visualize.py
└── util.py     

続いて,各プログラムの概要について解説します.

  • main.py
    学習全体を実行するためのメインスクリプトです.
    データの読み込み,可視化,モデル作成,学習・評価の実行をまとめて行います.

  • dataset.py
    画像データの読み込みと前処理を担当します.
    リサイズ,反転,回転などのデータ拡張を行い,学習用・検証用・テスト用のデータローダを作成します.

  • model.py
    シンプルなCNNモデルを定義しています.
    畳み込み層とプーリング層を重ねて特徴を抽出し,最後に5クラス分類を行います.

  • train_eval.py
    学習と評価の処理を担当します.
    各エポックでの損失・精度を記録し,最良モデルの保存やテストデータでの評価を行います.

  • utils.py
    学習を補助する関数をまとめたファイルです.
    乱数シードの固定,保存先ディレクトリの作成,クラス名の読み込みなどを行います.

  • visualize.py
    学習結果の可視化を行います.
    学習曲線,混同行列,誤分類例などを出力し,モデルの挙動を確認しやすくします.

実行例#

以下のようにプログラムを実行することで,学習を開始できます.

python main.py --data_dir ./dataset --epochs 10 --batch_size 32 --lr 1e-3

評価指標#

本サンプルプログラムでは,主に以下の指標を用いてモデルを評価します.

  • Loss
    予測と正解ラベルのずれを表す指標であり,値が小さいほど予測が正解に近いことを意味します. 訓練データに対する train_loss と,検証データに対する val_loss を確認します.

  • Accuracy
    正しく分類できた割合を表す指標であり,値が大きいほど正しく分類できていることを意味します. train_accval_acc を記録し,検証データにおける精度が最も高かったモデルを最良モデルとして保存します.

  • Confusion Matrix(混同行列)
    各クラスの正分類・誤分類の状況を表形式で確認できます.
    どのクラス同士が混同されやすいかを把握するのに役立ちます.

  • Precision
    あるクラスだと予測したもののうち,実際にそのクラスであった割合を示します.

  • Recall
    実際にそのクラスであるもののうち,正しくそのクラスだと判定できた割合を示します.

  • F1-score
    Precision と Recall のバランスを表す指標です.

結果の出力#

プログラムを実行すると,results ディレクトリに主に以下の結果が保存されます.

  • train_log.csv
    各エポックの loss,accuracy,学習時間

  • learning_curve.png
    学習曲線

  • best_model.pth
    検証精度が最も高かったモデル

  • confusion_matrix.png
    混同行列

  • classification_report.txt
    Precision,Recall,F1-score などをまとめた評価結果

  • misclassified.png
    誤分類例