前言
前两天VAR拿到了NeurIPS 2024最佳论文,之前刚放到arXiv上的时候只是简单刷了一下,加了个github star,不过这两天细细读过之后才品出味道来。这个短系列将详细梳理VAR的细节与其关键技术,并展望一些拓展的可能性。本文首先介绍VQ相关。
VQ
首先我们回忆一下AE和VAE。自编码器(Auto-Encoder,AE)是 无监督表征学习 的经典算法,其包含一个编码器(Encoder):$f:=\mathbb{R}^n\to\mathbb{R}^d$;与一个解码器(Decoder):$g:=\mathbb{R}^d\to\mathbb{R}^n$。AE的底层逻辑仍然是神经网络的通用近似定理,因此其天然适合压缩、重建等学习恒等映射的任务。当然它不能做生成,不过如果让encoder得到的latents符合某个分布,那我们可以直接在这个分布上进行采样从而过decoder完成随机图像生成。沿着这个思路便是大名鼎鼎的VAE了。详细可见:
然而实际生成过程中VAE得到的图片普遍比较糊,质量较低。Google的学者认为是VAE的encoder过强,对于任意图像都能encode到正态分布上,而在生成时使用连续变量生成时会有ambiguous的问题,作者称为posterior collapse问题。进一步地,如果使用离散化的latents则更具有灵活性,也符合一般信号的生成过程。因此作者借鉴NLP的思路,把图像编码成离散向量并通过一个embedding作为latents,这个嵌入层在VQ-VAE里叫做embedding space,在后续文章中则被称作codebook。具体如下:
当然从图上也能看出来,VQ-VAE最大的改进便是将图像编码为离散向量,因此本质上还是一个AE,同样面临着难以采样的问题。当然离散空间采样可以用Gumbel-softmax,作者用的生成方法是PixelCNN。不过由于生成方法并不是重点,而且PixelCNN目前也没有什么应用前景了,这里就不赘述了。
接下来我们详细看一下VQ-VAE的细节。我们的目标是让encoder输出一个整数,最直接的方法是对神经网络的logits做一个softmax转为分布,再用argmax得到类别编号。如果我们使用K-way的codebook,即用一个$\boldsymbol z_q\in \mathbb Z^{K},\mathbb Z=\{\cdots,-1,0,1,\cdots\}$去表示图像,那我们就需要重复K次softmax,显然这不是最高效的。因此对encoder的logits、离散变量codebook和decoder的输入embedding需要一种高效的关联方法。
具体而言,一个$\mathbb R^{3\times N\times N}$的图像首先encode得到一个连续向量$z\in\mathbb R^{d}$,而embedding层为:$E=[e_1,\cdots,e_K],e_i\in \mathbb R^d$,这里为了避免显式输入离散编码,一个直接的想法是直接关联$z$和$e_i$,就是直接用最近邻搜索:
$$ \begin{equation}z\to e_k,\quad k = \mathop{\text{argmin}}_j \Vert z - e_j\Vert_2\end{equation} $$
因此我们便可以直接将$e_k$作为$z$的embedding$z_q$,相当于我们把$k$作为了$z$的离散编码。当然实际操作时$z$一般不会之间encode到一维,而是$z\in\mathbb R^{h\times w\times d}$,相应地我们也可以量化得到一个$h\times w$的整数矩阵作为其离散编码。
接下来看它的训练。我们有重建损失:
$$ \mathcal L_{recon}=\Vert x-Decoder(z_q) \Vert_2^2 $$
当然显然注意到$z_q$当中使用了$argmin$,因此loss的梯度是传不到encoder的。文中使用了一种称为straight-through estimator的技术,思路就是前向传播和反向传播的计算可以不对应,具体而言我们定义算子sg (stop gradient):
$$ sg[x]=\begin{cases} x, 前向传播 \\ 0, 反向传播 \end{cases} $$
这个算子在pytorch里就是.detach()
,也即前向传播时$sg$里的值不变;反向传播时,$sg$按值为0求导,即此次计算无梯度。
基于这种运算,我们可以把重建loss写为:
$$ \mathcal L_{recon}=\Vert x-D(z+sg[z_q-z]) \Vert_2^2= \begin{cases} \Vert x-D(z_q) \Vert_2^2, 前向求loss \\ \Vert x-D(z) \Vert_2^2 , 反向传梯度 \end{cases} $$
此外我们的本意是让$z_q$和$z$相接近,而$L_{recon}$则并不能保证这一点,它只管了encoder和decoder而中间的embedding没有得到训练。因此我们可以额外加入$\Vert z_q-z\Vert_2^2$。作者进一步认为,encoder和embedding的学习速度应该不一样快,于是我们可以将其拆成两部分从而完成训练:
$$ \alpha\Vert sg[z]-z_q \Vert_2^2 +\beta \Vert z-sg[z_q] \Vert_3^2 $$
以上介绍了VQ-VAE的基本框架,它实际上是一种离散压缩方法(不知道为什么加上了V,反而有误导性了),整体可以总结为编码+量化+生成+解码的过程,文中编解码都用的是CNN,而量化方法就直接用的最近邻搜索。如果把它放到生成模型的任务下,就会有两个改进方向:
- 更好的生成模型,后续代表作为:VQ-GAN等
- 更好的量化方法,后续代表作为:RQ等
由于生成并不是我们这个系列的重点(毕竟重点是VAR),因此一些很出色的工作比如VQ-GAN、DALLE-1等这里就不介绍了。我们沿着量化方法这条路介绍相关工作,侧重的也是它们的量化压缩的方法。
RQ
在VQ-VAE中,我们使用最近邻的方法将$z\in\mathbb R^{h\times w\times d}$离散编码为了$[K]^{h\times w}$,从而得到了量化结果$z_q\in\mathbb R^{h\times w\times d}$,简化为一维就是$\mathcal Q(z;E):=\mathop{\text{argmin}}_j \Vert z-e_j\Vert_2$。显然这种方法过于直接,一个自然的想法是用$L$个顺序离散编码来表示:$\mathcal {RQ}(z;E,L)=(k_1,\cdots,k_L)\in [K]^L$。第一步还是直接量化$z$;而之后的每一步都量化logits和上一部量化结果的差值:
$$ k_i=\mathcal Q(r_{i-1};E),\quad r_i=r_{i-1} -e_{k_i},\quad r_0=z $$
最后只需要求$L$步量化结果的部分和就得到最终量化结果了:$z_q=\sum_{i=1}^L e_{k_i}$。当然训练除了重建loss$L_{recon}$之外,类似VQ-VAE也有commitment loss:
$$ L_{commit}=\sum_{i=1}^L \Vert z-sg[z_q^{(i)}] \Vert_2^2 $$
其中$z_q^{(i)}=\sum_{l=1}^i e_{k_l}$是第$i$步的部分量化结果。
总体RQ-VAE将视觉tokenize分解为了一个progressive的过程,用一系列token量化残差,这些token可以视作某种深度信息,从而改进生成结果。当然这种深度和真实的spatial空间有何联系没有深入探索,而且用多个token会增加自回归生成时的序列长度。
LFQ
LFQ的全名为Lookup-Free Quantization,顾名思义则是不需要lookup的量化方法。我们在之前介绍的VQ和RQ都需要一个显式的量化过程,即$\mathop{\text{argmin}}_j \Vert z - e_j\Vert_2$,这会带来一定的计算量。同时VQ有一个很严重的问题即:编码表坍缩。当编码表增大时,表现有时反而不如小codebook,而且这种现象越大越明显。归根结底还是argmin的手动梯度过于极端,在codebook过大的情况下只有少部分训练到了。而且过大的编码表也会导致第二阶段进行token生成时的困难。
因此[5]作者的出发点便是如何优化这个argmin操作。从本质来说,argmin无非是将logits和某个embedding快速对应起来,而embedding也不是训练好的,所以用如此精确的index方法从深度学习的角度而言并没有什么意义。换而言之,我只需要从logits中通过一个确定性过程得到一个整数index,再从编码表中取embedding就行了,而这种方法带来的编码损失则可以用更复杂的encoder和decoder来弥补。
LFQ的做法其实也非常简单,对于一个logits$z\in \mathbb R^d$,首先做一个二值量化即过一个符号函数:$sign(z):\mathbb R^d\to \{0,1\}^d$;而后对其中的正数分配一个2的幂再相加。所以通俗而言LFQ就是将logits逐位量化后用2的幂次值做加权和,这里用2的幂而不是直接相加自然是为了保证codebook的大小,避免过度压缩(毕竟文章的出发点就是研究大codebook下较VQ的优越性)。也即:
$$ index(z)=\sum_{i=1}^d {2^{i-1}\mathbb 1\{z_i>0\}},\ z_q=e_{index} $$
我们可以写一个pytroch风格的代码:
import torch
feature_dim = 8 # K
z = torch.randn([feature_dim]) # z
codebook = torch.nn.Embedding(2 ** feature_dim, feature_dim)
mask = 2 ** torch.arange(feature_dim - 1, -1, -1) # K-dim, [8, 4, 2, 1]
threshold_value = torch.ones_like(z)
quantized = torch.where(z > 0, threshold_value, -threshold_value) # sign(z)
indices = torch.sum((quantized > 0).int() * mask.int())
z_q = codebook(indices)
z_q = z + (z_q - z).detach() # STE
由于LFQ不像VQ需要$z$和$z_q$去显式逼近,所以一般而言也不需要commitment loss:$\mathcal L_{commit}=\Vert z-z_q \Vert_2^2$;作者额外引入了一个熵惩罚loss:
$$ \mathcal L_{entropy}=\mathbb E[H(z_q)]-H(\mathbb E[z_q]),\ H(x)=-\sum p(x)\log p(x) $$
由于我们的codebook大小$K$很大,所以LFQ同样面临坍缩问题,而entropy loss便是为了解决这一点。第一项为$\mathbb E[H(z_q)]=mean\Big[\sum_{i=1}^d H(z_{q,i})\Big]$为逐元素熵的平均,最小化即希望量化的token更确定,即希望encoder的logits能够得到置信度高的token;而最大化第二项$H(mean_{i}\Big[z_{q,i} \Big])$,即提高所有元素平均token分布的熵,是希望量化的token在codebook内更平均,避免codebook只有少部分训练到了,二者相当于刻画了量化的准确性和均匀性。
总体而言LFQ还是非常简洁的,虽然原论文有点避重就轻,而且很多细节也没有解释清楚。由于index是二元化的2的幂次值加权和,则有$K\leq \sum_{i}^d 2^{i-1}=2^{d}-1$,一般取$d=\log_2 K$,那么就要求在量化之前先将logits映射到一个较小的空间中。例如$K=262144=2^{18}$时,$z\in\mathbb R^{\log_2 K}=\mathbb R^{18}$,如此会导致一定程度的损失。
当然lookup-free其实是一类不需要argmin的量化方法的统称,只不过LFQ占住了这个名字。类似思路的tokenizer还有直接用round
函数的FSQ;BSQ等。
PQ
PQ (Product Quantization)来源于论文[4],截止现在2024.12.16还在ICLR 25的审稿期,不过从得分 (5, 6, 6, 8)来看还是很有戏的。
Multi-Scale RQ
最后我们就到了VAR使用的multi-scale RQ了。核心思想也非常自然,联想一下画图的过程,可能是先整体构图,再详细线条、上色等...,这样一种coarse-to-fine的范式其实可以认为是深度学习的一种核心思想。
具体来说VAR把图像编码成不同的scale,用所有scale的token来表示。联想一下之前介绍的RQ,通过一系列token不断预测logits和上一次量化之间的残差;而由于VAR在生成时肯定是从小scale开始生成,而每个scale的生成都以比它小的scale的token作为先验,因此在量化时我们也是从小scale开始。
我们定义一下符号,我们首先需要将图像分解为$m$个scale:$\{h_i,w_i\}_{i=1}^m$;VAE的encoder输出的latents为$z\in \mathbb R^{h\times w\times d}$,codebook为$E=[e_1,\cdots,e_K],e_i\in \mathbb R^d$,大小为$K$。我们以前两次量化为例:
- 首先第一次量化时,现将$z$插值到$(h_1,w_1)$得到$z_1\in \mathbb R^{h_1\times w_1\times d}$,文中直接从1*1开始,而后量化$z_1$:$v_1=\mathcal Q(z_1;E)\in [K]^{h_1\times w_1},\ z_{q,1}=e_{v_1}$。
- 而后计算“残差”,由于$z_{q,1}\in \mathbb R^{h_1\times w_1\times d}$和$z$的形状不符,所以我们需要一个卷积层,最终残差为:$z=z-conv_{1}(z_{q,1})$。
- 而后第二次量化,将$z$插值到$(h_2,w_2)$,再量化、计算残差。
算法的流程图如下:
可见这里的“残差”都是logits维度上计算的,而为了匹配不同的scale需要进行不断的插值(是不是也可以换成可学习的卷积?)+卷积。不同scale的encoding有没有更优雅的方式呢?
在具体实验中,$m$取为了10,对于256*256的图像生成,各个scale的分辨率具体为:$h=w=1,2,3,4,5,6,8,10,13,16$,则每个scale下的token分别只有$1,4,9,\cdots,256$,算是非常少的了。codebook大小$K=4096$。在量化算法创新之外,VAE的训练还是用的常规loss,这里就不赘述了。
References
- Neural Discrete Representation Learning, NIPS 2017 - VQ-VAE
Aaron van den Oord, Oriol Vinyals, Koray Kavukcuoglu @Google, 2017.11| [[arXiv PDF]](https://arxiv.org/pdf/1711.00937) - 轻松理解 VQ-VAE:首个提出 codebook 机制的生成模型
- Autoregressive image generationusing residual quantization, CVPR 2022, -RQ-VAE
Doyup Lee, Chiheon Kim, Saehoon Kim, Minsu Cho, Wook-Shin Han @postech, 2022.3 | [[arXiv PDF]](https://arxiv.org/pdf/2203.01941) - ImageFolder: Autoregressive Image Generation with Folded Tokens, ICLR 2025 Under Review, -ImageFolder
Xiang Li, Kai Qiu, Hao Chen, Jason Kuen, Jiuxiang Gu, Bhiksha Raj, Zhe Lin @CMU&Adobe, 2024.10 | [[arXiv PDF]](https://arxiv.org/pdf/2410.01756) - Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation, ICLR 2024, -MAGVIT-v2
Lijun Yu, José Lezama, Nitesh B. Gundavarapu, Luca Versari, Kihyuk Sohn, David Minnen, Yong Cheng, Vighnesh Birodkar, Agrim Gupta, Xiuye Gu, Alexander G. Hauptmann, Boqing Gong, Ming-Hsuan Yang, Irfan Essa, David A. Ross, Lu Jiang @Google&CMU, 2024.5 | [[arXiv PDF]](https://arxiv.org/pdf/2310.05737) - Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction, NIPS 2024, -VAR
Keyu Tian, Yi Jiang, Zehuan Yuan, Bingyue Peng, Liwei Wang @PKU&ByteDance, 2024.4 | [[arXiv PDF]](https://arxiv.org/pdf/2404.02905)