ChainerのCNNで中間層の可視化をしてみる

お久しぶりです。いろいろとバタバタしていてなかなかブログを更新できていませんでした。久しぶりにPythonの深層学習フレームワークであるChainerについての記事を書いていこうと思います。

CNNで中間層の可視化をする

CNNというのは、Convolution Neural Netの略で、畳み込みネットワークとも呼ばれます。主に画像処理分野で成果を上げているため、画像認識で使われることが多いネットワークです。詳しくは別の記事を参照のこと。

中間層を可視化すると何がいいのかというと、人間にも分かる特徴量として重みを学習している場合、可視化することによって人が認識しやすくなるという点です。実際に人の顔なんかを深層学習すると、各特徴量として顔のパーツなんかが出てきたりします。CNNの場合、畳み込み層でのフィルターを可視化することができます。フィルターは各特徴を学習して最適化されていくので、フィルターを見れば画像中のどんな特徴を学習しているのかわかる可能性があるというわけです。

最新のextend関数を使わない

最新のChainerでは学習サイクルを回す際に、extendという関数を用いることでコードをとても単純に書くことができます。ちなみに、Chainerの最新のサンプルコードはこのextendを使って書かれているため、そのまま使用すると、各学習ループから重みを取ってくることができないため、可視化するのが難しい。

よって少し前のバージョンの書き方をするのが楽かもしれない。今回載せるサンプルでは、while文を使ってエポックを回している。

・kasha.py

# -*- coding: utf-8 -*-
from sklearn.datasets import fetch_mldata
from sklearn.cross_validation import train_test_split
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
from chainer.datasets import tuple_dataset
from PIL import Image
import glob
import numpy as np
from chainer import serializers
from chainer import optimizers
from chainer.dataset.convert import concat_examples
import csv


import matplotlib.pyplot as plt
import sys

plt.style.use('ggplot')

# Network definition
class CNN(chainer.Chain):

    def __init__(self, train=True):
        super(CNN, self).__init__(
            conv1=L.Convolution2D(1, 64, 16, stride=4),
            conv2=L.Convolution2D(None, 64,  5, pad=2),
            conv3=L.Convolution2D(None, 8,  3, pad=1),
            l1=L.Linear(128, 2),
        )
        self.train = train

    def __call__(self, x):

        h = F.max_pooling_2d(F.relu(self.conv1(x)), 2)
        h = F.max_pooling_2d(F.relu(self.conv2(h)), 2)
        h = F.max_pooling_2d(F.relu(self.conv3(h)), 2)
        return self.l1(h)


def main():

    # Set up a neural network to train
    model=CNN()
    gpu_id = -1

    # Setup an optimizer
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    

    #画像のデータセット
    pathsAndLabels = []
    pathsAndLabels.append(np.asarray(["/Users/A/", 0]))
    pathsAndLabels.append(np.asarray(["/Users/B/", 1]))


    # データを混ぜて、trainとtestがちゃんとまばらになるように。1ch
    allData = []
    for pathAndLabel in pathsAndLabels:
        path = pathAndLabel[0]
        label = pathAndLabel[1]
        imagelist = glob.glob(path + "*")
        for imgName in imagelist:
            allData.append([imgName, label])
    allData = np.random.permutation(allData)

    imageData = []
    labelData = []
    for pathAndLabel in allData:
        img = Image.open(pathAndLabel[0])
        #3チャンネルの画像をr,g,bそれぞれの画像に分ける
        r,g,b,a= img.split()
        # print img
        gImgData = np.asarray(np.float32(g)/255.0)
        imageData.append([gImgData])
        labelData.append(np.int32(pathAndLabel[1]))

    threshold = np.int32(len(imageData)/8*7)
    train = tuple_dataset.TupleDataset(imageData[0:threshold], labelData[0:threshold])
    test  = tuple_dataset.TupleDataset(imageData[threshold:],  labelData[threshold:])


    max_epoch = 20
    train_iter = chainer.iterators.SerialIterator(train, batch_size=100)
    test_iter = chainer.iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)

    while train_iter.epoch < max_epoch:

        # ---------- 学習の1イテレーション ----------
        train_batch = train_iter.next()
        x, t = concat_examples(train_batch, gpu_id)

        # 予測値の計算
        y = model(x)

        # ロスの計算
        loss = F.softmax_cross_entropy(y, t)

        # 勾配の計算
        model.cleargrads()
        loss.backward()

        # パラメータの更新
        optimizer.update()
        # --------------- ここまで ----------------

        # 1エポック終了ごとにValidationデータに対する予測精度を測って、
        # モデルの汎化性能が向上していることをチェックしよう
        if train_iter.is_new_epoch:  # 1 epochが終わったら

            # ロスの表示
            print('epoch:{:02d} train_loss:{:.04f} '.format(
                train_iter.epoch, float(loss.data)))

            test_losses = []
            test_accuracies = []
            while True:
                test_batch = test_iter.next()
                x_test, t_test = concat_examples(test_batch, gpu_id)

                # テストデータをforward
                y_test = model(x_test)

                # ロスを計算
                loss_test = F.softmax_cross_entropy(y_test, t_test)
                test_losses.append(loss_test.data)

                # 精度を計算
                accuracy = F.accuracy(y_test, t_test)
                accuracy.to_cpu()
                test_accuracies.append(accuracy.data)

                if test_iter.is_new_epoch:
                    test_iter.epoch = 0
                    test_iter.current_position = 0
                    test_iter.is_new_epoch = False
                    test_iter._pushed_position = None
                    break

            print('val_loss:{:.04f} val_accuracy:{:.04f}'.format(
                np.mean(test_losses), np.mean(test_accuracies)))
            l1_W = []
            l2_W = []
            l3_W = []

            l1_W.append(model.conv1.W)
            plt.figure(figsize = (10, 10))
            cnt = 1
            for i in range(64):#np.random.permutation(64)[:64]:
                draw_digit4(l1_W[len(l1_W)-1][i].data, cnt, i)
                cnt += 1
            filename = "Layer_img/greyscale_ver/L1_"+str(train_iter.epoch)+".png"
            plt.savefig(filename)

            l2_W.append(model.conv2.W)
            plt.figure(figsize = (10, 10))
            cnt = 1
            for i in range(64):#np.random.permutation(64)[:64]:
                draw_digit5(l2_W[len(l2_W)-1][i].data, cnt, i)
                cnt += 1
            filename = "Layer_img/greyscale_ver/L2_"+str(train_iter.epoch)+".png"
            plt.savefig(filename)

            l3_W.append(model.conv3.W)
            plt.figure(figsize = (10, 10))
            cnt = 1
            for i in range(8):#np.random.permutation(64)[:64]:
                draw_digit6(l3_W[len(l3_W)-1][i].data, cnt, i)
                cnt += 1
            filename = "Layer_img/greyscale_ver/L3_"+str(train_iter.epoch)+".png"
            plt.savefig(filename)



# 1層目のパラメータwの可視化
def draw_digit4(data, n, i):
    pixel_size = 16
    plt.subplot(10, 10, n)
    Z = data.reshape(pixel_size, pixel_size)
    Z = Z[::-1]
    plt.xlim(0, pixel_size)
    plt.ylim(0, pixel_size)
    plt.imshow(Z)
    plt.title("{0}".format(i), size = 9)
    plt.gray()
    plt.tick_params(labelbottom = "off")
    plt.tick_params(labelleft = "off")

def draw_digit5(data, n, i):
    pixel_size = 40
    plt.subplot(10, 10, n)
    Z = data.reshape(pixel_size, pixel_size)
    Z = Z[::-1]
    plt.xlim(0, pixel_size)
    plt.ylim(0, pixel_size)
    plt.imshow(Z)
    plt.title("{0}".format(i), size = 9)
    plt.gray()
    plt.tick_params(labelbottom = "off")
    plt.tick_params(labelleft = "off")

def draw_digit6(data, n, i):
    pixel_size = 24
    plt.subplot(10, 10, n)
    Z = data.reshape(pixel_size, pixel_size)
    Z = Z[::-1]
    plt.xlim(0, pixel_size)
    plt.ylim(0, pixel_size)
    plt.imshow(Z)
    plt.title("{0}".format(i), size = 9)
    plt.gray()
    plt.tick_params(labelbottom = "off")
    plt.tick_params(labelleft = "off")

if __name__ == '__main__':
    main()


今回は3層の畳み込みニューラルネットを使っている。入力は画像データで、単純にするために入力チャンネルを1にして白黒画像を学習するようにしてある。元のデータがカラーだったので、RGBのGチャンネルを輝度とみなし、グレイースケールにしてある。
ネットワークの定義方法や、各種データセットの作り方は以前の記事を参考のこと。




draw_digitという関数を作り、ここで可視化を行っている。from PIL import Imageなど、ライブラリがいくつか必要なのでその都度インストールすればいいだろう。各層の重みはmodel.conv1.Wなどのように書くことで取ってこれるため、これを配列に格納しプロットしていくイメージだ。

まとめ

ニューラルネットの重みを可視化することができれば、学習サイクルにおける重みの変化遷移を知ることができる。ぜひ挑戦してみてほしい。
 

にほんブログ村 IT技術ブログへ
にほんブログ村
にほんブログ村 IT技術ブログ IT技術メモへ
にほんブログ村

ChainerのCNNで中間層の可視化をしてみる」への1件のフィードバック

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です