前言

  VAE的PyTorch实现。

References

  1. 抛开数学,轻松学懂 VAE(附 PyTorch 实现)
  2. https://github.com/AntixK/PyTorch-VAE

数据

  我们使用CelebA-HQ数据集进行人脸生成。为了简便我们将图片resize成64*64,下图是一个batch(16)的示例:

数据示例

模型

  接下来看VAE的实现。这里编码器和解码器的实现主要借鉴了[1],采用了一种类似U-net的结构:

class VAE(nn.Module):
    """
    VAE for 64x64 face generation. The hidden dimensions can be tuned.
    """

    def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None:
        super().__init__()

        # Encoder
        prev_channels = 3
        modules = []
        img_length = 64
        for cur_channels in hiddens:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(prev_channels,
                              cur_channels,
                              kernel_size=3,
                              stride=2,
                              padding=1), nn.BatchNorm2d(cur_channels),
                    nn.ReLU()))
            prev_channels = cur_channels
            img_length //= 2
        self.encoder = nn.Sequential(*modules)
        # 生成后验的均值和方差
        self.mean_linear = nn.Linear(prev_channels * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(prev_channels * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim

        # Decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim, prev_channels * img_length * img_length)
        self.decoder_input_chw = (prev_channels, img_length, img_length)
        for i in range(len(hiddens) - 1, 0, -1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hiddens[i],
                                       hiddens[i - 1],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hiddens[i - 1]), nn.ReLU()))
        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(hiddens[0],
                                   hiddens[0],
                                   kernel_size=3,
                                   stride=2,
                                   padding=1,
                                   output_padding=1),
                nn.BatchNorm2d(hiddens[0]), nn.ReLU(),
                nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1),
                nn.ReLU()))
        self.decoder = nn.Sequential(*modules)

  整体结构还是比较清晰的。
  接下来是正向计算部分,注意重参数化技巧:

def sampling(mean, log_var):
    # Reparameterization Trick
    eps = torch.randn_like(log_var, requires_grad=True)  # 从标准高斯分布采样
    std = torch.exp(log_var / 2)  # 还原方差
    return eps * std + mean

class VAE(nn.Module):
    """
    VAE for 64x64 face generation. The hidden dimensions can be tuned.
    """
  
    def forward(self, x):
        encoded = self.encoder(x)
        encoded = torch.flatten(encoded, 1)
        # 计算后验的均值与方差
        mean = self.mean_linear(encoded)
        log_var = self.var_linear(encoded)
        z = sampling(mean, log_var) # 重参数化
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded, mean, log_var

  从重参数化代码中我们可以发现,我们这里拟合的是$\log \boldsymbol\sigma^{2(i)}$而不是直接拟合$\boldsymbol\sigma^{2(i)}$,这是由于方差非负,拟合方差需要加激活函数处理。

损失函数

  最后,我们需要定义损失函数。结合上一篇文章的内容,学习任务为:

$$ \begin{align} \max_{\boldsymbol\theta}L_{\boldsymbol\theta}&=\max_{\boldsymbol\theta}\sum_{\boldsymbol x}ELBO(\boldsymbol x;\boldsymbol\theta,\boldsymbol\phi)\\ &=\max_{\boldsymbol\theta} \mathbb{E}_{\boldsymbol z\sim q_{\boldsymbol\phi}(\boldsymbol z|\boldsymbol x)}\log p_{\boldsymbol\theta}(\boldsymbol x|\boldsymbol z) -KL(q_{\boldsymbol \phi}(\boldsymbol z|\boldsymbol x)\| p_{\boldsymbol \theta}(\boldsymbol z))\tag{1} \end{align} $$

等效为:

$$ \begin{align} &\min_{\boldsymbol\theta}L_{\boldsymbol\theta}=\min_{\boldsymbol\theta}\Bigg[ MSE(\boldsymbol x,\hat{\boldsymbol x}) +KL(q_{\boldsymbol \phi}(\boldsymbol z|\boldsymbol x)\| p_{\boldsymbol \theta}(\boldsymbol z))\Bigg]\\ &where:\\ &\qquad\qquad KL=\frac 12 \sum_{i=1}^d \Big(\sigma^{2(i)}+\mu^{2(i)}-\log \sigma^{2(i)}-1 \Big) \end{align}\tag{2} $$

对应到代码则是:

def loss_fn(x, x_hat, mean, logvar):
    recons_loss = F.mse_loss(x_hat, x)
    kl_loss = torch.mean(
        -0.5 * torch.sum(1 + logvar - mean ** 2 - torch.exp(logvar), 1), 0)
    loss = recons_loss + kl_loss * kl_weight
    return loss

  最后把训练的框架搭好就完成啦~下面是生成的图片示例,VAE生成的图片普遍比较糊。

VAE生成图片

如果觉得我的文章对你有用,请随意赞赏