Chainerのpredictorで入力を分類してみる

Chainerでは、入力値として画像や数値を与え、 深層学習を行うことで分類問題に対するモデルを作成することができます。ネットワークは、あらかじめ用意されている関数を組み合わせることでほとんど自由に構成することが可能で、Pythonで利用できる深層学習フレームワークとしてはかなり優秀だと思います。

学習モデルを作成するには?


Chainerを用いて学習モデルを生成するには、学習済みの各層の重みやらなにやらを保存する必要があります。便利なことにChainerには、学習後のモデルを簡単に保存できる関数が用意されているため、以下のプログラムを学習後に実行させるように組み込めば、任意の場所に拡張子がnpzのモデルデータを書き出すことができます。第2引数をmodelにすれば、モデルが保存でき、optimizerにすれば様々なネットワークの設定を保存することができます。各モデルは別のプログラムで読み込んで利用することが可能です。

    print('save the model')
    serializers.save_npz('mymodel.npz', model)
    print('save the optimizer')
    serializers.save_npz('mystate.npz', optimizer)

生成されたモデルを読み込んでテストする。

生成されたモデルを使って、学習に利用していない未知のデータに対する結果を予想したい場合があるとおもいます。そのような場合は、モデルを読み込んでpredictorという関数を用いることで、簡単にテストデータの予測を行うことができます。
具体的なプログラムを以下に示します。

・・・・
def main():

    model = DNN_TEST()
    serializers.load_npz("mymodel.npz", model)
    model=L.Classifier(model)

・・・・

    class_name = ['A', 'B']

    num1 = len(Data)
    cou = 1

    print dir(model)
    plt.figure(figsize=(5, 5))
    acc=0
    for i in range(len(Data)):

        x = Data[i]
        t = labelData[i]

        print i
        y = model.predictor(x[None, ...]).data.argmax(axis=1)[0]
        print('predicted_label:', class_name[y])
        print('answer:', class_name[t])
        if class_name[y] == class_name[t]:
            acc=acc+1
        cou += 1
    
    print('acc:', acc)
・・・・

上に示したプログラムでは、モデルの定義部分や、テストデータを配列に格納する部分は端折っていますが、大事な部分を抜き出してみました。まず、serializers.load_npzで生成済みのモデルを読み込みます。ここで読み込むモデルは、先に示したモデルの保存で生成されたものです。また、predictorを使うには、モデルがClassifierでラップされている必要があるので、もし学習時にClassifierを利用せず分類問題を学習していた場合は、以上のプログラムのように書くことで無理やりラップすることができます。筆者の経験上、無理やりClassifierでラップしても精度にあまり影響はなかったと思います。

テスト用のデータをDataという配列に、そのラベルをlabelDataという配列に格納し、model.predictorを使えば実際に予想を行うことができます。結果として出力されるラベルはあらかじめ class_name = [‘A’, ‘B’]のように定義しておく必要があります。
上記のプログラムでは、テストデータのラベルと予想したラベルを比較し、正解を導けたデータ数がターミナルに出力されます。

 

まとめ

いかがだっただろうか。分類問題であれば、ChainerのClassifier関数とpredictor関数を用いることで、簡単にテストデータの予測までできてしまうことを紹介した。学習データの精度とテストデータの精度を比較することは、過学習を起こしているかの判断基準にもなるのでぜひ挑戦してみてほしい。

にほんブログ村 IT技術ブログへ
にほんブログ村
にほんブログ村 IT技術ブログ IT技術メモへ
にほんブログ村

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です