chainerでUNetのようなネットワークをfor文を使って簡単に書く方法
できること
この記事では、次のことができるようになります。
- chainer v2,3,4でfor文を使ってネットワークを書く
はじめに
こんにちは、dhgrsです。社会人になって1年、仕事ではTensorFlowを使うことが多いです。EstimatorやDatasetにEagerなど最近のTensorFlowは高レベルAPIがかなり発展してきていてかなり使いやすくなっています。それでも詳細が分からなくてソースコードを読もうとすると、大規模すぎてかなり辛いです。この辺りchainerは優秀で、ソースコードがとても分かりやすいと実感し、改めてchainer愛が深まっていきました。今回は、そんなchainer愛を発散するために、必要なシーンは結構多いはずなのにあまり情報が出回っていないfor文を使ってネットワークを定義するコツをまとめてみます。今回の記事のプログラムはgithubに公開していますので、参考になったらぜひstarをつけてください。
github.com
注意
この記事は2018/04/01に書いていますが、もしかしたら近いうちにchainerの仕様が変わるかもしれません。詳しくはこの辺りのissueで議論されています。
chainer v6あたりで変更が入るのかな、と思っています。そのため未来の読者のみなさんは注意してください。
for文でネットワークを書く動機
主にGANsでは、ネットワークの構造によって学習が上手くいくかどうかが大きく左右されます。画像のような先行研究が多い分野では、論文や公開実装を参考にすればいいですが、例えば音声など成功例が少ないデータで試す場合はネットワークの構造を簡単に変えてすぐに実験したいシーンがあると思います。他にも新しいアルゴリズムを試すときにも、うまくいくネットワーク構造を色々試したいシーンがあると思います。for文で層数を制御できると、こうした試行錯誤が簡単に行えます。さらには、ネットワーク構造はほぼ同じな別プロジェクトなどに流用しやすいなど、利点は多くあります。
サンプルコード
githubから持ってきています。はてなブログで行指定してgithub上のコードを引用する方法が分からなかったので、ベタ書きしています。そのためgithubの最新版コードと異なっている場合があります。
(記事はこのコミットを元に書いています。)
基本形
class Exp1(chainer.ChainList): def __init__(self, n_layers): super(Exp1, self).__init__() for layer in range(n_layers): self.add_link(L.Linear(None, 2 ** layer)) def __call__(self, x): for link in self.children(): x = link(x) return x
これが基本形です。普段ネットワークを書くときはchainer.Chainを使う方が多いかと思いますが、for文を使うときはchainer.ChainListを使うと良いです。self.add_linkした順にself.children()で呼び出せます。(chainer.Chainだと順番が保持されません。)このコードはかなりシンプルですが、よく見ると、どこにも活性化関数がありません。だからといって、例えばF.relu(link(x))にしてしまうと、出力値にもreluが掛かってしまうので、困るシーンが多いかと思います。次の例で、この問題を解決します。余談ですが、floatへの丸め誤差が非線形関数の代わりになることを示した論文とかありましたよね。
活性化関数を適用する
class Exp2(chainer.ChainList): def __init__(self, n_layers): super(Exp2, self).__init__() for layer in range(n_layers): self.add_link(L.Linear(None, 2 ** layer)) def __call__(self, x): for link in self.children(): pre_activate = link(x) x = F.relu(pre_activate) return pre_activate
活性化関数の適用前後を別々の変数にすれば問題解決です。ただし例えばfully connected->batch normalization->reluを繰り返したいときにこのコードを使うと、fully connectedの後にもreluが適用されてしまいます。この解決法を次の例で見てみましょう。
linkの種類で場合分けする
class Exp3(chainer.ChainList): def __init__(self, n_layers): super(Exp3, self).__init__() for layer in range(n_layers): fc = L.Linear(None, 2 ** layer) self.add_link(fc) fc.name = 'fc{}'.format(layer) if layer != n_layers - 1: norm = L.BatchNormalization(2 ** layer) self.add_link(norm) norm.name = 'norm{}'.format(layer) def __call__(self, x): for link in self.children(): pre_activate = link(x) if 'fc' in link.name: x = pre_activate elif 'norm' in link.name: x = F.relu(pre_activate) return pre_activate
linkにはnameをつけることができます。このnameを利用して場合分けができます。1つ注意が必要なのは、self.add_linkをする際に、裏側で勝手にnameを上書きしてしまいます。そのためnameの定義->self.add_linkとしてしまうと、せっかく定義したnameが無視されてしまいます。この例では、self.add_link->nameの定義としています。
もうひとつ工夫しているのは、出力層にはbatch normalizationを適用したくないので、__init__内で場合分けして出力層ではbatch normalizationを追加しないようにしています。
実用的な例(GANsのgenerator)
class Exp4(chainer.ChainList): def __init__(self, n_layers): super(Exp4, self).__init__() fc = L.Linear(None, 7 * 7 * 4) self.add_link(fc) fc.name = 'fc' for layer in range(n_layers): if layer == n_layers - 1: out_channels = 1 else: out_channels = 2 ** layer conv = L.Deconvolution2D( None, out_channels, ksize=4, stride=2, pad=1) self.add_link(conv) conv.name = 'conv{}'.format(layer) if layer != n_layers - 1: norm = L.BatchNormalization(out_channels) self.add_link(norm) norm.name = 'norm{}'.format(layer) def __call__(self, x): for link in self.children(): x = link(x) if 'fc' in link.name: x = x.reshape(-1, 4, 7, 7) if 'norm' in link.name: x = F.relu(x) return x
実用的な例も見てみます。例えばGANsのgeneratorに使われるようなネットワークです。__init__が少し複雑になってきましたね。convolution層の出力channel数を層によって変えるため、少しややこしくなっています。Exp3と見比べてもらうと、理解しやすいかと思います。では次はUNet(conv-deconv型ネットワーク)の例を見てみます。
UNet(concatなし)
class Exp5Enc(chainer.ChainList): def __init__(self, n_layers): super(Exp5Enc, self).__init__() for layer in range(n_layers): out_channels = 16 * 2 ** layer conv = L.Convolution2D( None, out_channels, ksize=4, stride=2, pad=1) self.add_link(conv) conv.name = 'conv{}'.format(layer) if layer != n_layers - 1: norm = L.BatchNormalization(out_channels) self.add_link(norm) norm.name = 'norm{}'.format(layer) def __call__(self, x): for link in self.children(): x = link(x) if 'norm' in link.name: x = F.relu(x) return x class Exp5Dec(chainer.ChainList): def __init__(self, n_layers): super(Exp5Dec, self).__init__() for layer in range(n_layers): if layer == n_layers - 1: out_channels = 3 else: out_channels = 16 * 2 ** (n_layers - layer - 1) conv = L.Deconvolution2D( None, out_channels, ksize=4, stride=2, pad=1) self.add_link(conv) conv.name = 'conv{}'.format(layer) if layer != n_layers - 1: norm = L.BatchNormalization(out_channels) self.add_link(norm) norm.name = 'norm{}'.format(layer) def __call__(self, x): for link in self.children(): x = link(x) if 'norm' in link.name: x = F.relu(x) return x class Exp5(chainer.Chain): def __init__(self, n_layers): super(Exp5, self).__init__() with self.init_scope(): self.enc = Exp5Enc(n_layers) self.dec = Exp5Dec(n_layers) def __call__(self, x): x = self.enc(x) x = self.dec(x) return x
このくらい複雑になってくると、1つのChainListで表現してもいいのですが、encoderとdecoderに分けるほうがコードが見やすくなると思います。さらにpix2pixのようなUNet generatorとdiscriminatorを扱う手法の場合、encoderはdiscriminatorに流用しやすいという利点もあります。
かなり複雑なネットワークも書けるようになりました。最後は、concatありのUNetを書いてみます。encoderとdecoderの特徴量mapで同じ解像度のものをconcatすることで、encoderで失われがちな細部の情報を残す工夫です。
UNet(concatあり)
class Exp6Enc(chainer.ChainList): def __init__(self, n_layers): super(Exp6Enc, self).__init__() for layer in range(n_layers): out_channels = 16 * 2 ** layer conv = L.Convolution2D( None, out_channels, ksize=4, stride=2, pad=1) self.add_link(conv) conv.name = 'conv{}'.format(layer) if layer != n_layers - 1: norm = L.BatchNormalization(out_channels) self.add_link(norm) norm.name = 'norm{}'.format(layer) def __call__(self, x): features = [] for link in self.children(): x = link(x) if 'norm' in link.name: x = F.relu(x) features.append(x) return x, features class Exp6Dec(chainer.ChainList): def __init__(self, n_layers): super(Exp6Dec, self).__init__() for layer in range(n_layers): if layer == n_layers - 1: out_channels = 3 else: out_channels = 16 * 2 ** (n_layers - layer - 1) conv = L.Deconvolution2D( None, out_channels, ksize=4, stride=2, pad=1) self.add_link(conv) conv.name = 'conv{}'.format(layer) if layer != n_layers - 1: norm = L.BatchNormalization(out_channels) self.add_link(norm) norm.name = 'norm{}'.format(layer) def __call__(self, x, features): for link in self.children(): x = link(x) if 'norm' in link.name: x = F.relu(x) x = F.concat((x, features.pop(-1))) return x class Exp6(chainer.Chain): def __init__(self, n_layers): super(Exp6, self).__init__() with self.init_scope(): self.enc = Exp6Enc(n_layers) self.dec = Exp6Dec(n_layers) def __call__(self, x): x, features = self.enc(x) x = self.dec(x, features) return x
concatするためには、encoderが最終出力だけでなく各層の出力も一緒にdecoderに渡す必要があります。各層の出力をfeaturesというlistで管理することで実現してみました。これでよく使われるようなネットワークは全てfor文で書けるようになったかと思います。
chainerに詳しい人向けの補足
サンプルコードを見ていただいた方から複数、add_linkはdeprecated(非推奨)でないのかとご指摘いただきました。(英語力がないのでdeprecatedとduplicatedが違う単語なことに今気付きました。文脈で使い分けていました...)確かにchainer.Chainではdeprecatedです。しかし、実はChainListではfor文のような用途が想定されているため、add_linkはnot deprecatedです。さらにはこの記事のはじめに紹介したissueでは、add_linkをnot deprecatedに戻そうという議論もあります。もしこうしたほうが良いと思う意見がある方は、issueで議論するのも良いと思います。