Sleep like a pillow

Deep Learning関係の話。

TResNet: High Performance GPU-Dedicated Architecture

arxiv.org

github.com

どんなもの?

 近年のDeep Learningのモデルは、ResNet50と比較して低いFLOPsで高い精度を出しているが、GPUによる学習・推論速度はResNet50と同等以下なものが多いことを指摘。本論文では、GPUに最適なネットワーク構造と実装の工夫によって、実際の学習・推論速度が高速、かつ高精度なモデルであるTResNetを提案。

ネットワーク構造

 ネットワーク構造の工夫は以下の5つ。

  • SpaceToDepth Stem
  • Anti-Alias Down Sampling
  • In-Place Activated BatchNorm
  • New Block-type Selection

SpaceToDepth Stem

 多くのネットワークでは、最初の数層に解像度を大きく下げる構造(例えば、ResNet50だとconv7x7(stride=2)->maxpoolの部分)が入っており、Stemと呼ばれる。TResNetではStemとしてSpaceToDepth Stemを用いる。SpaceToDepth Stemは、例えば解像度を1/2にする場合は、stride=2で抽出したピクセルをチャネル方向にconcatする。多分コードを見た方がわかりやすい。

class SpaceToDepth(nn.Module):
    def __init__(self, block_size=4):
        super().__init__()
        assert block_size == 4
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs)  # (N, C, H//bs, bs, W//bs, bs)
        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # (N, bs, bs, C, H//bs, W//bs)
        x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs)  # (N, C*bs^2, H//bs, W//bs)
        return x

図だと以下のような感じになる。

f:id:uiiurz1:20200712092652p:plain

Anti-Alias Down Sampling

 Stem以外の解像度を下げる機構としてstride=2のconv3x3の代わりに、Anti-Alias Down Samplingを用いる。Anti-Alias Down Samplingでは以下の図のように、重み固定のBlurフィルタでダウンサンプリングを行うことで、エイリアシングを抑える。これによってネットワークのシフト不変性が向上する。

f:id:uiiurz1:20200712093522p:plain

In-Place Activated BatchNorm

 全てのBatchNorm+ActivationをInplace-ABNで置き換える。Inplace-ABNはBN+Activationを一つのinplaceな演算で実現したもの。これによってGPUのメモリ消費量が下がり、より大きなバッチサイズで学習可能になるため、学習速度を高めることできる。

github.com

Block-Type Selection

 ResNetはconv3x3->conv3x3のBasicBlockと、conv1x1->conv3x3->conv1x1のBottleneckの二つのブロックから構成されている。Bottleneckの方がBasicBlockよりもGPUの使用率が高いが、精度は高くなる。TResNetでは、効率性を高めるために、解像度の高いStage1,2はBasicBlock、Stage3,4はBottleneckを用いる。チャネル数を変えた3モデル(M,L,XL)を用意。

f:id:uiiurz1:20200712095022p:plain

Optimized SE Layers

 演算コストを抑えるためにSqueeze-and-excitationモジュールのreductionパラメータと、Blockに挿入する位置を工夫。Bottleneckではreduction=8とするとともに、チャネル数の少ないconv3x3の後段に挿入する。また、チャネル数の大きいstage4では使用しない。

f:id:uiiurz1:20200712095548p:plain

コード最適化

 コード最適化の工夫は以下の3つ。どれもPyTorchを使用する前提のもの。

  • JIT Compilation
  • Inplace Operations
  • Fast Global Average Pooling

JIT Compilation

 学習の必要ないモジュール(BlurフィルタとSpaceToDepthモジュール)にPyTorchのJITコンパイル(torch.jit.script)を用いることで、学習・推論速度を高める。

Inplace Operations

 PyTorchでは、inplaceな演算はテンソルの値をコピーすることなく直接変更するため、GPUのメモリ使用を抑えることができる。TResNetではできる限りの演算をinplaceにする。

Fast Global Average Pooling

 PyTorchにはGlobal Average Poolingを行うクラスとしてAdaptiveAvgPool2dとAvgPool2dがあるが、AvgPool2dの方が高速。しかし、さらにPyTorchのview,meanメソッドでテンソルを直接操作した方が、AvgPool2dの5倍高速であることがわかったため、これをFast Global Average Poolingとして使用する。

有効性の検証

ResNet50との比較

 入力画像サイズは224x224。TResNet50-Mは学習速度以外はResNet50に優っている。Batch Sizeは2倍になっているため、メモリ使用がかなり抑えられていることがわかる。→メモリの貧弱なGPUでも動かせる。

f:id:uiiurz1:20200712104647p:plain

Ablation Study

 どの工夫も精度or速度に寄与している。

f:id:uiiurz1:20200712105147p:plain

High-Resolution Fine-Tuning

 入力画像サイズ224x224で学習したモデルを448x448で10epochだけFine-tuningした結果。TResNet-XLは84.3%と高い精度を出している。

f:id:uiiurz1:20200712105542p:plain

EfficientNetとの比較

 精度と学習・推論速度についてEfficientNetと比較。TResNetの方が良いトレードオフを示している。

f:id:uiiurz1:20200712105811p:plain

f:id:uiiurz1:20200712105825p:plain

所感

 かなり良い性能を示しておりpretrained modelも公開されているので、kaggleなどでも使われるようになっていくかもしれない。