前言
VAE的PyTorch实现。
References
数据
我们使用CelebA-HQ数据集进行人脸生成。为了简便我们将图片resize成64*64,下图是一个batch(16)的示例:
模型
接下来看VAE的实现。这里编码器和解码器的实现主要借鉴了[1],采用了一种类似U-net的结构:
class VAE(nn.Module):
"""
VAE for 64x64 face generation. The hidden dimensions can be tuned.
"""
def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None:
super().__init__()
# Encoder
prev_channels = 3
modules = []
img_length = 64
for cur_channels in hiddens:
modules.append(
nn.Sequential(
nn.Conv2d(prev_channels,
cur_channels,
kernel_size=3,
stride=2,
padding=1), nn.BatchNorm2d(cur_channels),
nn.ReLU()))
prev_channels = cur_channels
img_length //= 2
self.encoder = nn.Sequential(*modules)
# 生成后验的均值和方差
self.mean_linear = nn.Linear(prev_channels * img_length * img_length,
latent_dim)
self.var_linear = nn.Linear(prev_channels * img_length * img_length,
latent_dim)
self.latent_dim = latent_dim
# Decoder
modules = []
self.decoder_projection = nn.Linear(
latent_dim, prev_channels * img_length * img_length)
self.decoder_input_chw = (prev_channels, img_length, img_length)
for i in range(len(hiddens) - 1, 0, -1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hiddens[i],
hiddens[i - 1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hiddens[i - 1]), nn.ReLU()))
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hiddens[0],
hiddens[0],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hiddens[0]), nn.ReLU(),
nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1),
nn.ReLU()))
self.decoder = nn.Sequential(*modules)
整体结构还是比较清晰的。
接下来是正向计算部分,注意重参数化技巧:
def sampling(mean, log_var):
# Reparameterization Trick
eps = torch.randn_like(log_var, requires_grad=True) # 从标准高斯分布采样
std = torch.exp(log_var / 2) # 还原方差
return eps * std + mean
class VAE(nn.Module):
"""
VAE for 64x64 face generation. The hidden dimensions can be tuned.
"""
def forward(self, x):
encoded = self.encoder(x)
encoded = torch.flatten(encoded, 1)
# 计算后验的均值与方差
mean = self.mean_linear(encoded)
log_var = self.var_linear(encoded)
z = sampling(mean, log_var) # 重参数化
x = self.decoder_projection(z)
x = torch.reshape(x, (-1, *self.decoder_input_chw))
decoded = self.decoder(x)
return decoded, mean, log_var
从重参数化代码中我们可以发现,我们这里拟合的是$\log \boldsymbol\sigma^{2(i)}$而不是直接拟合$\boldsymbol\sigma^{2(i)}$,这是由于方差非负,拟合方差需要加激活函数处理。
损失函数
最后,我们需要定义损失函数。结合上一篇文章的内容,学习任务为:
$$ \begin{align} \max_{\boldsymbol\theta}L_{\boldsymbol\theta}&=\max_{\boldsymbol\theta}\sum_{\boldsymbol x}ELBO(\boldsymbol x;\boldsymbol\theta,\boldsymbol\phi)\\ &=\max_{\boldsymbol\theta} \mathbb{E}_{\boldsymbol z\sim q_{\boldsymbol\phi}(\boldsymbol z|\boldsymbol x)}\log p_{\boldsymbol\theta}(\boldsymbol x|\boldsymbol z) -KL(q_{\boldsymbol \phi}(\boldsymbol z|\boldsymbol x)\| p_{\boldsymbol \theta}(\boldsymbol z))\tag{1} \end{align} $$
等效为:
$$ \begin{align} &\min_{\boldsymbol\theta}L_{\boldsymbol\theta}=\min_{\boldsymbol\theta}\Bigg[ MSE(\boldsymbol x,\hat{\boldsymbol x}) +KL(q_{\boldsymbol \phi}(\boldsymbol z|\boldsymbol x)\| p_{\boldsymbol \theta}(\boldsymbol z))\Bigg]\\ &where:\\ &\qquad\qquad KL=\frac 12 \sum_{i=1}^d \Big(\sigma^{2(i)}+\mu^{2(i)}-\log \sigma^{2(i)}-1 \Big) \end{align}\tag{2} $$
对应到代码则是:
def loss_fn(x, x_hat, mean, logvar):
recons_loss = F.mse_loss(x_hat, x)
kl_loss = torch.mean(
-0.5 * torch.sum(1 + logvar - mean ** 2 - torch.exp(logvar), 1), 0)
loss = recons_loss + kl_loss * kl_weight
return loss
最后把训练的框架搭好就完成啦~下面是生成的图片示例,VAE生成的图片普遍比较糊。