Given an input image \(\mathbf x \) and text conditioning \(\mathbf c\), we use a diffusion model to choose the class that best fits this image. Our approach, Diffusion Classifier, is theoretically motivated through the variational view of diffusion models and uses the ELBO to approximate \(\log p_{\theta}(\mathbf x|\mathbf c).\) Diffusion Classifier chooses the conditioning \(\mathbf c\) that best predicts the noise added to the input image. Diffusion Classifier can be used to extract a zero-shot classifier from a text-to-image model (like Stable Diffusion) and a standard classifier from a class-conditional model (like DiT) without any additional training.
The recent wave of large-scale text-to-image diffusion models has dramatically increased our text-based image generation abilities. These models can generate realistic images for a staggering variety of prompts and exhibit impressive compositional generalization abilities. Almost all use cases thus far have solely focused on sampling; however, diffusion models can also provide conditional density estimates, which are useful for tasks beyond image generation.
In this paper, we show that the density estimates from large-scale text-to-image diffusion models like Stable Diffusion can be leveraged to perform zero-shot classification without any additional training. Our generative approach to classification, which we call Diffusion Classifier, attains strong results on a variety of benchmarks and outperforms alternative methods of extracting knowledge from diffusion models. Although a gap remains between generative and discriminative approaches on zero-shot recognition tasks, we find that our diffusion-based approach has stronger multimodal relational reasoning abilities than competing discriminative approaches.
Finally, we use Diffusion Classifier to extract standard classifiers from class-conditional diffusion models trained on ImageNet. Even though these diffusion models are trained with weak augmentations and no regularization, we find that they approach the performance of SOTA discriminative ImageNet classifiers. Overall, our strong generalization and robustness results represent an encouraging step toward using generative over discriminative models for downstream tasks.
In general, classification using a conditional generative model can be done by using Bayes' theorem on the model predictions and the prior \(p(\mathbf{c})\) over labels \(\{\mathbf{c}_i\}\):
\begin{equation} p_\theta(\mathbf{c}_i \mid \mathbf{x}) = \frac{p(\mathbf{c}_i)\ p_\theta(\mathbf{x} \mid \mathbf{c}_i)}{\sum_j p(\mathbf{c}_j)\ p_\theta(\mathbf{x} \mid \mathbf{c}_j)} \label{eq:bayes} \end{equation}
A uniform prior over \(\{\mathbf{c}_i\}\) (i.e., \(p(\mathbf{c}_i) = \frac{1}{N}\)) is natural and leads to all of the \(p(\mathbf{c})\) terms cancelling. For diffusion models, computing \(\log p_\theta(\mathbf{x}\mid \mathbf{c})\) is intractable, so we approximate it with the ELBO (see paper §3.1), from which we have dropped constant and weighting terms:
\begin{align} \text{ELBO} \approx - \mathbb{E}_{t, \epsilon}[\|\epsilon - \epsilon_\theta(\mathbf{x}_t, \mathbf{c}_i)\|^2] \label{eq:elbo} \end{align}
We plug the modified ELBO Eq. \ref{eq:elbo} into Eq. \ref{eq:bayes} to obtain the posterior over \(\{\mathbf{c}_i\}_{i=1}^N\):
\begin{align} p_\theta(\mathbf{c}_i \mid \mathbf{x}) &\approx \frac{\exp\{- \mathbb{E}_{t, \epsilon}[\|\epsilon - \epsilon_\theta(\mathbf{x}_t, \mathbf{c}_i)\|^2]\}}{\sum_j \exp\{- \mathbb{E}_{t, \epsilon}[\|\epsilon - \epsilon_\theta(\mathbf{x}_t, \mathbf{c}_j)\|^2]\}} \label{eq:posterior} \end{align}
We compute an unbiased Monte Carlo estimate of each expectation by sampling \(N\) \((t_i, \epsilon_i)\) pairs, with \(t_i \sim [1, 1000]\) and \(\epsilon \sim \mathcal{N}(0, I)\), and computing
\begin{align} \frac{1}{N}\sum_{i=1}^N \left\|\epsilon_i - \epsilon_\theta(\sqrt{\bar \alpha_{t_i}}\mathbf{x} + \sqrt{1-\bar\alpha_{t_i}} \epsilon_i, \mathbf{c}_j)\right\|^2 \label{eq:monte_carlo} \end{align}
By plugging Eq. \ref{eq:monte_carlo} into Eq. \ref{eq:posterior}, we can extract a classifier from any conditional diffusion model. This method, which we call Diffusion Classifier, is a powerful, hyperparameter-free approach that leverages pretrained diffusion models for classification without any additional training. Diffusion Classifier can be used to extract a zero-shot classifier from a text-to-image model like Stable Diffusion, to extract a standard classifier from a class-conditional diffusion model like DiT, and so on.
We build Diffusion Classifier on top of Stable Diffusion, a text-to-image latent diffusion model trained on a filtered subset of LAION-5B. Our zero-shot classification method is competitive with CLIP and significantly outperforms the zero-shot diffusion model baseline that trains a classifier on synthetic SD data. It also generally outperforms the baseline trained on Stable Diffusion features, especially on complex datasets like ImageNet. This is especially impressive since the "SD Features" baseline uses the entire training set to train a classifier.
Zero-shot? | Food | CIFAR10 | FGVC | Pets | Flowers | STL10 | ImageNet | ObjectNet | ||
---|---|---|---|---|---|---|---|---|---|---|
Synthetic SD Data | ✓ | 12.6 | 35.3 | 9.4 | 31.3 | 22.1 | 38.0 | 18.9 | 5.2 | |
SD Features | ✗ | 73.0 | 84.0 | 35.2 | 75.9 | 70.0 | 87.2 | 56.6 | 10.2 | |
Diffusion Classifier | ✓ | 77.9 | 76.3 | 24.3 | 85.7 | 56.8 | 94.2 | 58.4 | 38.3 | |
CLIP ResNet50 | ✓ | 81.1 | 75.6 | 19.3 | 85.4 | 65.9 | 94.3 | 58.2 | 40.0 | |
OpenCLIP ViT-H/14 | ✓ | 92.7 | 97.3 | 42.3 | 94.6 | 79.9 | 98.3 | 76.8 | 69.2 |
We compare our zero-shot Diffusion Classifier method to CLIP and OpenCLIP on Winoground, a popular benchmark for evaluating the visuo-linguistic compositional reasoning abilities of vision-language models. This benchmark tests whether models can match captions to the correct images when certain entities are swapped in the captions.
Diffusion Classifier clearly outperfroms when object swaps are involved (Object and Both below). Since Stable Diffusion uses the same text encoder as OpenCLIP ViT-H/14, this improvement must come from better cross-modal binding of concepts to images. Overall, we find it surprising that Stable Diffusion, trained with only sampling in mind, can be repurposed into such a good classifier and reasoner.
Model | Object | Relation | Both | Average | ||||||
---|---|---|---|---|---|---|---|---|---|---|
Random Chance | 25.0 | 25.0 | 25.0 | 25.0 | ||||||
CLIP ViT-L/14 | 27.0 | 25.8 | 57.7 | 28.2 | ||||||
OpenCLIP ViT-H/14 | 39.0 | 26.6 | 57.7 | 33.0 | ||||||
Diffusion Classifier | 41.8 | 25.3 | 69.2 | 34.0 |
We use Diffusion Classifier to obtain a standard 1000-way classifier on ImageNet from a pretrained Diffusion Transformer (DiT) model. DiT is a class-conditional diffusion model trained solely on ImageNet-1k, with only random horizontal flips and no regularization. We compare Diffusion Classifier in this setting to strong discriminative classifiers like ResNet-50 and ViT-B/16 in the table below. We highlight cells in green where Diffusion Classifier outperforms.
Method | ID | OOD | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
IN | IN-v2 | IN-A | ObjectNet | |||||||
ResNet-18 | 74.1 | 57.3 | 15.0 | 26.6 | ||||||
ResNet-34 | 78.1 | 59.8 | 10.5 | 31.6 | ||||||
ResNet-50 | 79.7 | 61.6 | 9.8 | 35.6 | ||||||
ResNet-101 | 82.2 | 63.2 | 19.5 | 38.2 | ||||||
ViT-L/32 | 79.0 | 61.6 | 26.3 | 29.9 | ||||||
ViT-L/16 | 81.0 | 66.6 | 25.6 | 36.7 | ||||||
ViT-B/16 | 83.4 | 66.6 | 30.1 | 37.8 | ||||||
Diffusion Classifier | 78.9 | 62.1 | 22.6 | 32.3 |
Despite the fact that the discriminative models are trained with highly tuned learning rate schedulers, augmentation strategies, and regularizers, Diffusion Classifier outperforms many of them on ImageNet, both in-distribution and out-of-distribution. To the best of our knowledge, our approach is among the first generative modeling approaches to be competitive with SOTA discriminative classifiers on ImageNet.
@misc{li2023diffusion,
title={Your Diffusion Model is Secretly a Zero-Shot Classifier},
author={Alexander C. Li and Mihir Prabhudesai and Shivam Duggal and Ellis Brown and Deepak Pathak},
year={2023},
eprint={2303.16203},
archivePrefix={arXiv},
primaryClass={cs.LG}
}