TResNet: High Performance GPU-Dedicated Architecture
どんなもの?
近年のDeep Learningのモデルは、ResNet50と比較して低いFLOPsで高い精度を出しているが、GPUによる学習・推論速度はResNet50と同等以下なものが多いことを指摘。本論文では、GPUに最適なネットワーク構造と実装の工夫によって、実際の学習・推論速度が高速、かつ高精度なモデルであるTResNetを提案。
ネットワーク構造
ネットワーク構造の工夫は以下の5つ。
- SpaceToDepth Stem
- Anti-Alias Down Sampling
- In-Place Activated BatchNorm
- New Block-type Selection
SpaceToDepth Stem
多くのネットワークでは、最初の数層に解像度を大きく下げる構造(例えば、ResNet50だとconv7x7(stride=2)->maxpoolの部分)が入っており、Stemと呼ばれる。TResNetではStemとしてSpaceToDepth Stemを用いる。SpaceToDepth Stemは、例えば解像度を1/2にする場合は、stride=2で抽出したピクセルをチャネル方向にconcatする。多分コードを見た方がわかりやすい。
class SpaceToDepth(nn.Module): def __init__(self, block_size=4): super().__init__() assert block_size == 4 self.bs = block_size def forward(self, x): N, C, H, W = x.size() x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) return x
図だと以下のような感じになる。
Anti-Alias Down Sampling
Stem以外の解像度を下げる機構としてstride=2のconv3x3の代わりに、Anti-Alias Down Samplingを用いる。Anti-Alias Down Samplingでは以下の図のように、重み固定のBlurフィルタでダウンサンプリングを行うことで、エイリアシングを抑える。これによってネットワークのシフト不変性が向上する。
In-Place Activated BatchNorm
全てのBatchNorm+ActivationをInplace-ABNで置き換える。Inplace-ABNはBN+Activationを一つのinplaceな演算で実現したもの。これによってGPUのメモリ消費量が下がり、より大きなバッチサイズで学習可能になるため、学習速度を高めることできる。
Block-Type Selection
ResNetはconv3x3->conv3x3のBasicBlockと、conv1x1->conv3x3->conv1x1のBottleneckの二つのブロックから構成されている。Bottleneckの方がBasicBlockよりもGPUの使用率が高いが、精度は高くなる。TResNetでは、効率性を高めるために、解像度の高いStage1,2はBasicBlock、Stage3,4はBottleneckを用いる。チャネル数を変えた3モデル(M,L,XL)を用意。
Optimized SE Layers
演算コストを抑えるためにSqueeze-and-excitationモジュールのreductionパラメータと、Blockに挿入する位置を工夫。Bottleneckではreduction=8とするとともに、チャネル数の少ないconv3x3の後段に挿入する。また、チャネル数の大きいstage4では使用しない。
コード最適化
コード最適化の工夫は以下の3つ。どれもPyTorchを使用する前提のもの。
- JIT Compilation
- Inplace Operations
- Fast Global Average Pooling
JIT Compilation
学習の必要ないモジュール(BlurフィルタとSpaceToDepthモジュール)にPyTorchのJITコンパイル(torch.jit.script)を用いることで、学習・推論速度を高める。
Inplace Operations
PyTorchでは、inplaceな演算はテンソルの値をコピーすることなく直接変更するため、GPUのメモリ使用を抑えることができる。TResNetではできる限りの演算をinplaceにする。
Fast Global Average Pooling
PyTorchにはGlobal Average Poolingを行うクラスとしてAdaptiveAvgPool2dとAvgPool2dがあるが、AvgPool2dの方が高速。しかし、さらにPyTorchのview,meanメソッドでテンソルを直接操作した方が、AvgPool2dの5倍高速であることがわかったため、これをFast Global Average Poolingとして使用する。
有効性の検証
ResNet50との比較
入力画像サイズは224x224。TResNet50-Mは学習速度以外はResNet50に優っている。Batch Sizeは2倍になっているため、メモリ使用がかなり抑えられていることがわかる。→メモリの貧弱なGPUでも動かせる。
Ablation Study
どの工夫も精度or速度に寄与している。
High-Resolution Fine-Tuning
入力画像サイズ224x224で学習したモデルを448x448で10epochだけFine-tuningした結果。TResNet-XLは84.3%と高い精度を出している。
EfficientNetとの比較
精度と学習・推論速度についてEfficientNetと比較。TResNetの方が良いトレードオフを示している。
所感
かなり良い性能を示しておりpretrained modelも公開されているので、kaggleなどでも使われるようになっていくかもしれない。