Audio-Driven THG(四):Hallo代码详解
前言
本文学习Hallo的代码。由于这是笔者第一次接触HF家的diffusers库,因此如果表述有误敬请谅解。代码为24.9.14号修改后的版本,笔者fork了一下:https://github.com/HJHGJGHHG/hallo/tree/original。diffusers的版本为0.24.0。
Inference框架
首先我们来看推理,来到scripts/inference.py
。在L157-L164我们首先处理输入图片reference img,除去其像素numpy外,我们还需要得到面部区域、面部特征以及lip、exp和pose三个mask。这里使用的是hallo/datasets/image_processor.py/ImageProcessor
,值得注意的是这里给mask加上了高斯模糊,一定程度上可以加速训练收敛;此外三个mask也分别resize了一下得到了一个多尺度的列表。各项的形状为:(T表示Tensor)
with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
source_image_pixels, \ # T(3, 512, 512), norm之后的
source_image_face_region, \ # T(3, 512, 512), 0/1
source_image_face_emb, \ # ndarray(512)
source_image_full_mask, \ # [T(1,4096),T(1,1024),T(1,256),T(1,64)]
source_image_face_mask, \ # [T(1,4096),T(1,1024),T(1,256),T(1,64)]
source_image_lip_mask = \ # [T(1,4096),T(1,1024),T(1,256),T(1,64)]
image_processor.preprocess(
source_image_path, save_path, config.face_expand_ratio)
之后提取audio的Wav2Vec特征,形状:
audio_emb, \ # T(wav2vec length,12,768)
audio_length = ... # int
而后L196-L221加载模型:VAE, reference UNet, denoising UNet
和image、audio的projection模型。模型结构我们之后再介绍。
之后再L292进入生成loop,循环次数为audio特征长度//一次生成的segment帧数,具体代码如下:
for t in range(times):
print(f"[{t+1}/{times}]")
if len(tensor_result) == 0:
# The first iteration
motion_zeros = source_image_pixels.repeat(config.data.n_motion_frames, 1, 1, 1)
motion_zeros = motion_zeros.to(dtype=source_image_pixels.dtype, device=source_image_pixels.device)
pixel_values_ref_img = torch.cat(
[source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
else:
motion_frames = tensor_result[-1][0]
motion_frames = motion_frames.permute(1, 0, 2, 3)
motion_frames = motion_frames[0-config.data.n_motion_frames:]
motion_frames = motion_frames * 2.0 - 1.0
motion_frames = motion_frames.to(dtype=source_image_pixels.dtype, device=source_image_pixels.device)
pixel_values_ref_img = torch.cat(
[source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
audio_tensor = audio_emb[
t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
]
audio_tensor = audio_tensor.unsqueeze(0)
audio_tensor = audio_tensor.to(
device=net.audioproj.device, dtype=net.audioproj.dtype)
audio_tensor = net.audioproj(audio_tensor)
pipeline_output = pipeline(
ref_image=pixel_values_ref_img,
audio_tensor=audio_tensor,
face_emb=source_image_face_emb,
face_mask=source_image_face_region,
pixel_values_full_mask=source_image_full_mask,
pixel_values_face_mask=source_image_face_mask,
pixel_values_lip_mask=source_image_lip_mask,
width=img_size[0],
height=img_size[1],
video_length=clip_length,
num_inference_steps=config.inference_steps,
guidance_scale=config.cfg_scale,
generator=generator,
motion_scale=motion_scale,
)
tensor_result.append(pipeline_output.videos)
首先准备temporal layer要用到的motion frames。如果是第一个segment则用的是reference image的repeat,这里motion frames的长度为2。
motion_zeros = source_image_pixels.repeat(config.data.n_motion_frames, 1, 1, 1) # T(2, C=3, H=512, W=512)
pixel_values_ref_img = torch.cat([source_image_pixels, motion_zeros], dim=0)
# concat the ref image and the first motion frames
准备好reference image和audio的embedding就来到L323了,最后输出的pipeline_output
仅包含输出的结果帧一个属性:.videos: T(1,C=3,f=16,H,W)
。
FaceAnimatePipeline
接下来我们研究一下生成的具体过程,在animate/face_animate.py
,父类是diffusers的核心DiffusionPipeline
类,主要是将model、scheduler等各种组件组合成了一个高层的API。用过HF的transformers库的话感觉可以类比为PreTrainedModel
,也有from_pretrained
方法,例如
from diffusers import DiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
FaceAnimatePipeline的基本介绍在注释里说的比较清楚了,prepare和decode latents也比较简单,我们直接来看L249的__call__
。首先设置降噪的步数,实际为40:
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
而后是face embedding:
# prepare clip image embeddings
clip_image_embeds = face_emb # T(1, 512)
clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype)
encoder_hidden_states = self.image_proj(clip_image_embeds) # T(1, 4, 768)
uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds))
在L300-313创建了逐层注入reference 特征的两个Controler,我们之后在模型部分再详细介绍。之后在L317准备denoising net的输入latents,prepare_latents
方法比较简单,最后latents的形状为:[1, 4, f=16, 64, 64]。
而后准备reference img的tensor:
# Prepare ref image latents
ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w")
ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width) # (bs, c, width, height)
ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device)
ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h=64, w=64)
这里值得注意的是stable diffusion在VAE encoder之后用scale factor=0.18216对latent进行了一个放缩,主要是观察到过VAE encoder之后的latent的分布有差异,为了使其更接近$\mathcal N(\boldsymbol 0, \boldsymbol I)$,采用了这样一个因子来rescale,具体参见LDM原文。
之后就进到denoising的循环了:
# denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Forward reference image
if i == 0:
self.reference_unet(
ref_image_latents.repeat(
(2 if do_classifier_free_guidance else 1), 1, 1, 1
),
torch.zeros_like(t),
encoder_hidden_states=encoder_hidden_states,
return_dict=False,
)
reference_control_reader.update(reference_control_writer)
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
noise_pred = self.denoising_unet(
latent_model_input,
t,
encoder_hidden_states=encoder_hidden_states,
mask_cond_fea=face_mask,
full_mask=pixel_values_full_mask,
face_mask=pixel_values_face_mask,
lip_mask=pixel_values_lip_mask,
audio_embedding=audio_tensor,
motion_scale=motion_scale,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
reference_control_reader.clear()
reference_control_writer.clear()
主要就是三步:计算reference UNet、denoising noise以及计算下一步的latents。最后结束降噪循环后得到的latents再decode就得到视频帧了:
# Post-processing
# latens: [1, 4, 16, 64, 64]
images = self.decode_latents(latents) # (1, c=3, f=16, h=512, w=512)
模型
模型目录的结构如下:
models
attention.py # attention的实现
audio_proj.py # audio的projection模块
face_locator.py # face mask的projection模块
image_proj.py # face embedding的projection模块
motion_module.py # 时序注意力相关
mutual_self_attention.py # reference attention相关
resnet.py # upsample, downsample和ResNet3D相关组件
transformer_2d.py # 用于reference UNet的Transformer2DModel
transformer_3d.py # 用于denoising UNet的Transformer3DModel
unet_2d_blocks.py # 2D UNet
unet_2d_condition.py
unet_3d.py # 3D UNet
unet_3d_blocks.py
wav2vec.py
__init__.py
其中我们用到的几个模型以及其类分别为:
- vae:AutoencoderKL
- image_proj:ImageProjModel (image_proj.py)
- face_locator:FaceLocator (face_locator.py)
- audio_proj:AudioProjModel (audio_proj.py)
- reference_unet:UNet2DConditionModel (unet_2d_condition.py)
- denoising_unet:UNet3DConditionModel (unet_3d.py)
其中vae是直接用diffusers的from_pretrained
,而几个projection比较简单,这里就不详细介绍了。我们主要介绍两个UNet的细节。
Reference UNet
reference unet的结构如下,经典的Stable Diffusion UNet结构,包含down_block -> mid_block -> up_block,都是按照类的名称来创建的:
UNet2DConditionModel(
(conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_proj): Timesteps()
(time_embedding): TimestepEmbedding(
(linear_1): Linear(in_features=320, out_features=1280, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=1280, out_features=1280, bias=True)
)
(down_blocks): ModuleList(
(0-2): 3 x CrossAttnDownBlock2D(
(attentions): ModuleList[(0-1): 2 x Transformer2DModel]
(resnets): ModuleList[(0-1): 2 x ResnetBlock2D]
(downsamplers): ModuleList[(0): Downsample2D]
)
(3): DownBlock2D(
(resnets): ModuleList[(0-1): 2 x ResnetBlock2D]
)
)
(mid_block): UNetMidBlock2DCrossAttn(
(attentions): ModuleList[(0): Transformer2DModel]
(resnets): ModuleList[(0-1): 2 x ResnetBlock2D]
)
(up_blocks): ModuleList(
(0): UpBlock2D(
(resnets): ModuleList[(0-2): 3 x ResnetBlock2]
(upsamplers): ModuleList[(0): Upsample2D]
)
(1-3): 3 x CrossAttnUpBlock2D(
(attentions): ModuleList[(0-2): 3 x Transformer2DModel]
(resnets): ModuleList[(0-2): 3 x ResnetBlock2D]
(upsamplers): ModuleList[(0): Upsample2D]
)
)
(conv_norm_out): None
(conv_act): SiLU()
)
下图更加清晰:
Downsample2D
,ResnetBlock2D
和Upsample2D
都是常规的UNet组件,直接从diffusers库中导入的,具体实现可以看源码:
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
因此reference UNet最为核心的就是Transformer2DModel,我们转到transformer_2d.py
详细研究一番。我们以reference_unet.mid_block.attentions[0]
为例,先看总体结构:
reference_unet.mid_block.attentions[0]: Transformer2DModel(
(norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
(proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(attn1): Attention(
(to_q): Linear(in_features=1280, out_features=1280, bias=False)
(to_k): Linear(in_features=1280, out_features=1280, bias=False)
(to_v): Linear(in_features=1280, out_features=1280, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=1280, out_features=1280, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(attn2): Attention(
(to_q): Linear(in_features=1280, out_features=1280, bias=False)
(to_k): Linear(in_features=768, out_features=1280, bias=False)
(to_v): Linear(in_features=768, out_features=1280, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=1280, out_features=1280, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): Linear(in_features=1280, out_features=10240, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=5120, out_features=1280, bias=True)
)
)
)
)
(proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
)
显然就是一个attention前后加上卷积层。用到的是attention.py
中的BasicTransformerBlock,就是最基础的3个norm,两个attention和一个ffn。下面是用默认值精简过的__init__
:
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int = 320,
num_attention_heads: int = 8,
attention_head_dim: int = 40,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
# 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_type: str = "layer_norm",
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(
dim, max_seq_length=num_positional_embeddings
)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
self.norm2 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=(
cross_attention_dim if not double_self_attention else None
),
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
这里的attention模块用的是diffusers自带的:
from diffusers.models.attention import (AdaLayerNorm, AdaLayerNormZero,
Attention, FeedForward)
而FFN中的激活函数换成了大模型中比较火的GEGLU,在GPT的GeLU函数上加入门控,有点像T5的GLU和FLASH的GAU的思路:
class GEGLU(nn.Module):
r"""
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states, *args, **kwargs):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
Denoinsing UNet
接下来我们看核心的3D UNet的组成。总体上与Reference UNet一致,只是2D换为了3D:
UNet3DConditionModel(
(conv_in): InflatedConv3d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_proj): Timesteps()
(time_embedding): TimestepEmbedding(
(linear_1): Linear(in_features=320, out_features=1280, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=1280, out_features=1280, bias=True)
)
(down_blocks): ModuleList(
(0-2): 3 x CrossAttnDownBlock3D(
(attentions): ModuleList[(0-1): 2 x Transformer3DModel]
(resnets): ModuleList[(0-1): 2 x ResnetBlock3D]
(audio_modules): ModuleList[(0-1): 2 x Transformer3DModel]
(motion_modules): ModuleList[(0-1): 2 x VanillaTemporalModule]
(downsamplers): ModuleList[(0): Downsample3D]
)
(3): DownBlock3D(
(resnets): ModuleList[(0-1): 2 x ResnetBlock3D]
(motion_modules): ModuleList[(0-1): 2 x VanillaTemporalModule]
)
)
(mid_block): UNetMidBlock3DCrossAttn(
(attentions): ModuleList[(0): Transformer3DModel]
(resnets): ModuleList[(0-1): 2 x ResnetBlock3D]
(audio_modules): ModuleList[(0): Transformer3DModel]
(motion_modules): ModuleList[(0): VanillaTemporalModule]
)
(up_blocks): ModuleList(
(0): UpBlock3D(
(resnets): ModuleList[(0-2): 3 x ResnetBlock3D]
(motion_modules): ModuleList[(0-2): 3 x VanillaTemporalModule]
(upsamplers): ModuleList[(0): Upsample3D]
)
(1-3): 3 x CrossAttnUpBlock3D(
(attentions): ModuleList[(0-2): 3 x Transformer3DModel]
(resnets): ModuleList[(0-2): 3 x ResnetBlock3D]
(audio_modules): ModuleList[(0-2): 3 x Transformer3DModel]
(motion_modules): ModuleList[(0-2): 3 x VanillaTemporalModule]
(upsamplers): ModuleList[(0): Upsample3D]
)
)
(conv_norm_out): InflatedGroupNorm(32, 320, eps=1e-05, affine=True)
(conv_act): SiLU()
(conv_out): InflatedConv3d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
Downsample3D
,ResnetBlock3D
和Upsample3D
就是把常规的卷积换成了InflatedConv3d
,就是rearrange在C, H, W上应用卷积:
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
这里笔者有个疑问,为啥不rearrange到$B\times F \times C\times H\times W$后用Conv3D呢?还是说不想在卷积这里影响cross-frame?
除此之外denoising UNet就是两个核心部分了:Transformer3DModel
和时序模块VanillaTemporalModule
。我们先来看transformer_3d.py
中的3Dtransformer。需要注意的是,这里与2Dtransformer不同,L89-L131根据输入参数use_audio_module
的不同而选用了两种不同的attention模块:AudioTemporalBasicTransformerBlock
和TemporalBasicTransformerBlock
。事实上上面结构图中每个cross attention block里attentions
用的是后者;而audio_modules
则用的是前者。顾名思义,前者则是论文中audio与三个mask交互的部分。
转到attention.py
,我们首先来看TemporalBasicTransformerBlock
,其与上文我们介绍的BasicTransformerBlock
最大的不同就是L522-L540额外定义了一个temporal attention层:
# Temp-Attn
# assert unet_use_temporal_attention is not None
if unet_use_temporal_attention is None:
unet_use_temporal_attention = False
if unet_use_temporal_attention:
self.attn_temp = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
但事实上这个atten_temp
没有用到,我们对比两个UNet中相同位置的attention可以发现是完全相同的,就比如:
print(denoising_unet.down_blocks[1].attentions[0].transformer_blocks[0])
print(reference_unet.down_blocks[1].attentions[0].transformer_blocks[0])
我们再看AudioTemporalBasicTransformerBlock
。首先L691-L701就是注入三个信息的零卷积模块:
zero_conv_full = nn.Conv2d(dim, dim, kernel_size=1)
self.zero_conv_full = zero_module(zero_conv_full)
zero_conv_face = nn.Conv2d(dim, dim, kernel_size=1)
self.zero_conv_face = zero_module(zero_conv_face)
zero_conv_lip = nn.Conv2d(dim, dim, kernel_size=1)
self.zero_conv_lip = zero_module(zero_conv_lip)
第一个attention,norm和最后的FFN不变,而第二个cross attention改为了三个attention:
self.attn2_0 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.attn2_1 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.attn2_2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.attn2 = None
在forward时,首先还是过第一个norm和attention:
norm_hidden_states = self.norm1(hidden_states)
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
而后转到L846,过完第二个norm后分别进行对应的attention,attention的结果乘上对应的mask,再过零卷积。以full为例:
level = self.depth # 该层所在transformer模块的位置,用以判断使用哪个mask
full_hidden_states = (
self.attn2_0(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
) * full_mask[level][:, :, None]
)
# 为了使用2D卷积,将bz * sz * c形状的hidden states展开
bz, sz, c = full_hidden_states.shape
sz_sqrt = int(sz ** 0.5)
full_hidden_states = full_hidden_states.reshape(
bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
full_hidden_states = self.zero_conv_full(full_hidden_states).permute(0, 2, 3, 1).reshape(bz, -1, c)
之后再将得到的三个hidden states和第一个attention后得到的hidden state相加:
hidden_states = (
motion_scale[0] * full_hidden_states +
motion_scale[1] * face_hidden_state +
motion_scale[2] * lip_hidden_state + hidden_states
)
其中motion scale
即是三个部分的权重,默认均为1.0。原文中也展示了一些调节权重的实验。
Temporal Modules
接下来我们研究一下时序模块。在motion_module.py
中,我们以下面的模型实例看一下VanillaTemporalModule
的具体组成:
denoising_unet.down_blocks[1].motion_modules[0]
结构为:
VanillaTemporalModule(
(temporal_transformer): TemporalTransformer3DModel(
(norm): GroupNorm(32, 640, eps=1e-06, affine=True)
(proj_in): Linear(in_features=640, out_features=640, bias=True)
(transformer_blocks): ModuleList(
(0): TemporalTransformerBlock(
(attention_blocks): ModuleList(
(0-1): 2 x VersatileAttention(
(Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
(to_q): Linear(in_features=640, out_features=640, bias=False)
(to_k): Linear(in_features=640, out_features=640, bias=False)
(to_v): Linear(in_features=640, out_features=640, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=640, out_features=640, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
(pos_encoder): PositionalEncoding(
(dropout): Dropout(p=0.0, inplace=False)
)
)
)
(norms): ModuleList(
(0-1): 2 x LayerNorm((640,), eps=1e-05, elementwise_affine=True)
)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): Linear(in_features=640, out_features=5120, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=2560, out_features=640, bias=True)
)
)
(ff_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
)
)
(proj_out): Linear(in_features=640, out_features=640, bias=True)
)
)
可见整体结构上还是和上文介绍的几个attention保持一致。我们直接来看TemporalTransformerBlock
的forward
def forward(self, hidden_states, encoder_hidden_states=None):
assert (
hidden_states.dim() == 5
), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2] # 18
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") # [bf, c, h, w]
batch, _, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1] # channels
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch, height * weight, inner_dim
) # [bf, h*w, c]
hidden_states = self.proj_in(hidden_states)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
video_length=video_length,
)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim)
.permute(0, 3, 1, 2)
.contiguous()
)
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
其中hidden_states
的形状为[B, channel, L, H, W],其中$L=18$为$f=16$与motion frames$=2$之和,而这里的encoder_hidden_states
就是我们在face_animate.py
L294-L298的face embedding,这里只是为了统一接口传进来了,事实并没有用到。而后我们直接转到VersatileAttention
的forward:
def forward(
self,
hidden_states, # [bf, h*w, c]
encoder_hidden_states=None, # None
attention_mask=None, # None
video_length=None, # 18
**cross_attention_kwargs, # empty
):
if self.attention_mode == "Temporal":
d = hidden_states.shape[1] # d means HxW
hidden_states = rearrange(
hidden_states, "(b f) d c -> (b d) f c", f=video_length
) # [b*h*w, f, c]
# add sinusoidal PE
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
encoder_hidden_states = (
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
if encoder_hidden_states is not None
else encoder_hidden_states
)
else:
raise NotImplementedError
hidden_states = self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.attention_mode == "Temporal":
hidden_states = rearrange(
hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
这里先是将$\boldsymbol x\in \mathbb R^{b\times c\times f\times h\times w}$reshape到$(b\times h\times w)\times f\times c$,而后给后两维temporal和channel上加入sinusoidal的位置编码;而后调用父类diffusers.models.attention_processor.Attention
的processor
属性进行注意力计算。
这里Attention类在初始化时可以自定义,否则就是按diffusers/models/attention_processor.py
中L729-L798的AttnProcessor()
进行默认初始化。它只有一个__call__
方法:
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
这里略去了一些norm和形状判断的细节,可见encoder_hidden_states
就是注意力计算中的$K,V$,如果有控制信号的输入就是cross attention;否则就和我们VersatileAttention
这里一样是self attention。
Reference Attention
最后我们介绍一下denoising UNet的self attention和reference attention。由于我们需要将reference UNet的特征按对应层进行cross attention,而且我们还需要保存作为motion frames的特征,层数一多就容易乱,而且会有很多冗长的通信代码。因此Hallo这里额外定义了一个模型,将通信和保存直接写在对应attention类的forward
方法里。我们详细学习一下。
在mutual_self_attention.py
中定义了通信模型ReferenceAttentionControl
。在face_animate.py
中L300-L313初始化:
reference_control_writer = ReferenceAttentionControl(
self.reference_unet,
do_classifier_free_guidance=do_classifier_free_guidance,
mode="write",
batch_size=batch_size,
fusion_blocks="full",
)
reference_control_reader = ReferenceAttentionControl(
self.denoising_unet,
do_classifier_free_guidance=do_classifier_free_guidance,
mode="read",
batch_size=batch_size,
fusion_blocks="full",
)
在L390-L399我们发现ReferenceAttentionControl在初始化时用__get__
将方法hacked_basic_transformer_inner_forward
替换了两个attention layer原来的forward方法:
for i, module in enumerate(attn_modules):
module._original_inner_forward = module.forward
if isinstance(module, BasicTransformerBlock):
module.forward = hacked_basic_transformer_inner_forward.__get__(
module,
BasicTransformerBlock)
if isinstance(module, TemporalBasicTransformerBlock):
module.forward = hacked_basic_transformer_inner_forward.__get__(
module,
TemporalBasicTransformerBlock)
首先我们看reference net中的操作,当MODE='weite'
时:
def hacked_basic_transformer_inner_forward(
...
):
norm_hidden_states = self.norm1(hidden_states)
# 1. Self-Attention
# self.only_cross_attention = False
cross_attention_kwargs = (
cross_attention_kwargs if cross_attention_kwargs is not None else {}
)
if MODE == "write":
# Reference UNet
self.bank.append(norm_hidden_states.clone())
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=(
encoder_hidden_states if self.only_cross_attention else None
),
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
# 2. Cross-Attention
tmp = norm_hidden_states.shape[0] // encoder_hidden_states.shape[0]
attn_output = self.attn2(
norm_hidden_states,
# TODO: repeat这个地方需要斟酌一下
encoder_hidden_states=encoder_hidden_states.repeat(
tmp, 1, 1),
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
return hidden_states
可见就是常规attention,只是第二个cross attention使用face embedding。在L224将norm之后的hidden states存入bank
。reference UNet的forward结束后,在face_animate.py
的L395将bank给denoising UNet的对应层:
reference_control_reader.update(reference_control_writer)
def update(self, writer, dtype=torch.float16):
if self.reference_attn:
if self.fusion_blocks == "midup":
pass
elif self.fusion_blocks == "full":
reader_attn_modules = [
module
for module in torch_dfs(self.unet)
if isinstance(module, TemporalBasicTransformerBlock)
]
writer_attn_modules = [
module
for module in torch_dfs(writer.unet)
if isinstance(module, BasicTransformerBlock)
]
assert len(reader_attn_modules) == len(writer_attn_modules)
reader_attn_modules = sorted(
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
writer_attn_modules = sorted(
writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
for r, w in zip(reader_attn_modules, writer_attn_modules):
r.bank = [v.clone().to(dtype) for v in w.bank]
之后我们再来看denoising UNet的操作:
def hacked_basic_transformer_inner_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
video_length=None,
):
norm_hidden_states = self.norm1(hidden_states)
# 1. Self-Attention
# self.only_cross_attention = False
cross_attention_kwargs = (
cross_attention_kwargs if cross_attention_kwargs is not None else {}
)
if MODE == "read":
# denoising net
bank_fea = [
rearrange(
rearrange(
d,
"(b s) l c -> b s l c",
b=norm_hidden_states.shape[0] // video_length,
)[:, 0, :, :]
# .unsqueeze(1)
.repeat(1, video_length, 1, 1),
"b t l c -> (b t) l c",
)
for d in self.bank
]
motion_frames_fea = [rearrange(
d,
"(b s) l c -> b s l c",
b=norm_hidden_states.shape[0] // video_length,
)[:, 1:, :, :] for d in self.bank]
modify_norm_hidden_states = torch.cat(
[norm_hidden_states] + bank_fea, dim=1
)
hidden_states_uc = (
self.attn1(
norm_hidden_states,
encoder_hidden_states=modify_norm_hidden_states,
attention_mask=attention_mask,
)
+ hidden_states
)
if do_classifier_free_guidance:
hidden_states_c = hidden_states_uc.clone()
_uc_mask = uc_mask.clone()
if hidden_states.shape[0] != _uc_mask.shape[0]:
_uc_mask = (
torch.Tensor(
[1] * (hidden_states.shape[0] // 2)
+ [0] * (hidden_states.shape[0] // 2)
)
.to(device)
.bool()
)
hidden_states_c[_uc_mask] = (
self.attn1(
norm_hidden_states[_uc_mask],
encoder_hidden_states=norm_hidden_states[_uc_mask],
attention_mask=attention_mask,
)
+ hidden_states[_uc_mask]
)
hidden_states = hidden_states_c.clone()
else:
hidden_states = hidden_states_uc
# self.bank.clear()
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(
hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(
hidden_states, "(b f) d c -> (b d) f c", f=video_length
)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm_temp(hidden_states)
)
hidden_states = (
self.attn_temp(norm_hidden_states) + hidden_states
)
hidden_states = rearrange(
hidden_states, "(b d) f c -> (b f) d c", d=d
)
return hidden_states, motion_frames_fea
Hallo这里reference attention是和animate anyone一样的spatial attention:先整理bank内的特征的形状,而后L253-255将hidden states和bank_fea拼接起来;作为K, V和hidden states进行cross attention。之后hidden states再与face embedding CA,就与reference net的操作相同了。