Understanding Knowledge Distillation from Neural Network
Knowledge Distillation
I. Introduction
Knowledge distillation is an attempt to transfer 'dark knowledge' from a complex model (teacher) to a simple model (student). In general, the teacher model is powerful and expressive, while the student model is simpler and more compact. Through the method of knowledge distillation, it is ideal that the performance of the student model can be as close to the teacher model as possible, so that the prediction effect similar to the complex model can be obtained with less parameters and computation. In 2015, Hinton specifically defined the concept of knowledge distillation for the first time in the paper "Distilling the Knowledge in a Neural Network", which induced the training of the student network by introducing the soft targets of the teacher model.
In large-scale machine learning tasks, we divide predictions into two distinct stages: the training stage and the deployment stage. In the training phase, a large amount of computing resources can be used, and no real-time response is required, so that a large amount of data can be used for training, so that the model has good generalization ability. However, there will be many restrictions in the deployment stage, such as computing resources, computing speed requirements, etc. Knowledge distillation can perfectly solve this requirement and is an ideal method for model compression.
II. Method
In knowledge distillation, the teacher imparts knowledge to students by minimizing a loss function that targets the probability distribution of the teacher's predicted results in the process of training students. The probability distribution predicted by the teacher is the output of the last softmax function layer of the teacher model. However, in many cases, the output of the traditional softmax layer has a large probability value of the correct classification, while the probability value of other classifications is almost close to 0. Therefore, feeding such an output to the student network does not provide more useful information than the original dataset, without taking advantage of the strong generalization ability of the teacher model. In order to solve this problem, Hinton proposed the concept of 'softmax temperature' and improved the softmax function:
Here T refers to the temperature parameter. When T is equal to 1, it is the traditional softmax function. As T increases, the probability distribution of the softmax output becomes smoother, so that more information from the teacher model can be utilized. When training the student model, the softmax function of the student uses the same T as the teacher, and the loss function targets the soft target output by the teacher. We call such a loss function 'distillation loss'. In the paper, it was also found that adding the correct data labels (hard labels) during the training process made the predictions better. While calculating the distillation loss, we use the hard label to calculate the standard loss (T=1) at the same time, which we call 'student loss'. The formula after integrating the two losses is:
The former term computes the KL divergence between the student model outputs (logits) and the teacher model outputs, and the latter term computes the cross-entropy between the student model outputs and the correct data labels. In the above formula, τ and α are artificially set as hyperparameters. The range of T used in Hinton's paper is 1 to 20. When the student model is very small relative to the teacher model, a relatively small value of T is more effective. A possible explanation for this result is that if the value of T is increased, the distribution of soft labels contains more information, so that a small student model cannot capture all the information.
III. Gradient
In the paper, Hinton uses gradient calculations to demonstrate why high temperatures work well. In the case of a combination of soft and hard labels, the gradient contributed by each logit is calculated as:
When T is large, that is, when the temperature is high:
Assuming the mean of logits is 0, then:
When T is small, distillation pays more attention to negative labels. When training complex networks, these negative labels are almost unconstrained, which makes the generated negative label probability relatively noisy, so a larger T value (above) is used.
Finally, I add simple code using PyTorch to implement soft loss in KD.
batch_soft_loss = -t.mean(t.sum(F.softmax(teacher_out / opt.temperature, dim = 1) * \ t.log(F.softmax(student_out / opt.temperature, dim = 1) + 1e-10), dim = 1))
IV. KD In My Work
My past research on KD was mainly divided into two parts. One part was to apply KD to existing topics, such as abnormal transactions, OCR and NLP; the other part was to study new KD algorithms. For the first part, the difficulty was mainly in the data and the interpretability of the algorithm, because the original KD was used in the field of image recognition, it is difficult to migrate it to structured data or natural language data. The second part was to study the new KD algorithm. At present, the research on KD in the AI industry is very hot, but it still needs to break through many difficulties and bottlenecks of the algorithm.
- Kunhong Yu
- Mar, 28 2022