Dataset Distillation with Feature Matching through the Wasserstein Metric

Haoyang Liu1, Yijiang Li2, Tiancheng Xing3, Vibhu Dalal1, Luwei Li1, Jingrui He1, Haohan Wang1,
1University of Illinois Urbana-Champaign, 2UC San Diego, 3Nanjing University
Distilled Images
Distilled Images on ImageNet-1k
Distillation Process
Distillation Process

We introduce WMDD (Dataset Distillation with Wasserstein Metric-based Feature Matching) which employs the Wasserstein metric to enhance distribution matching.

Abstract

Dataset Distillation (DD) aims to generate a compact synthetic dataset that enables models to achieve performance comparable to training on the full large dataset, significantly reducing computational costs. Drawing from optimal transport theory, we introduce WMDD (Dataset Distillation with Wasserstein Metric-based Feature Matching), a straightforward yet powerful method that employs the Wasserstein metric to enhance distribution matching. We compute the Wasserstein barycenter of features from a pretrained classifier to capture essential characteristics of the original data distribution. By optimizing synthetic data to align with this barycenter in feature space and leveraging per-class BatchNorm statistics to preserve intra-class variations, WMDD maintains the efficiency of distribution matching approaches while achieving state-of-the-art results across various high-resolution datasets. Our extensive experiments demonstrate WMDD's effectiveness and adaptability, highlighting its potential for advancing machine learning applications at scale.

Wasserstein Distance

Wasserstein Distance: For probability distributions μ and ν on a metric space Ω with distance metric D, the p-Wasserstein distance is defined as:

Wasserstein Distance Equation

where $P(\Omega)$ represents the space of all distributions on $\Omega$. The collection $\Pi(\mu,\nu)$ contains all joint distributions $\pi$ on $\Omega \times \Omega$ with marginals $\mu$ and $\nu$. Intuitively, this distance measures the minimum ``work'' needed to transform $\mu$ into $\nu$, with the effort of moving each mass unit captured by the $p$-th power of distance $D$.

Wasserstein Barycenter: Given N distributions $N$ distributions ${\nu_1,\dots,\nu_N} \in \mathbb{P} \subset P(\Omega)$, a Wasserstein barycenter solves:

Wasserstein Barycenter Equation

where $\mathbb{P}$ represents a subset of distributions in $P(\Omega)$, and $W_p^p(\mu, \nu_i)$ denotes the $p$-Wasserstein distance (raised to power $p$) between $\mu$ and each distribution $\nu_i$. This formulation allows us to find a central, representative distribution that best summarizes a collection of distributions according to their geometric properties.

gradient_feature

We demonstrate the ability of the Wasserstein barycenter to condense the core characteristics of distributions, and compare its effectiveness with KL divergence and Maximum Mean Discrepancy (MMD).

WMDD

We formulate dataset distillation as the problem of finding a synthetic empirical distribution $\mu_{\mathcal{S}}$—with learnable sample positions and weights—that minimizes the Wasserstein distance to the real data distribution $\mu_{\mathcal{T}}$:

WMDD Objective

The optimization alternates between two steps: (1) updating the weights $\mathbf{w}$ by solving the dual of the optimal transport problem, and (2) updating the positions $\tilde{\mathbf{X}}$ using Newton’s method:

WMDD Update Rule

This iterative approach efficiently computes a Wasserstein barycenter, allowing the distilled (synthetic) dataset to closely match the structure and diversity of the original data.

Method Overview

Overall, real dataset $T$ and synthetic dataset $S$ pass through the feature network $f$ to obtain features. The features of the real dataset are used to compute the Wasserstein Barycenter. The synthetic dataset is optimized via feature matching and loss computation (combining feature loss and BN regularization) to align with the Barycenter, generating high-quality synthetic data for efficient model training.

Comparison with Prior Methods

WMDD consistently showed SOTA performance in most settings across different datasets.

Performance Table
Performance comparison of various dataset distillation methods on different datasets.

Cross-architecture Generalization

Our synthetic data generalize well to other models in the ResNet family, and the performance increases with the model capacity in settings where the data size is relatively large.

Performance Table
Cross-architecture generalization performance on Tiny-ImageNet and ImageNet-1K in different IPC settings.

Ablation Study

Ablation Study 1
Effect of Wasserstein Barycenter, compared with cross-entropy and MMD.
Ablation Study 2
Effect of different $\lambda$.

Citation

If you find our work useful in your research, please consider citing:

@article{liu2023dataset,
    title={Dataset distillation via the wasserstein metric},
    author={Liu, Haoyang and Li, Yijiang and Xing, Tiancheng and Dalal, Vibhu and Li, Luwei and He, Jingrui and Wang, Haohan},
    journal={arXiv preprint arXiv:2311.18531},
    year={2023}
}