제공하는 Quantization aware training기법을 제시한 논문입니다.
대체로 큰 모델의 경우 post training quantization을 한 후에도 정확도는 크게 떨어지지 않지만
가벼운 모델(mobilenet, efficientnet...)은 정확도가 많이 떨어질 수 있는데 QAT로 이러한 문제를 해결할 수 있습니다.
hardware에 좀 더 친화적인(friendly)방식으로 deploy하고자 하는 hardware에 맞추기 때문에
post training quantization적용 후 정확도에 대한 괴리감이 많이 없어지게 됩니다.
해당 논문에서는 이론적인 내용뿐만 아니라 Tensorflow에도 바로 적용할 수 있는 API를 제공해주고 있습니다.
기본적인 개념과 이 논문에서 contribution한 내용들에대해 자세히 알아보도록 하겠습니다.
[Quantization error]
보통 32bit의 모델을 8bit로 quantization을 하곤하는데 정보손실로 인해 quantization error가 발생합니다.
위의 그림이 quantization과 dequantization으로 인한 quantization error를 잘 말해주고 있습니다.
[Limitations of post training quantization]
post training quantization은 말 그대로 학습후에 모델을 quantization하는것을 의미합니다.
해당논문에서는 두 가지 이유로 accuracy drop이 발생한다고 제시하고 있습니다.
1. 첫 번째는 output channel에서 발생하는 서로 다른 범위가 매우 차이나는 것입니다.
예를 들어 weight tensor가 각각 [-2.3, 2.1], [-0.7, 0.5], [-4.7, 3.8]의 범위를 갖고 있다고 했을 때
해당 weight는 각 채널 범위에서 가장 큰 범위인 [-4.7, 3.8]을 기준으로 quantization이 될것입니다.
여기서 눈여겨볼것은 두 번째 채널은 [-0.7, 0.5]의 범위로 상당히 작은 범위이기 때문에
quantization이 될 때 quantization error가 크게 발생하게 됩니다.
[이 문제를 보완한 per-channel quantization이라는 기법도 있으니 관심있으시면 참고바랍니다.]
2. 두 번째는 weight의 가장자리 값(outlier weight values)으로 인해 quantization시 정확도가 떨어진다고 말하고있습니다.
위의 그래프는 weight tensor의 분포를 나타낸 예시인데 노란색으로 표시된 부분처럼 값이 큰 tensor들이 비중을 많이 차지하고 있지 않지만 quantization 범위를 크게 만들어 정확도를 떨어뜨리는 효과를 만들어냅니다.
[Quantization aware training: training process]
Quantization aware training방식은 integer-only hardware의 연산과 quantized된 모델의 형태를 학습과정에서부터
반영한다고 생각하면 이해하기가 쉽습니다.
integer hardware의 환경과 quantization을 simulation했기 때문에 quantization simulation이라고도 부릅니다.
또한 quantization의 효과를 간접적으로 반영하기 때문에 fake quantization이라고도 불립니다.
QAT에서는 위의 그림(b)처럼 학습하는 과정에서 weights와 activation output에 대해서 fake quantization node(wt quant, act quant)를
추가해 quantization효과를 줍니다.
quantization 효과란 32bit의 실수가 8bit의 정수가 되고 다시 32bit 실수로 변환되는
과정에서 발생하는 clamping, rounding효과를 의미합니다.
위의 수식에서 언급하는 값들의 정의는 다음과 같습니다.
$r:$변환하고자 하는 실수
$a, b:$ quantization하고자 하는 실수 범위
$s:$ quantization scale
Example)
$r=1.0, a=-3.2, b=1.3, n=2^{8}(8bit)$
$s(a,b,n)=\frac{1.3-(-3.2)}{2^{8}-1}=0.0176$
$q(r;a,b,n)=\left \lfloor (1.0-(-3.2))/0.0176 \right \rceil *0.0176+(-3.2)=1.0064$
위의 수식에대한 예시를 들어보면, quantization error가 0.0064만큼 발생하는것을 볼 수 있습니다.
즉, QAT는 quantization error가 반영된 값을 학습하는데 의의가 있습니다.
weight quantization과 activation output quantization에서 서로 다른 방법으로 $a, b$를 계산합니다.
- weight quantization: $a=min(w), b=max(w)$, int 8bit quantization 기준으로 뺄셈 최적화를 위해 [-127, 127]로 범위를 정함
- activation quantization: 범위를 어느정도 smooth하게 하기 위해 EMA로 $a, b$를 학습중에 계산함 (범위가 급격하게 바뀔 때 초반 학습 단계에서 activation quantization이 불가능해짐)
[Quantization aware training: inference process]
기존의 fake quantization node (wt quant, act quant)를 제거하고 quantization range 정보를 모델에 quantization할 때 반영합니다.
그렇게 만들어진 quantization 모델을 inference합니다. 이 과정은 target hardware(DSP / EdgeTPU)의 실제 연산과정과 유사합니다.
다음장에서 설명 이어나가겠습니다.
(Reference)