Audio-Driven THG(四):Hallo代码详解

前言

  本文学习Hallo的代码。由于这是笔者第一次接触HF家的diffusers库,因此如果表述有误敬请谅解。代码为24.9.14号修改后的版本,笔者fork了一下:https://github.com/HJHGJGHHG/hallo/tree/original。diffusers的版本为0.24.0。

Inference框架

  首先我们来看推理,来到scripts/inference.py。在L157-L164我们首先处理输入图片reference img,除去其像素numpy外,我们还需要得到面部区域、面部特征以及lip、exp和pose三个mask。这里使用的是hallo/datasets/image_processor.py/ImageProcessor,值得注意的是这里给mask加上了高斯模糊,一定程度上可以加速训练收敛;此外三个mask也分别resize了一下得到了一个多尺度的列表。各项的形状为:(T表示Tensor)

with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
        source_image_pixels, \  # T(3, 512, 512), norm之后的
        source_image_face_region, \  # T(3, 512, 512), 0/1
        source_image_face_emb, \  # ndarray(512)
        source_image_full_mask, \  # [T(1,4096),T(1,1024),T(1,256),T(1,64)]
        source_image_face_mask, \  # [T(1,4096),T(1,1024),T(1,256),T(1,64)]
        source_image_lip_mask = \  # [T(1,4096),T(1,1024),T(1,256),T(1,64)]
        image_processor.preprocess(
            source_image_path, save_path, config.face_expand_ratio)

  之后提取audio的Wav2Vec特征,形状:

audio_emb, \  # T(wav2vec length,12,768)
audio_length = ...  # int

  而后L196-L221加载模型:VAE, reference UNet, denoising UNet和image、audio的projection模型。模型结构我们之后再介绍。

  之后再L292进入生成loop,循环次数为audio特征长度//一次生成的segment帧数,具体代码如下:

for t in range(times):
    print(f"[{t+1}/{times}]")

    if len(tensor_result) == 0:
        # The first iteration
        motion_zeros = source_image_pixels.repeat(config.data.n_motion_frames, 1, 1, 1)
        motion_zeros = motion_zeros.to(dtype=source_image_pixels.dtype, device=source_image_pixels.device)
        pixel_values_ref_img = torch.cat(
                [source_image_pixels, motion_zeros], dim=0)  # concat the ref image and the first motion frames
    else:
        motion_frames = tensor_result[-1][0]
        motion_frames = motion_frames.permute(1, 0, 2, 3)
        motion_frames = motion_frames[0-config.data.n_motion_frames:]
        motion_frames = motion_frames * 2.0 - 1.0
        motion_frames = motion_frames.to(dtype=source_image_pixels.dtype, device=source_image_pixels.device)
        pixel_values_ref_img = torch.cat(
                [source_image_pixels, motion_frames], dim=0)  # concat the ref image and the motion frames

    pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)

    audio_tensor = audio_emb[
            t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
        ]
    audio_tensor = audio_tensor.unsqueeze(0)
    audio_tensor = audio_tensor.to(
            device=net.audioproj.device, dtype=net.audioproj.dtype)
    audio_tensor = net.audioproj(audio_tensor)

    pipeline_output = pipeline(
            ref_image=pixel_values_ref_img,
            audio_tensor=audio_tensor,
            face_emb=source_image_face_emb,
            face_mask=source_image_face_region,
            pixel_values_full_mask=source_image_full_mask,
            pixel_values_face_mask=source_image_face_mask,
            pixel_values_lip_mask=source_image_lip_mask,
            width=img_size[0],
            height=img_size[1],
            video_length=clip_length,
            num_inference_steps=config.inference_steps,
            guidance_scale=config.cfg_scale,
            generator=generator,
            motion_scale=motion_scale,
        )

    tensor_result.append(pipeline_output.videos)

首先准备temporal layer要用到的motion frames。如果是第一个segment则用的是reference image的repeat,这里motion frames的长度为2。

motion_zeros = source_image_pixels.repeat(config.data.n_motion_frames, 1, 1, 1)  # T(2, C=3, H=512, W=512)
pixel_values_ref_img = torch.cat([source_image_pixels, motion_zeros], dim=0)  
# concat the ref image and the first motion frames

准备好reference image和audio的embedding就来到L323了,最后输出的pipeline_output仅包含输出的结果帧一个属性:.videos: T(1,C=3,f=16,H,W)

FaceAnimatePipeline

  接下来我们研究一下生成的具体过程,在animate/face_animate.py,父类是diffusers的核心DiffusionPipeline类,主要是将model、scheduler等各种组件组合成了一个高层的API。用过HF的transformers库的话感觉可以类比为PreTrainedModel ,也有from_pretrained方法,例如

from diffusers import DiffusionPipeline

repo_id = "runwayml/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)

  FaceAnimatePipeline的基本介绍在注释里说的比较清楚了,prepare和decode latents也比较简单,我们直接来看L249的__call__。首先设置降噪的步数,实际为40:

# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps

  而后是face embedding:

# prepare clip image embeddings
clip_image_embeds = face_emb  # T(1, 512)
clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype)

encoder_hidden_states = self.image_proj(clip_image_embeds)  # T(1, 4, 768)
uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds))

  在L300-313创建了逐层注入reference 特征的两个Controler,我们之后在模型部分再详细介绍。之后在L317准备denoising net的输入latents,prepare_latents方法比较简单,最后latents的形状为:[1, 4, f=16, 64, 64]。

  而后准备reference img的tensor:

# Prepare ref image latents
ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w")
ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width)  # (bs, c, width, height)
ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device)
ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
ref_image_latents = ref_image_latents * 0.18215  # (b, 4, h=64, w=64)

这里值得注意的是stable diffusion在VAE encoder之后用scale factor=0.18216对latent进行了一个放缩,主要是观察到过VAE encoder之后的latent的分布有差异,为了使其更接近$\mathcal N(\boldsymbol 0, \boldsymbol I)$,采用了这样一个因子来rescale,具体参见LDM原文。

  之后就进到denoising的循环了:

# denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
    for i, t in enumerate(timesteps):
        # Forward reference image
        if i == 0:
            self.reference_unet(
                ref_image_latents.repeat(
                    (2 if do_classifier_free_guidance else 1), 1, 1, 1
                ),
                torch.zeros_like(t),
                encoder_hidden_states=encoder_hidden_states,
                return_dict=False,
            )
            reference_control_reader.update(reference_control_writer)

        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

        noise_pred = self.denoising_unet(
            latent_model_input,
            t,
            encoder_hidden_states=encoder_hidden_states,
            mask_cond_fea=face_mask,
            full_mask=pixel_values_full_mask,
            face_mask=pixel_values_face_mask,
            lip_mask=pixel_values_lip_mask,
            audio_embedding=audio_tensor,
            motion_scale=motion_scale,
            return_dict=False,
        )[0]

        # perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

        # call the callback, if provided
        if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
            progress_bar.update()
            if callback is not None and i % callback_steps == 0:
                step_idx = i // getattr(self.scheduler, "order", 1)
                callback(step_idx, t, latents)

    reference_control_reader.clear()
    reference_control_writer.clear()

主要就是三步:计算reference UNet、denoising noise以及计算下一步的latents。最后结束降噪循环后得到的latents再decode就得到视频帧了:

# Post-processing
# latens: [1, 4, 16, 64, 64]
images = self.decode_latents(latents)  # (1, c=3, f=16, h=512, w=512)

模型

  模型目录的结构如下:

models
    attention.py  # attention的实现
    audio_proj.py  # audio的projection模块
    face_locator.py  # face mask的projection模块
    image_proj.py  # face embedding的projection模块
    motion_module.py  # 时序注意力相关
    mutual_self_attention.py  # reference attention相关
    resnet.py  # upsample, downsample和ResNet3D相关组件
    transformer_2d.py  # 用于reference UNet的Transformer2DModel
    transformer_3d.py  # 用于denoising UNet的Transformer3DModel
    unet_2d_blocks.py  # 2D UNet
    unet_2d_condition.py
    unet_3d.py  # 3D UNet
    unet_3d_blocks.py
    wav2vec.py
    __init__.py

  其中我们用到的几个模型以及其类分别为:

  • vae:AutoencoderKL
  • image_proj:ImageProjModel (image_proj.py)
  • face_locator:FaceLocator (face_locator.py)
  • audio_proj:AudioProjModel (audio_proj.py)
  • reference_unet:UNet2DConditionModel (unet_2d_condition.py)
  • denoising_unet:UNet3DConditionModel (unet_3d.py)

其中vae是直接用diffusers的from_pretrained,而几个projection比较简单,这里就不详细介绍了。我们主要介绍两个UNet的细节。

Reference UNet

  reference unet的结构如下,经典的Stable Diffusion UNet结构,包含down_block -> mid_block -> up_block,都是按照类的名称来创建的:

UNet2DConditionModel(
    (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (time_proj): Timesteps()
    (time_embedding): TimestepEmbedding(
      (linear_1): Linear(in_features=320, out_features=1280, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (down_blocks): ModuleList(
      (0-2): 3 x CrossAttnDownBlock2D(
          (attentions): ModuleList[(0-1): 2 x Transformer2DModel]
          (resnets): ModuleList[(0-1): 2 x ResnetBlock2D]
          (downsamplers): ModuleList[(0): Downsample2D]
      )
      (3): DownBlock2D(
        (resnets): ModuleList[(0-1): 2 x ResnetBlock2D]
      )
    )
    (mid_block): UNetMidBlock2DCrossAttn(
        (attentions): ModuleList[(0): Transformer2DModel]
        (resnets): ModuleList[(0-1): 2 x ResnetBlock2D]
    )
    (up_blocks): ModuleList(
      (0): UpBlock2D(
        (resnets): ModuleList[(0-2): 3 x ResnetBlock2]
        (upsamplers): ModuleList[(0): Upsample2D]
      )
      (1-3): 3 x CrossAttnUpBlock2D(
        (attentions): ModuleList[(0-2): 3 x Transformer2DModel]
        (resnets): ModuleList[(0-2): 3 x ResnetBlock2D]
        (upsamplers): ModuleList[(0): Upsample2D]
      )
    )
    (conv_norm_out): None
    (conv_act): SiLU()
)

下图更加清晰:

![Reference UNet (图源[1])](https://hjhgjghhg-1304275113.cos.ap-shanghai.myqcloud.com/usr/img/Virtual%20human/Talking%20Head/Talking%20Head%20Video%20Generation/Audio-Driven%20THG%EF%BC%88%E5%9B%9B%EF%BC%89%EF%BC%9AHallo%E4%BB%A3%E7%A0%81%E8%AF%A6%E8%A7%A3/Reference%20UNet.jpg)

Downsample2DResnetBlock2DUpsample2D都是常规的UNet组件,直接从diffusers库中导入的,具体实现可以看源码:

from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D

  因此reference UNet最为核心的就是Transformer2DModel,我们转到transformer_2d.py详细研究一番。我们以reference_unet.mid_block.attentions[0]为例,先看总体结构:

reference_unet.mid_block.attentions[0]: Transformer2DModel(
  (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
  (proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  (transformer_blocks): ModuleList(
    (0): BasicTransformerBlock(
      (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (attn1): Attention(
        (to_q): Linear(in_features=1280, out_features=1280, bias=False)
        (to_k): Linear(in_features=1280, out_features=1280, bias=False)
        (to_v): Linear(in_features=1280, out_features=1280, bias=False)
        (to_out): ModuleList(
          (0): Linear(in_features=1280, out_features=1280, bias=True)
          (1): Dropout(p=0.0, inplace=False)
        )
      )
      (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (attn2): Attention(
        (to_q): Linear(in_features=1280, out_features=1280, bias=False)
        (to_k): Linear(in_features=768, out_features=1280, bias=False)
        (to_v): Linear(in_features=768, out_features=1280, bias=False)
        (to_out): ModuleList(
          (0): Linear(in_features=1280, out_features=1280, bias=True)
          (1): Dropout(p=0.0, inplace=False)
        )
      )
      (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (ff): FeedForward(
        (net): ModuleList(
          (0): GEGLU(
            (proj): Linear(in_features=1280, out_features=10240, bias=True)
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=5120, out_features=1280, bias=True)
        )
      )
    )
  )
  (proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
)

显然就是一个attention前后加上卷积层。用到的是attention.py中的BasicTransformerBlock,就是最基础的3个norm,两个attention和一个ffn。下面是用默认值精简过的__init__

class BasicTransformerBlock(nn.Module):
    def __init__(
            self,
            dim: int = 320,
            num_attention_heads: int = 8,
            attention_head_dim: int = 40,
            dropout=0.0,
            cross_attention_dim: Optional[int] = None,
            activation_fn: str = "geglu",
            num_embeds_ada_norm: Optional[int] = None,
            attention_bias: bool = False,
            only_cross_attention: bool = False,
            double_self_attention: bool = False,
            upcast_attention: bool = False,
            norm_elementwise_affine: bool = True,
            # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
            norm_type: str = "layer_norm",
            norm_eps: float = 1e-5,
            final_dropout: bool = False,
            attention_type: str = "default",
            positional_embeddings: Optional[str] = None,
            num_positional_embeddings: Optional[int] = None,
    ):
        super().__init__()
        if positional_embeddings == "sinusoidal":
            self.pos_embed = SinusoidalPositionalEmbedding(
                dim, max_seq_length=num_positional_embeddings
            )
        else:
            self.pos_embed = None

        # Define 3 blocks. Each block has its own normalization layer.
        # 1. Self-Attn
        self.norm1 = nn.LayerNorm(
            dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
        )

        self.attn1 = Attention(
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
            upcast_attention=upcast_attention,
        )

        # 2. Cross-Attn
        if cross_attention_dim is not None or double_self_attention:
            self.norm2 = nn.LayerNorm(
                dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
            )
            self.attn2 = Attention(
                query_dim=dim,
                cross_attention_dim=(
                    cross_attention_dim if not double_self_attention else None
                ),
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                upcast_attention=upcast_attention,
            )  # is self-attn if encoder_hidden_states is none
        else:
            self.norm2 = None
            self.attn2 = None

        # 3. Feed-forward
        if not self.use_ada_layer_norm_single:
            self.norm3 = nn.LayerNorm(
                dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
            )

        self.ff = FeedForward(
            dim,
            dropout=dropout,
            activation_fn=activation_fn,
            final_dropout=final_dropout,
        )

        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = 0

这里的attention模块用的是diffusers自带的:

from diffusers.models.attention import (AdaLayerNorm, AdaLayerNormZero,
                                        Attention, FeedForward)

  而FFN中的激活函数换成了大模型中比较火的GEGLU,在GPT的GeLU函数上加入门控,有点像T5的GLU和FLASH的GAU的思路:

class GEGLU(nn.Module):
    r"""
    A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.

    Parameters:
        dim_in (`int`): The number of channels in the input.
        dim_out (`int`): The number of channels in the output.
        bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
    """

    def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)

    def gelu(self, gate: torch.Tensor) -> torch.Tensor:
        if gate.device.type != "mps":
            return F.gelu(gate)
        # mps: gelu is not implemented for float16
        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

    def forward(self, hidden_states, *args, **kwargs):
        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
        return hidden_states * self.gelu(gate)

Denoinsing UNet

  接下来我们看核心的3D UNet的组成。总体上与Reference UNet一致,只是2D换为了3D:

UNet3DConditionModel(
  (conv_in): InflatedConv3d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0-2): 3 x CrossAttnDownBlock3D(
      (attentions): ModuleList[(0-1): 2 x Transformer3DModel]
      (resnets): ModuleList[(0-1): 2 x ResnetBlock3D]
      (audio_modules): ModuleList[(0-1): 2 x Transformer3DModel]
      (motion_modules): ModuleList[(0-1): 2 x VanillaTemporalModule]
      (downsamplers): ModuleList[(0): Downsample3D]
    )
    (3): DownBlock3D(
      (resnets): ModuleList[(0-1): 2 x ResnetBlock3D]
      (motion_modules): ModuleList[(0-1): 2 x VanillaTemporalModule]
    )
  )
  (mid_block): UNetMidBlock3DCrossAttn(
    (attentions): ModuleList[(0): Transformer3DModel]
    (resnets): ModuleList[(0-1): 2 x ResnetBlock3D]
    (audio_modules): ModuleList[(0): Transformer3DModel]
    (motion_modules): ModuleList[(0): VanillaTemporalModule]
  )
  (up_blocks): ModuleList(
    (0): UpBlock3D(
      (resnets): ModuleList[(0-2): 3 x ResnetBlock3D]
      (motion_modules): ModuleList[(0-2): 3 x VanillaTemporalModule]
      (upsamplers): ModuleList[(0): Upsample3D]
    )
    (1-3): 3 x CrossAttnUpBlock3D(
      (attentions): ModuleList[(0-2): 3 x Transformer3DModel]
      (resnets): ModuleList[(0-2): 3 x ResnetBlock3D]
      (audio_modules): ModuleList[(0-2): 3 x Transformer3DModel]
      (motion_modules): ModuleList[(0-2): 3 x VanillaTemporalModule]
      (upsamplers): ModuleList[(0): Upsample3D]
    )
  )
  (conv_norm_out): InflatedGroupNorm(32, 320, eps=1e-05, affine=True)
  (conv_act): SiLU()
  (conv_out): InflatedConv3d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

Downsample3DResnetBlock3DUpsample3D就是把常规的卷积换成了InflatedConv3d,就是rearrange在C, H, W上应用卷积:

class InflatedConv3d(nn.Conv2d):
    def forward(self, x):
        video_length = x.shape[2]
        x = rearrange(x, "b c f h w -> (b f) c h w")
        x = super().forward(x)
        x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
        return x
这里笔者有个疑问,为啥不rearrange到$B\times F \times C\times H\times W$后用Conv3D呢?还是说不想在卷积这里影响cross-frame?

  除此之外denoising UNet就是两个核心部分了:Transformer3DModel和时序模块VanillaTemporalModule。我们先来看transformer_3d.py中的3Dtransformer。需要注意的是,这里与2Dtransformer不同,L89-L131根据输入参数use_audio_module的不同而选用了两种不同的attention模块:AudioTemporalBasicTransformerBlockTemporalBasicTransformerBlock。事实上上面结构图中每个cross attention block里attentions用的是后者;而audio_modules则用的是前者。顾名思义,前者则是论文中audio与三个mask交互的部分。

  转到attention.py,我们首先来看TemporalBasicTransformerBlock,其与上文我们介绍的BasicTransformerBlock最大的不同就是L522-L540额外定义了一个temporal attention层:

# Temp-Attn
# assert unet_use_temporal_attention is not None
if unet_use_temporal_attention is None:
    unet_use_temporal_attention = False
if unet_use_temporal_attention:
    self.attn_temp = Attention(
        query_dim=dim,
        heads=num_attention_heads,
        dim_head=attention_head_dim,
        dropout=dropout,
        bias=attention_bias,
        upcast_attention=upcast_attention,
    )
    nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
    self.norm_temp = (
        AdaLayerNorm(dim, num_embeds_ada_norm)
        if self.use_ada_layer_norm
        else nn.LayerNorm(dim)
    )

但事实上这个atten_temp没有用到,我们对比两个UNet中相同位置的attention可以发现是完全相同的,就比如:

print(denoising_unet.down_blocks[1].attentions[0].transformer_blocks[0])
print(reference_unet.down_blocks[1].attentions[0].transformer_blocks[0])

  我们再看AudioTemporalBasicTransformerBlock。首先L691-L701就是注入三个信息的零卷积模块:

zero_conv_full = nn.Conv2d(dim, dim, kernel_size=1)
self.zero_conv_full = zero_module(zero_conv_full)

zero_conv_face = nn.Conv2d(dim, dim, kernel_size=1)
self.zero_conv_face = zero_module(zero_conv_face)

zero_conv_lip = nn.Conv2d(dim, dim, kernel_size=1)
self.zero_conv_lip = zero_module(zero_conv_lip)

  第一个attention,norm和最后的FFN不变,而第二个cross attention改为了三个attention:

self.attn2_0 = Attention(
    query_dim=dim,
    cross_attention_dim=cross_attention_dim,
    heads=num_attention_heads,
    dim_head=attention_head_dim,
    dropout=dropout,
    bias=attention_bias,
    upcast_attention=upcast_attention,
)
self.attn2_1 = Attention(
    query_dim=dim,
    cross_attention_dim=cross_attention_dim,
    heads=num_attention_heads,
    dim_head=attention_head_dim,
    dropout=dropout,
    bias=attention_bias,
    upcast_attention=upcast_attention,
)
self.attn2_2 = Attention(
    query_dim=dim,
    cross_attention_dim=cross_attention_dim,
    heads=num_attention_heads,
    dim_head=attention_head_dim,
    dropout=dropout,
    bias=attention_bias,
    upcast_attention=upcast_attention,
)
self.attn2 = None

  在forward时,首先还是过第一个norm和attention:

norm_hidden_states = self.norm1(hidden_states)
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states

而后转到L846,过完第二个norm后分别进行对应的attention,attention的结果乘上对应的mask,再过零卷积。以full为例:

level = self.depth  # 该层所在transformer模块的位置,用以判断使用哪个mask
full_hidden_states = (
        self.attn2_0(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
        ) * full_mask[level][:, :, None]
)
# 为了使用2D卷积,将bz * sz * c形状的hidden states展开
bz, sz, c = full_hidden_states.shape
sz_sqrt = int(sz ** 0.5)
full_hidden_states = full_hidden_states.reshape(
    bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
full_hidden_states = self.zero_conv_full(full_hidden_states).permute(0, 2, 3, 1).reshape(bz, -1, c)

之后再将得到的三个hidden states和第一个attention后得到的hidden state相加:

hidden_states = (
                    motion_scale[0] * full_hidden_states +
                    motion_scale[1] * face_hidden_state +
                    motion_scale[2] * lip_hidden_state + hidden_states
                )

其中motion scale即是三个部分的权重,默认均为1.0。原文中也展示了一些调节权重的实验。

Temporal Modules

  接下来我们研究一下时序模块。在motion_module.py中,我们以下面的模型实例看一下VanillaTemporalModule的具体组成:

denoising_unet.down_blocks[1].motion_modules[0]

结构为:

VanillaTemporalModule(
  (temporal_transformer): TemporalTransformer3DModel(
    (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
    (proj_in): Linear(in_features=640, out_features=640, bias=True)
    (transformer_blocks): ModuleList(
      (0): TemporalTransformerBlock(
        (attention_blocks): ModuleList(
          (0-1): 2 x VersatileAttention(
            (Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=640, out_features=640, bias=False)
            (to_v): Linear(in_features=640, out_features=640, bias=False)
            (to_out): ModuleList(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
            (pos_encoder): PositionalEncoding(
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
        )
        (norms): ModuleList(
          (0-1): 2 x LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
        (ff): FeedForward(
          (net): ModuleList(
            (0): GEGLU(
              (proj): Linear(in_features=640, out_features=5120, bias=True)
            )
            (1): Dropout(p=0.0, inplace=False)
            (2): Linear(in_features=2560, out_features=640, bias=True)
          )
        )
        (ff_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
      )
    )
    (proj_out): Linear(in_features=640, out_features=640, bias=True)
  )
)

  可见整体结构上还是和上文介绍的几个attention保持一致。我们直接来看TemporalTransformerBlock的forward

def forward(self, hidden_states, encoder_hidden_states=None):
    assert (
            hidden_states.dim() == 5
    ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
    video_length = hidden_states.shape[2]  # 18
    hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")  # [bf, c, h, w]

    batch, _, height, weight = hidden_states.shape
    residual = hidden_states

    hidden_states = self.norm(hidden_states)
    inner_dim = hidden_states.shape[1]  # channels
    hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
        batch, height * weight, inner_dim
    )  # [bf, h*w, c]
    hidden_states = self.proj_in(hidden_states)

    # Transformer Blocks
    for block in self.transformer_blocks:
        hidden_states = block(
            hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            video_length=video_length,
        )

    # output
    hidden_states = self.proj_out(hidden_states)
    hidden_states = (
        hidden_states.reshape(batch, height, weight, inner_dim)
        .permute(0, 3, 1, 2)
        .contiguous()
    )

    output = hidden_states + residual
    output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)

    return output

其中hidden_states的形状为[B, channel, L, H, W],其中$L=18$为$f=16$与motion frames$=2$之和,而这里的encoder_hidden_states就是我们在face_animate.pyL294-L298的face embedding,这里只是为了统一接口传进来了,事实并没有用到。而后我们直接转到VersatileAttention的forward:

def forward(
        self,
        hidden_states,  # [bf, h*w, c]
        encoder_hidden_states=None,  # None
        attention_mask=None,  # None
        video_length=None,  # 18
        **cross_attention_kwargs,  # empty
):
    if self.attention_mode == "Temporal":
        d = hidden_states.shape[1]  # d means HxW
        hidden_states = rearrange(
            hidden_states, "(b f) d c -> (b d) f c", f=video_length
        )  # [b*h*w, f, c]
        # add sinusoidal PE
        if self.pos_encoder is not None:
            hidden_states = self.pos_encoder(hidden_states)

        encoder_hidden_states = (
            repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
            if encoder_hidden_states is not None
            else encoder_hidden_states
        )
    else:
        raise NotImplementedError
    hidden_states = self.processor(
        self,
        hidden_states,
        encoder_hidden_states=encoder_hidden_states,
        attention_mask=attention_mask,
        **cross_attention_kwargs,
    )
    if self.attention_mode == "Temporal":
        hidden_states = rearrange(
            hidden_states, "(b d) f c -> (b f) d c", d=d)
    return hidden_states

  这里先是将$\boldsymbol x\in \mathbb R^{b\times c\times f\times h\times w}$reshape到$(b\times h\times w)\times f\times c$,而后给后两维temporal和channel上加入sinusoidal的位置编码;而后调用父类diffusers.models.attention_processor.Attentionprocessor属性进行注意力计算。

  这里Attention类在初始化时可以自定义,否则就是按diffusers/models/attention_processor.py中L729-L798的AttnProcessor()进行默认初始化。它只有一个__call__方法:

class AttnProcessor:
    r"""
    Default processor for performing attention-related computations.
    """

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        residual = hidden_states
        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
            
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

这里略去了一些norm和形状判断的细节,可见encoder_hidden_states就是注意力计算中的$K,V$,如果有控制信号的输入就是cross attention;否则就和我们VersatileAttention这里一样是self attention。

Reference Attention

  最后我们介绍一下denoising UNet的self attention和reference attention。由于我们需要将reference UNet的特征按对应层进行cross attention,而且我们还需要保存作为motion frames的特征,层数一多就容易乱,而且会有很多冗长的通信代码。因此Hallo这里额外定义了一个模型,将通信和保存直接写在对应attention类的forward方法里。我们详细学习一下。

  在mutual_self_attention.py中定义了通信模型ReferenceAttentionControl。在face_animate.py中L300-L313初始化:

reference_control_writer = ReferenceAttentionControl(
    self.reference_unet,
    do_classifier_free_guidance=do_classifier_free_guidance,
    mode="write",
    batch_size=batch_size,
    fusion_blocks="full",
)
reference_control_reader = ReferenceAttentionControl(
    self.denoising_unet,
    do_classifier_free_guidance=do_classifier_free_guidance,
    mode="read",
    batch_size=batch_size,
    fusion_blocks="full",
)

在L390-L399我们发现ReferenceAttentionControl在初始化时用__get__将方法hacked_basic_transformer_inner_forward替换了两个attention layer原来的forward方法:

for i, module in enumerate(attn_modules):
    module._original_inner_forward = module.forward
    if isinstance(module, BasicTransformerBlock):
        module.forward = hacked_basic_transformer_inner_forward.__get__(
            module,
            BasicTransformerBlock)
    if isinstance(module, TemporalBasicTransformerBlock):
        module.forward = hacked_basic_transformer_inner_forward.__get__(
            module,
            TemporalBasicTransformerBlock)

  首先我们看reference net中的操作,当MODE='weite'时:

def hacked_basic_transformer_inner_forward(
        ...
):
    norm_hidden_states = self.norm1(hidden_states)
    # 1. Self-Attention
    # self.only_cross_attention = False
    cross_attention_kwargs = (
        cross_attention_kwargs if cross_attention_kwargs is not None else {}
    )
    if MODE == "write":
        # Reference UNet
        self.bank.append(norm_hidden_states.clone())
        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=(
                encoder_hidden_states if self.only_cross_attention else None
            ),
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
        
    hidden_states = attn_output + hidden_states
    if self.attn2 is not None:
        norm_hidden_states = self.norm2(hidden_states)

        # 2. Cross-Attention
        tmp = norm_hidden_states.shape[0] // encoder_hidden_states.shape[0]
        attn_output = self.attn2(
            norm_hidden_states,
            # TODO: repeat这个地方需要斟酌一下
            encoder_hidden_states=encoder_hidden_states.repeat(
                tmp, 1, 1),
            attention_mask=encoder_attention_mask,
            **cross_attention_kwargs,
        )
        hidden_states = attn_output + hidden_states

    # 3. Feed-forward
    norm_hidden_states = self.norm3(hidden_states)
    ff_output = self.ff(norm_hidden_states)
    hidden_states = ff_output + hidden_states
    return hidden_states

可见就是常规attention,只是第二个cross attention使用face embedding。在L224将norm之后的hidden states存入bank。reference UNet的forward结束后,在face_animate.py的L395将bank给denoising UNet的对应层:

reference_control_reader.update(reference_control_writer)
def update(self, writer, dtype=torch.float16):
    if self.reference_attn:
        if self.fusion_blocks == "midup":
            pass
        elif self.fusion_blocks == "full":
            reader_attn_modules = [
                module
                for module in torch_dfs(self.unet)
                if isinstance(module, TemporalBasicTransformerBlock)
            ]
            writer_attn_modules = [
                module
                for module in torch_dfs(writer.unet)
                if isinstance(module, BasicTransformerBlock)
            ]

        assert len(reader_attn_modules) == len(writer_attn_modules)
        reader_attn_modules = sorted(
            reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
        )
        writer_attn_modules = sorted(
            writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
        )
        for r, w in zip(reader_attn_modules, writer_attn_modules):
            r.bank = [v.clone().to(dtype) for v in w.bank]

  之后我们再来看denoising UNet的操作:

def hacked_basic_transformer_inner_forward(
        self,
        hidden_states: torch.FloatTensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        timestep: Optional[torch.LongTensor] = None,
        cross_attention_kwargs: Dict[str, Any] = None,
        class_labels: Optional[torch.LongTensor] = None,
        video_length=None,
):
    norm_hidden_states = self.norm1(hidden_states)

    # 1. Self-Attention
    # self.only_cross_attention = False
    cross_attention_kwargs = (
        cross_attention_kwargs if cross_attention_kwargs is not None else {}
    )
    if MODE == "read":
        # denoising net
        bank_fea = [
            rearrange(
                rearrange(
                    d,
                    "(b s) l c -> b s l c",
                    b=norm_hidden_states.shape[0] // video_length,
                )[:, 0, :, :]
                # .unsqueeze(1)
                .repeat(1, video_length, 1, 1),
                "b t l c -> (b t) l c",
            )
            for d in self.bank
        ]
        motion_frames_fea = [rearrange(
            d,
            "(b s) l c -> b s l c",
            b=norm_hidden_states.shape[0] // video_length,
        )[:, 1:, :, :] for d in self.bank]
        modify_norm_hidden_states = torch.cat(
            [norm_hidden_states] + bank_fea, dim=1
        )
        hidden_states_uc = (
                self.attn1(
                    norm_hidden_states,
                    encoder_hidden_states=modify_norm_hidden_states,
                    attention_mask=attention_mask,
                )
                + hidden_states
        )
        if do_classifier_free_guidance:
            hidden_states_c = hidden_states_uc.clone()
            _uc_mask = uc_mask.clone()
            if hidden_states.shape[0] != _uc_mask.shape[0]:
                _uc_mask = (
                    torch.Tensor(
                        [1] * (hidden_states.shape[0] // 2)
                        + [0] * (hidden_states.shape[0] // 2)
                    )
                    .to(device)
                    .bool()
                )
            hidden_states_c[_uc_mask] = (
                    self.attn1(
                        norm_hidden_states[_uc_mask],
                        encoder_hidden_states=norm_hidden_states[_uc_mask],
                        attention_mask=attention_mask,
                    )
                    + hidden_states[_uc_mask]
            )
            hidden_states = hidden_states_c.clone()
        else:
            hidden_states = hidden_states_uc

        # self.bank.clear()
        if self.attn2 is not None:
            # Cross-Attention
            norm_hidden_states = (
                self.norm2(hidden_states, timestep)
                if self.use_ada_layer_norm
                else self.norm2(hidden_states)
            )
            hidden_states = (
                    self.attn2(
                        norm_hidden_states,
                        encoder_hidden_states=encoder_hidden_states,
                        attention_mask=attention_mask,
                    )
                    + hidden_states
            )

        # Feed-forward
        hidden_states = self.ff(self.norm3(
            hidden_states)) + hidden_states

        # Temporal-Attention
        if self.unet_use_temporal_attention:
            d = hidden_states.shape[1]
            hidden_states = rearrange(
                hidden_states, "(b f) d c -> (b d) f c", f=video_length
            )
            norm_hidden_states = (
                self.norm_temp(hidden_states, timestep)
                if self.use_ada_layer_norm
                else self.norm_temp(hidden_states)
            )
            hidden_states = (
                    self.attn_temp(norm_hidden_states) + hidden_states
            )
            hidden_states = rearrange(
                hidden_states, "(b d) f c -> (b f) d c", d=d
            )

        return hidden_states, motion_frames_fea

Hallo这里reference attention是和animate anyone一样的spatial attention:先整理bank内的特征的形状,而后L253-255将hidden states和bank_fea拼接起来;作为K, V和hidden states进行cross attention。之后hidden states再与face embedding CA,就与reference net的操作相同了。

References

[1] 硬核解读Stable Diffusion(完整版)

如果觉得我的文章对你有用,请随意赞赏