DeiT-LT: Distillation Strikes Back for Vision Transformer Training on Long-Tailed Datasets

CVPR 2024
1Vision and AI Lab, Indian Institute of Science, 2Indian Institute of Technology, Kharagpur

TL;DR: We propose DeiT-LT - an OOD distillation framework that distills from CNNs trained via SAM. This leads to (a) learning of local generalizable features in early blocks in ViTs (b) low-rank features across blocks in ViTs which improves generalization (c) significantly improved performance for minority classes compared to SOTA.


Abstract

Vision Transformer (ViT) has emerged as a prominent architecture for various computer vision tasks. In ViT, we divide the input image into patch tokens and process them through a stack of self-attention blocks. However, unlike Convolutional Neural Networks (CNNs), ViT’s simple archi- tecture has no informative inductive bias (e.g., locality). This causes ViTs to require a large amount of data for pre-training. Various data-efficient approaches (DeiT) have been proposed to train a ViT on balanced data effectively. However, limited literature discusses the use of ViT for datasets with long- tailed imbalances. In this work, we introduce DeiT-LT for tackling the problem of training ViTs from scratch on long- tailed datasets. In DeiT-LT, we introduce an efficient and effective way of distillation from CNN via distillation DIST token, by using out-of-distribution images and re-weighting the distillation loss to enhance focus on tail classes. This leads to learning of local CNN-like features in early ViT blocks, improving generalization for tail classes. Further, to mitigate overfitting, we propose distilling from flat CNN teachers, which leads to learning low-rank generalizable features for DIST tokens across all ViT blocks. With the proposed DeiT-LT scheme, the distillation DIST token be- comes an expert on the tail classes and the classifier CLS token becomes an expert on the head classes. The experts help to effectively learn features related to both the majority and minority classes using a distinct set of tokens within the same ViT architecture. We show the effectiveness of DeiT-LT for training ViTs from scratch on datasets ranging from small- scale CIFAR-10 LT to large-scale iNaturalist-2018.



Proposed Method


  • • Distillation via Out-of-Distribution images: We propose to distill knowledge from a CNN teacher through OOD images generated using CutMix and Mixup. The distillation is done via the DIST token as follows:
    $$\mathcal{L}_{dist} = \mathcal{L}_{CE} (f^{d}(x), y_{t}), y_t = \arg \max_{i} g(x)_{i}$$
    The out-of-distribution distillation leads to diverse experts, which become more diverse with deferred re-weighting on the distillation token (DRW).

  • Deferred Re-Weighting (DRW) for distillation: We introduce DRW with the distillation loss to encourage the DIST token to focus on the tail classes. This leads to diverse CLS and DIST tokens that specialize on the majority and minority classes, respectively. $$\mathcal{L} = \frac{1}{2}\big{\{}\mathcal{L}_{CE}(f^{c}(x), y) + \mathcal{L}_{DRW} (f^{d}(x), y_{t})\big{\}}$$ $$\mathcal{L}_{DRW} = -w_{y_t} \; log (f^{d}(x)_{y_t})$$ $$w_y = {1}/\{1 + (e_y - 1)\mathbb{1}_{\mathrm{epoch \geq K}}\}, \textrm{ where } e_y=\frac{1-\beta^{N_y}}{1-\beta}$$

  • • Low-rank features via SAM teachers: To further improve the generalizability to tail classes, we propose to distill teachers that have been trained with Sharpness Aware Minimization (SAM).
    Description 1
    We show the rank of features for DIST token, where we demonstrate that students trained with SAM are more low-rank in comparison to baselines.

  • • Induction of local features: DeiT-LT contains heads that attend locally, like CNN, in the neighborhood of the patch in early blocks (1,2) and hence learn more local generalizable features.
    Description 1
    We plot the Mean Attention Distance for the patches across the early self attention block 1 (solid) and block 2 (dashed) for baselines, where we find that DeiT-LT leads to highly local and generalizable features.


Attention Visualisation

Visual comparison of the attention maps from ViT-B, DeiT-III and DeiT-LT (ours) on the ImageNet-LT dataset, computed using the method of Attention Rollout (ACL'2020).


Main Results

1. Results on CIFAR10-LT and CIFAR100-LT: Performance (%) of the proposed approach DeiT-LT, compared to the existing methods on CIFAR-10 and CIFAR-100 with $\rho=50$ and $\rho=100$.
Method CIFAR 10-LT CIFAR 100-LT
$\rho=50$ $\rho=100$ $\rho=50$ $\rho=100$
ResNet-32 backbone
CB Focal Loss (CVPR'2019) 74.6 79.3 38.3 46.2
CAM (AAAI'2021) 80.0 83.6 47.8 51.7
GCL (CVPR'2022) 82.7 85.5 48.7 53.6
1LDAM+DRW+SAM (NeurIPS'2022) 81.9 84.8 45.4 49.4
2PaCo+SAM (ICCV'2021, NeurIPS'2022) 86.8 88.6 52.8 56.6
ViT backbone
ViT (ICLR'2021) 62.6 70.1 35.0 39.0
ViT (cRT) (ICLR'2020) 68.9 74.5 38.9 42.2
DeiT (ICML'2021) 70.2 77.5 31.3 39.1
DeiT-III (ECCV'2022) 59.1 68.2 38.1 44.1
1DeiT-LT (ours) 84.8 87.5 52.0 54.1
2DeiT-LT (ours) 87.5 89.8 55.6 60.5
2. Results on ImageNet-LT: Performance (%) of the proposed approach DeiT-LT, compared to the existing methods on ImageNet-LT. Supererscripts indicate teacher-student pairs.
Method Overall Head Mid Tail
ResNet50 backbone
LDAM (NeurIPS'2019) 49.8 60.4 46.9 30.7
RIDE (3 exp.) (ICLR'2021) 54.9 66.2 51.7 34.9
BCL (NeurIPS'2020) 57.1 67.9 54.2 36.6
1LDAM+DRW+SAM (NeurIPS'2022) 53.1 62.0 52.1 32.8
2PaCo+SAM (ICCV'2021, NeurIPS'2022) 57.5 62.1 58.8 39.3
ViT backbone
ViT (ICML'2021) 37.5 56.9 30.4 10.3
DeiT-III (ECCV'2022) 48.4 70.4 40.9 12.8
1DeiT-LT (ours) 55.6 65.2 54.0 37.1
2DeiT-LT (ours) 59.1 66.6 58.3 40.0
3. Results on iNaturalist-2018 Performance (%) of the proposed approach DeiT-LT, compared to the existing methods on iNaturalist-2018. Supererscripts indicate teacher-student pairs.
Method Overall Head Mid Tail
ResNet50 backbone
cRT (ICLR'2020) 65.2 69.0 66.0 63.2
RIDE (3 exp.) (ICLR'2021) 72.2 70.2 72.2 72.7
CBDENS (BMVC'2021) 73.6 75.9 74.7 71.5
1LDAM+DRW+SAM (NeurIPS'2022) 70.1 64.1 70.5 71.2
2PaCo+SAM (ICCV'2021, NeurIPS'2022) 73.4 66.3 73.6 75.2
ViT backbone
ViT (ICML'2021) 54.2 64.3 53.9 52.1
DeiT-III (ECCV'2022) 61.0 72.9 62.8 55.8
1DeiT-LT (ours) 72.9 69.0 73.3 73.3
2DeiT-LT (ours) 75.1 70.3 75.2 76.2

BibTeX

@InProceedings{Rangwani_2024_CVPR,
    author    = {Rangwani, Harsh and Mondal, Pradipto and Mishra, Mayank and Asokan, Ashish Ramayee and Babu, R. Venkatesh},
    title     = {DeiT-LT: Distillation Strikes Back for Vision Transformer Training on Long-Tailed Datasets},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2024},
    pages     = {23396-23406}
}