2022년 1월 7일 금요일

Folding(Fusing) Batch normalization

1. Batch normalization


$\gamma,\beta$는 학습중에 지속적으로 업데이트되는 trainable한 변수입니다.
$\mu,\sigma$은 mini batch에서 계산되는 평균값과 표준편차값입니다.
$\mu_{avg}, \sigma_{avg}$은 inference과정에서 사용하기위해 moving average로 축적하는 값입니다.






2. Folding(Fusing) batch normalization

<inception block>

일반적으로 위와 같이 convolution layer와 batch normalization, 그리고 activation function을 연달아 사용하는 
block형태로 일반화되어 사용하는데 Folding(Fusing) batch normalization은 convolution layer와 batch normalization layer가 
결합되어 최적화를 한 layer를 말합니다.


$1.$   $y_{conv}=Wx+b$
-  convolution layer는 위의 식으로 표현됩니다.

$2.$   $\hat{x}=\frac{x-\mu}{\sqrt{\sigma^{2}+\epsilon}}$
$3.$   $y=\gamma\hat{x}+\beta$
-  batch normalization layer는 위의 식으로 표현됩니다.

$4.$   $y=\frac{\gamma(x-\mu)}{\sqrt{\sigma^{2}+\epsilon}}+\beta$                     $\because(2.\cup3.)$
$5.$   $y=\frac{\gamma(Wx+b-\mu)}{\sqrt{\sigma^{2}+\epsilon}}+\beta$                     $\because(1.\cup4.)$
-  batch normalization layer의 input이 convolution layer의 output이 되기 때문에 두 과정을 결합할 수 있습니다.
 
$6.$   $y=\frac{\gamma W}{\sqrt{\sigma^{2}+\epsilon}}x+\gamma\frac{b-\mu}{\sqrt{\sigma^{2}+\epsilon}}+\beta$     $\approx$    $y_{inference}=\frac{\gamma W}{\sqrt{\sigma_{avg}^{2}+\epsilon}}x_{inference}+\gamma\frac{b-\mu_{avg}}{\sqrt{\sigma_{avg}^{2}+\epsilon}}+\beta$      $\approx$    $y_{inference}=\dot{W}x_{inference}$ $+\dot{b}$
-  결국 최종형태의 식은 $y=\dot{W}x$ $+\dot{b}$형태가 되는데 inference과정에서의 $\dot{W}, \dot{b}$값은 모두 상수이기 때문에
   시간복잡도 $O(1)$로 굉장히 빨리 계산할 수 있게됩니다.





3. Tensorflow code
tensorflow2.7버젼 기준으로 batch normalization은 위의 API로 정의 되어있습니다.
앞서 설명한 fusing관련된 argument가 없어보이지만 실제로는 내부적으로 wrapping 돼 있어 
fused = False # or True
와 같이 argument를 추가적으로 명시해줄 수 있습니다.


fused의 코드들을 쭉 따라가다보면  아래와 같은 코드를 볼 수 있게되는데 (https://github.com/tensorflow/tensorflow/blob/v2.7.0/tensorflow/python/ops/nn_impl.py#L1586-L1691)
@tf_export("nn.batch_normalization")
@dispatch.add_dispatch_support
def batch_normalization(x,
                        mean,
                        variance,
                        offset,
                        scale,
                        variance_epsilon,
                        name=None):
  ...
  ...
  y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
      x,
      scale,
      offset,
      mean,
      variance,
      epsilon=epsilon,
      exponential_avg_factor=exponential_avg_factor,
      data_format=data_format,
      is_training=is_training,
      name=name)
  return y, running_mean, running_var
더이상 fused_batch_norm_v3를 실행하는 gen_nn_ops라는 라이브러리를 따라갈 수 없게됩니다.
그 이유는 bazel빌드를 통해서 만들어지는 파일이기 때문입니다.

실제 코드는 C++로 구현이 되어있으며 gen_nn_ops는 단순히 python으로 wrapping되어있는것 뿐인것이죠.


fused_batch_norm API의 자세한 C++코드를 보고 싶으다면 아래의 링크에서 확인할 수 있습니다.





[Reference]

댓글 없음:

댓글 쓰기