이 논문은 제프리 힌튼과 제프 딘 등이 참여한 논문으로 지식증류(knowledge distillation) 기법을 다룬 연구이다. 모델 학습할때는 최대한 많은 데이터를 크고 깊은 모델에 학습시키면 좋은 성능의 결과가 나오겠지만, 실제 서비스에 배포할때 큰 모델을 사용하면 비용과 추론시간(latency)의 문제가 생긴다는 문제점을 해결한다.
직관적으로 논문을 설명하자면 학습은 크게, 서비스는 작게 하면 효과적이다. 이에 큰 모델(cumbersome model, 추후에는 teacher model)을 학습해 지식을 증류(distilled model, student model)한다.
이 연구에서는 soft target이라는 개념을 도입한다. 딥러닝 학습과정 중 클래스 분류를 위한 Cross entropy 비용함수를 최적화 할때 정답 클래스가 가진 특징을 학습하여 비슷한 특징을 가진 클래스도 학습할 수 있도록 한다. 하이퍼파라미터 Temperature(T)를 도입해 비용함수의 확률분포를 부드럽게한다(soft). 정답 클래스가 최대화될지라도 logit을 T로 나누어 최소한의 특징정보를 보존하는 것이다.
논문의 실험을 통해 MNIST 데이터셋에서 숫자 2 클래스의 데이터를 학습하면 숫자 3, 7과 비슷한 특징을 어느정도 가지기 때문에, 3과 7도 실제로 추론이 가능한 모습도 보여준다.
작은 모델은 soft target, hard target(실제 정답)을 참고해 학습하는데, 더 적은 데이터로 큰 모델과 유사한 일반화 능력을 갖게 된다.
논문에는 수식이 나와있지 않고 설명이 되어있지만 코드 실습을 하기 위해 chatGPT에 질문하면 해당 수식은 L=(1−α)⋅CE(hard labels)+α⋅KL(soft targets)이다.. 논문의 2장과 2.1장에서 Student 모델이 Teacher모델의 soft target 학습을 유도하는 과정이 설명되었다는 것이다. 논문에서 직접 KLDivergence를 명시하지 않았지만 soft target을 학습하는 과정에서 사실상 사용하고 있다고 한다.. (어렵다, 정보이론 개념 다시공부 필요)
지식 증류를 위한 수식은 아래 그림과 같다. T 값이 1보다 커지면 클래스 i의 확률분포가 더 부드러워진다. 큰 모델과 증류모델을 학습할때 높은 온도(high temperature)를 사용하고 학습이 완료되면 1을 값으로 한다.
전이 데이터셋(transfer set)에 정답 레이블이 포함되면 증류모델이 올바른 정답을 출력하는데 도움이 된다.
그 방법중 하나는 정답레이블을 이용해 소프트타겟을 변형하는 것인데, 저자들은 더 나은 방법인 두개의 목적함수의 가중합을 사용하였다. 첫 번째 목적함수는 cross entropy with 소프트 타겟으로 높은 temperature로 소프트 타겟에 사용된다. 두번째 목적함수는 cross entropy with 정답 레이블(correct labels)으로 같은 로짓을 사용하며 temperature는 1로 고정된다. 저자들은 두번째 목적함수의 비중이 낮을때 가장 좋은 결과를 얻었다고 한다. 소프트 타겟으로 계산된 기울기(gradient)는 1/T^2가 되어, soft, hard 둘다 사용할때는 T의 제곱을 곱해 두 타겟의 상대적인 기여도를 보정하였다.
(?로짓: 딥러닝 모델의 출력층의 결과로 softmax함수에 들어가기 전 클래스별로 각 한개씩 갖고 있는 값, 클래스에 대한 확신의 정도)
수식 2~4를 통해 로짓과 소프트 타겟의 관계에 대해 확인할 수 있다. 크로스 엔트로피 기울기는 각 증류모델의 로짓과 관련이 있는데(z_i), 증류모델의 크로스 엔트로피 결과 q와 큰모델 (cumbersome model, soft target사용)에 대한 크로스엔트로피 결과 p의 차이를 T로 나눈것과 같다(수식 2).
T가 로짓에 비해 크다면 수식 3과 같이 대략적으로 계산할 수 있다. 3의 수식에서 z의 합과 v의 합이 같아져 로짓의 평균이 0이 되면 수식 4처럼 단순화시킬 수 있다. T가 낮으면 증류모델은 로짓을 매칭하는데 평균보다 낮은(음수)인 경우 로짓을 정확히 맞추는데 덜 집중한다. T가 크면 다른 클래스들을 더 집중한다는 것이다.
연구진은 지식 증류의 효과를 보이고자 MNIST, 음성인식(speech recognition), 대규모 데이터(google JFT)에 대한 앙상블 모델에 대해 비교하고 regularizer로써 소프트 타겟을 실험하였다. 전문가 모델(specialist model)을 도입해 버섯, 표고버섯, 팽이버섯 등 혼동되는 클래스를 구분하는 실험도 진행했다.
실험결과는 논문 참조
Reference
https://arxiv.org/abs/1503.02531
'papers' 카테고리의 다른 글
You Only Look Once: Unified, Real-Time Object Detection (0) | 2024.12.18 |
---|---|
Deep Residual Learning for Image Recognition (0) | 2024.12.10 |
Learning Deep Features for Discriminative Localization (1) | 2024.11.30 |
XGBoost: A Scalable Tree Boosting System (3) | 2024.11.07 |
ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION (3) | 2024.10.16 |