1. Batch normalization
$\gamma,\beta$는 학습중에 지속적으로 업데이트되는 trainable한 변수입니다.
$\mu,\sigma$은 mini batch에서 계산되는 평균값과 표준편차값입니다.
$\mu_{avg}, \sigma_{avg}$은 inference과정에서 사용하기위해 moving average로 축적하는 값입니다.
(Paper: https://arxiv.org/pdf/1502.03167.pdf)
2. Folding(Fusing) batch normalization
<inception block>
일반적으로 위와 같이 convolution layer와 batch normalization, 그리고 activation function을 연달아 사용하는
block형태로 일반화되어 사용하는데 Folding(Fusing) batch normalization은 convolution layer와 batch normalization layer가
결합되어 최적화를 한 layer를 말합니다.
$1.$ $y_{conv}=Wx+b$
- convolution layer는 위의 식으로 표현됩니다.
$2.$ $\hat{x}=\frac{x-\mu}{\sqrt{\sigma^{2}+\epsilon}}$
$3.$ $y=\gamma\hat{x}+\beta$
- batch normalization layer는 위의 식으로 표현됩니다.
$4.$ $y=\frac{\gamma(x-\mu)}{\sqrt{\sigma^{2}+\epsilon}}+\beta$ $\because(2.\cup3.)$
$5.$ $y=\frac{\gamma(Wx+b-\mu)}{\sqrt{\sigma^{2}+\epsilon}}+\beta$ $\because(1.\cup4.)$
- batch normalization layer의 input이 convolution layer의 output이 되기 때문에 두 과정을 결합할 수 있습니다.
$6.$ $y=\frac{\gamma W}{\sqrt{\sigma^{2}+\epsilon}}x+\gamma\frac{b-\mu}{\sqrt{\sigma^{2}+\epsilon}}+\beta$ $\approx$ $y_{inference}=\frac{\gamma W}{\sqrt{\sigma_{avg}^{2}+\epsilon}}x_{inference}+\gamma\frac{b-\mu_{avg}}{\sqrt{\sigma_{avg}^{2}+\epsilon}}+\beta$ $\approx$ $y_{inference}=\dot{W}x_{inference}$ $+\dot{b}$
- 결국 최종형태의 식은 $y=\dot{W}x$ $+\dot{b}$형태가 되는데 inference과정에서의 $\dot{W}, \dot{b}$값은 모두 상수이기 때문에
시간복잡도 $O(1)$로 굉장히 빨리 계산할 수 있게됩니다.
3. Tensorflow code
tensorflow2.7버젼 기준으로 batch normalization은 위의 API로 정의 되어있습니다.앞서 설명한 fusing관련된 argument가 없어보이지만 실제로는 내부적으로 wrapping 돼 있어
fused = False # or True
와 같이 argument를 추가적으로 명시해줄 수 있습니다.
fused의 코드들을 쭉 따라가다보면 아래와 같은 코드를 볼 수 있게되는데 (https://github.com/tensorflow/tensorflow/blob/v2.7.0/tensorflow/python/ops/nn_impl.py#L1586-L1691)
@tf_export("nn.batch_normalization")
@dispatch.add_dispatch_support
def batch_normalization(x,
mean,
variance,
offset,
scale,
variance_epsilon,
name=None):
...
...
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
x,
scale,
offset,
mean,
variance,
epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=is_training,
name=name)
return y, running_mean, running_var
더이상 fused_batch_norm_v3를 실행하는 gen_nn_ops라는 라이브러리를 따라갈 수 없게됩니다.
그 이유는 bazel빌드를 통해서 만들어지는 파일이기 때문입니다.
실제 코드는 C++로 구현이 되어있으며 gen_nn_ops는 단순히 python으로 wrapping되어있는것 뿐인것이죠.
fused_batch_norm API의 자세한 C++코드를 보고 싶으다면 아래의 링크에서 확인할 수 있습니다.
[Reference]
댓글 없음:
댓글 쓰기