前言
本文我们来看将VAR分为drafter和refiner两阶段生成的CoDe。来自NUS的《Collaborative Decoding Makes Visual Auto-Regressive Modeling Efficient》。结论是基本保持VAR-d30的1.95FID的效果下,将其计算速度提升至原来的1.7倍而显存降低50%,还是很令人印象深刻的。
Observations
在引出方法之前,作者首先进过实验得到了两个结论:
1. 提升模型规模对大scale的生成作用有限
对VAR分出的10个scale:$h=w=1,2,3,4,5,6,8,10,13,16$中的每个scale$k$,作者使用不同的VAR模型生成第$k$个scale的token map$\boldsymbol R_k$,而用最大的VAR-d30生成其余的token map$\boldsymbol R_1,\cdots,\boldsymbol R_{k-1},\boldsymbol R_{k+1},\boldsymbol R_{10}$。结果如下图:
可见随着模型规模的增大,对小scale生成效果的提升较为明显;而对大scale的效果提升越来越小。
可能有读者会有疑问,为什么不同规模的VAR模型生成结果为什么能够组合。这是由于其共享了同一套tokenizer-decoder,所以我们可以training-free地进行组合。
2. 不同scale关注的信息不同,且它们生成模式是互斥的
下图是前3个和后3个scale的transformer生成token map的傅里叶分析,说明低scale的token包含更多低频成分(例如背景、颜色、布局等);高scale的token包含更多高频成分(例如纹理、细节等),这与我们的直观理解也是相符的。
同时作者进一步做了实验,VAR在训练过程中输入transformer的都是ground truth tokens而非模型生成的tokens,加上了attention mask之后来训练它next-scale prediction的效果,所以在训练阶段,理论上最后三个scale loss的优化可以和之前的无关,作者用相比训练阶段$1\%$的步数只用最后三个scale的交叉熵loss进行微调,结果FID从3.3直接飙升至了21.93,模型很快崩溃;另一方面如果只用前7个scale的loss微调,模型的生成质量会迅速退化到FID=10左右,也会崩掉,这说明高频和低频的建模能力存在相互干扰,影响了参数的有效的利用。
Methodology
那么自然地有了前面两个结论,相信读者也能够很快想到改进策略了:分离前后scale的生成过程。具体而言,文中用原来2B的VAR模型作为drafter,生成低分辨率的前$N$个scale的token map;而由于增大规模对大scale提升有限,所以可以只用0.3B参数的VAR模型作为refiner生成高分辨率的token。整体结构如下:
思路还是比较自然的,最后只需要组合两个部分的feature map就可以得到图片了。我们在Observations1中有提到,这种做法可以做到training-free;作者则进一步地将这两部分VAR分别在各自的scale上进行了微调。对于drafter而言,则仍然用原来生成的token map和GT之间的交叉熵loss;而对于refiner,作者用VAR-d30作为teacher用知识蒸馏构建loss:
$$ \mathcal L_{refiner}=\sum_{k=1}^K (\lambda_{ep}\cdot \boldsymbol 1_{[k\leq N]}+\boldsymbol 1_{[k>N]})KL\Big[p_{\theta_r}(\boldsymbol R_k)\Vert p_{teacher}(\boldsymbol R_k) \Big] $$
其中$\lambda_{ep}$随训练从1降至0,也即KD loss的重点逐渐从所有token map的分布转移至refiner的分布。
Experiments
方法部分就介绍完了,可见还是非常简洁的,效果如何呢?
作者实验了当$N$(前多少个scale用大参数模型生成)分别取为6-9的情况,可见在速度和显存占用上都有明显提升,关键在于将前后scale的生成过程断开后,整个生成的序列长度有明显减少,从而显著减少了attention计算和KV-cache的时空复杂度。
从视觉结果来看,指标的降低也是非常不明显的,甚至挑出来的一些图感觉比vanilla的质量还要高。总而言之,将VAR的生成过程断开确实能够很好提升效率。当时在看VAR的时候就觉得一张图片的token数量是不是太多了:$\sum_{i\in1,2,3,4,5,6,8,10,13,16}i^2=680$,相比其他自回归方法来说多很多。
文中还有其他的一些实验结果,大同小异这里就不放了。补一组论文中没有的:使用在前7个scale specilized fintune过的VAR-d30来生成前七个scale,后3个scale用原始的VAR-d30来生成,fid会从1.95降低到1.89。
References
- Collaborative Decoding Makes Visual Auto-Regressive Modeling Efficient, arXiv 2024 - CoDe
Zigeng Chen , Xinyin Ma , Gongfan Fang , Xinchao Wang @NUS, 2024.11 | [[arXiv PDF]](https://arxiv.org/pdf/2411.17787)