Mutual Effort for Efficiency: A Similarity-Based Token Pruning for Vision Transformers in Self-Supervised Learning

ICLR 2025
Sheng Li1,* Qitao Tan2,* Yue Dai1 Zhenglun Kong3 Tianyu Wang1 Jun Liu4 Ao Li5 Ninghao Liu2 Yufei Ding6 Xulong Tang1 Geng Yuan2
1 University of Pittsburgh 2 University of Georgia 3 Harvard University
4 Northeastern University 5 University of Arizona 6 University of California, San Diego

* Equal contribution.

Abstract

Self-supervised learning (SSL) offers a compelling solution to the challenge of extensive labeled data requirements in traditional supervised learning. With the proven success of Vision Transformers (ViTs) in supervised tasks, there is increasing interest in adapting them for SSL frameworks. However, the high computational demands of SSL pose substantial challenges, particularly on resource-limited platforms like edge devices, despite its ability to achieve high accuracy without labeled data. Recent studies in supervised learning have shown that token pruning can reduce training costs by removing less informative tokens without compromising accuracy. However, SSL’s dual-branch encoders make traditional single-branch pruning strategies less effective, as they fail to account for the critical cross-branch similarity information, leading to reduced accuracy in SSL. To this end, we introduce SimPrune, a novel token pruning strategy designed for ViTs in SSL. SimPrune leverages cross-branch similarity information to efficiently prune tokens, retaining essential semantic information across dual branches. Additionally, we incorporate a difficulty-aware pruning strategy to further enhance SimPrune's effectiveness. Experimental results show that our proposed approach effectively reduces training computation while maintaining accuracy. Specifically, our approach offers 24% savings in training costs compared to SSL baseline, without sacrificing accuracy.

Background: Discriminative SSL

Discriminative self-supervised learning trains a model by aligning representations from two different augmented views of the same image.

A typical SSL framework uses an online branch and a target branch. The online branch is optimized by back-propagation, while the target branch is updated by Momemtum-based method or even not trained.

This dual-branch structure is the key difference from supervised learning.

Discriminative self-supervised learning with online and target encoder branches.

Why SSL Needs Cross-Branch Pruning

We first test whether a representative single-branch, attention-based token pruning method is effective for discriminative self-supervised learning.

Method Training FLOPs SSL Accuracy SSL Drop Supervised Drop
DINO baseline 100% 57.16 - -
Attention-based pruning, kr = 0.9 87% 56.65 -0.51 -0.10
Attention-based pruning, kr = 0.8 76% 54.68 -2.48 -0.21
Observation Attention-based pruning transfers poorly to SSL: at 76% training FLOPs, the SSL accuracy drop is 2.48%, far larger than the 0.21% drop in supervised learning.

Reason

Supervised learning uses a single image stream and explicit labels, so token importance can be estimated within one branch. Discriminative SSL instead aligns two augmented views through online and target branches. If pruning is done independently in each branch, corresponding semantic regions can be removed unevenly, weakening the SSL alignment objective.

Design

Transformer Block Integration

SimPrune is placed inside the Transformer block after the attention module and before the MLP.

SimPrune is inserted into selected Transformer blocks, such as the 4th, 7th, and 10th blocks, and works directly with standard SSL dual-branch training.

The online and target SSL branches keep the same backbone structure, while SimPrune reduces the token sequence that later layers need to process.

SimPrune inserted into Transformer blocks in the online and target SSL branches.

Cross-Branch Similarity Pruning

  1. 1
    Match tokens across branches

    For each online-branch image token, SimPrune finds the most similar target-branch token using cosine similarity.

  2. 2
    Sort pairs by similarity

    The matched token pairs are ordered from high to low similarity, forming the basis for pruning decisions.

  3. 3
    Prune token pairs together

    Tokens carrying related semantic information are retained or removed together, avoiding asymmetric pruning between views.

Token matching, sorting, and pair-wise pruning process of SimPrune.

Difficulty-Aware Sliding Window

Early SSL training benefits from easier training, so SimPrune initially retains token pairs with higher similarity and prunes more dissimilar pairs. As training progresses, the retained window shifts toward lower-similarity pairs, gradually increasing the task difficulty and encouraging stronger representations.

Difficulty-aware sliding window that shifts over token-pair similarity during training.

Main Results

ImageNet-1k experiments use DINO with DeiT encoders for 300 epochs. Results below highlight the keep-rate 0.8 setting from the main comparison.

Encoder Method Accuracy Training FLOPs Training Time
DeiT-T DINO 55.71 100% 100%
EViT 53.13 76% 88%
ToMe 53.65 76% 86%
SimPrune (Ours) 55.66 76% 86%
DeiT-S DINO 62.49 100% 100%
EViT 60.06 76% 87%
ToMe 60.34 76% 87%
SimPrune (Ours) 62.31 76% 85%
DeiT-B DINO 64.56 100% 100%
EViT 62.27 76% 89%
ToMe 61.58 76% 88%
SimPrune (Ours) 64.21 76% 89%

At keep rate 0.8, SimPrune reduces FLOPs by 24% and training time by about 13% on average, while keeping accuracy close to the unpruned DINO baseline.

Beyond Image Classification

SimPrune also transfers to fine-grained datasets and dense downstream tasks after SSL pretraining.

Method Stanford Cars
Accuracy
FGVC Aircraft
Accuracy
FLOPs
DINO 52.87 55.70 100%
EViT 50.72 53.81 76%
ToMe 49.56 53.15 76%
SimPrune (Ours) 52.61 55.48 76%

Fine-grained datasets with DeiT-S and keep rate 0.8.

Method MS COCO
APb
ADE20K
mIoU
FLOPs
DINO 49.8 34.3 100%
EViT 48.0 31.2 76%
ToMe 46.8 31.7 76%
SimPrune (Ours) 49.7 34.0 76%

Downstream object detection and semantic segmentation with DeiT-B.

Pruning Visualization

Visualization comparing EViT attention-based pruning and SimPrune token pruning outcomes.
Black regions are pruned tokens. SimPrune prunes matched token pairs so semantically corresponding regions across two augmented views are retained or removed together, reducing cross-view misalignment during SSL training.