前言

  我们在上一篇文章中介绍了NeRF:

事实上仅仅只是理论,其中还有非常多的技术细节值得讨论。特别是由于笔者对图形学相关知识知之甚少,在看代码的时候困难颇多。因此本文将逐行代码介绍NeRF的PyTorch实现,主要参考为[2]。

坐标变换

这部分参考了[3]-[5]。

  在讨论NeRF的代码之前,我们首先需要明确各个坐标所属的坐标系。首先我们来看NeRF的输入部分。我们知道论文中的输入为五维:

$$ input=(x,y,z,\theta,\varphi),\;x,y,z\in\mathbb R,\theta,\varphi\in[-\pi,\pi]\tag{1} $$

这五维是统一在世界坐标系下的,也即相机相对于世界坐标系的平移(相机(笛卡尔)坐标系原点在世界坐标系下的坐标$(x,y,z)$),以及相机坐标系相对于世界坐标系的旋转(相机坐标系原点对于世界坐标系$x,\;z$轴的旋转$(\theta,\varphi)$)。但是对于某一条光线上的点$\boldsymbol r(t)=\boldsymbol o+t\boldsymbol d$,如果仍然选择用世界坐标系下表示则过于复杂,我们需要世界坐标系$(X,Y,Z)$到相机坐标系$(X_c,Y_c,Z_c)$的刚体变换。完成这一变换的矩阵通常称为world2camera矩阵或简称为w2c,通常表示为:

$$ \boldsymbol T=\left[ \begin{array}{ccc|c} &&&&\\ &\boldsymbol R&&\boldsymbol t \\ &&&&\\\hline 0 &0&0&1 \end{array} \right]_{4\times 4} $$

也即$[X_c,Y_c,Z_c]=\boldsymbol T[X,Y,Z]$。其中$\boldsymbol R$的1至3列分别代表相机坐标系的$x,y,z$轴的方向,而$t$则代表了原点在世界坐标系中的位置。w2c矩阵也被称为相机外参矩阵。相应地,也有实现camera2world的变换矩阵,简称为c2w
  有了在相机坐标系中的位置,我们还需要将其变换成像素坐标系(二维)。设相机在$x,y$方向上的焦距为$f_x,f_y$,主点为$p_x,p_y$,则:

$$ \begin{bmatrix} u\\ v \\ 1 \end{bmatrix} =\boldsymbol K \begin{bmatrix} X_c\\ Y_c\\ Z_c \end{bmatrix} =\begin{bmatrix} f_x & 0 & p_x\\ 0 &f_y & p_y\\ 0 & 0 & 1 \end{bmatrix}\begin{bmatrix} X_c\\ Y_c\\ Z_c \end{bmatrix} \tag{2} $$

称$\boldsymbol K$为投影变换矩阵,也称为相机内参矩阵

数据

  现在我们可以开始研究NeRF的实现了。首先我们来看数据的读取与处理。下文以blender的乐高挖掘机为例。
  在data/nerf_synthetic/lego文件夹下可以发现,不仅有图片,还有对应的内外参:transforms_xxx.json。以测试集为例,形式为:

{
    "camera_angle_x": 0.6911112070083618,
    "frames": [
        {
            "file_path": "./test/r_0",
            "rotation": 0.031415926535897934,
            "transform_matrix": [
                [-0.9999999403953552, 0.0, 0.0, 0.0],
                [0.0, -0.7341099977493286, 0.6790305972099304, 2.737260103225708],
                [0.0, 0.6790306568145752, 0.7341098785400391, 2.959291696548462],
                [0.0, 0.0, 0.0, 1.0]
            ]
        },
        ...
        ]
}

  具体而言,其包括:相机水平场视角camera_angle_x、图片位置、以及c2w矩阵transform_matrix。随后使用load_blender.py读取数据。图片与pose数据分别保存至imgs: [N*H*W*4]poses: [N*4*4]中。
  接下来到run_nerf.py572行继续执行。

i_train, i_val, i_test = i_split

near = 2.
far = 6.

if args.white_bkgd:
    images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
else:
    images = images[...,:3]

这里几个操作分别为:取出3类数据集对应的下标列表、设置边界框的远近值以及RGBA转为RGB。而后转到611行,hwf分别为高、宽与焦距。最后计算内参矩阵(参见2式):

    if K is None:
        K = np.array([
            [focal, 0, 0.5*W],
            [0, focal, 0.5*H],
            [0, 0, 1]
        ])

显然地,这里是以图片左下角作为像素坐标系的原点了(主点为$(W/2, H/2)$)。至此,数据的加载与参数的设置完成。

模型

  而后再在640行创建NeRF模型。

render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)

create_nerf首先是进行Positional encoding:

def get_embedder(multires, i=0):
    if i == -1:
        return nn.Identity(), 3
  
    embed_kwargs = {
                'include_input' : True,
                'input_dims' : 3,
                'max_freq_log2' : multires-1,
                'num_freqs' : multires,
                'log_sampling' : True,
                'periodic_fns' : [torch.sin, torch.cos],
    }
  
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim

embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
input_ch_views = 0
    embeddirs_fn = None
    if args.use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)

其中multires, multires_views便是文章Positional encoding中的$L$。代码中multires指对$(x,y,z)$的$L$,这里取了10;而multires_views是指对$(\theta,\varphi)$的$L$,这里取了4。注意这里是将输入$p$与$\gamma(p)$拼接之后作为Positional encoding的结果,也即将$p$变为$[\gamma(p),p]_{2L+1}$。
  接下来187-199便是定义了fine与coarse两个模型,我们首先看NeRF类的定义,再结合参数一起看。

class NeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
        """ 
        """
        super(NeRF, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.skips = skips
        self.use_viewdirs = use_viewdirs

        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in
                                        range(D - 1)])

        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)])

        ### Implementation according to the paper
        # self.views_linears = nn.ModuleList(
        #     [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])

        if use_viewdirs:
            self.feature_linear = nn.Linear(W, W)
            self.alpha_linear = nn.Linear(W, 1)
            self.rgb_linear = nn.Linear(W // 2, 3)
        else:
            self.output_linear = nn.Linear(W, output_ch)

首先来看参数与初始化,其中W=256是intermediate dim;input_ch=63为$(x,y,z)$的Positional encoding维度;input_ch_views=27是$(\theta,\varphi)$的Positional encoding维度。
  而后是一些线性层:

NeRF模型结构

  • self.pts_linears:定义了前8层,其中skip=[4]也即第五层进行了Positional encoding的连接。
  • self.views_linears:定义了第10层,输入为W+input_ch_views=256_27,输出是W/2=128
  • self.alpha_linear与self.rgb_linear的解码层。

  之后的前向过程比较简单,这里就略过了。

光线

  对于lego数据集我们并不使用batching,因此转到主函数的729行。

        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]
            target = torch.Tensor(target).to(device)
            pose = poses[img_i, :3,:4]

            if N_rand is not None:
                rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)

                if i < args.precrop_iters:
                    dH = int(H//2 * args.precrop_frac)
                    dW = int(W//2 * args.precrop_frac)
                    coords = torch.stack(
                        torch.meshgrid(
                            torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), 
                            torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
                        ), -1)
                    if i == start:
                        print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")              
                else:
                    coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1)  # (H, W, 2)

                coords = torch.reshape(coords, [-1,2])  # (H * W, 2)
                select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False)  # (N_rand,)
                select_coords = coords[select_inds].long()  # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
                batch_rays = torch.stack([rays_o, rays_d], 0)
                target_s = target[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)

  注意此时我们已经进训练的loop了,而我们的数据与光线还没开始准备。首先随机选择一张图片,而后通过get_rays得出光线:

def get_rays(H, W, K, c2w):
    i, j = torch.meshgrid(torch.linspace(0, W - 1, W),
                          torch.linspace(0, H - 1, H))  # pytorch's meshgrid has indexing='ij'
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3],
                       -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3, -1].expand(rays_d.shape)
    return rays_o, rays_d

  首先通过meshgrid生成像素坐标系中个像素点的坐标,而后通过(2)式通过$(u,v)$求解$(X_c,Y_c,Z_c)$:

$$ \begin{bmatrix} X_c \\ Y_c \\ Z_c \end{bmatrix}= \begin{bmatrix} \frac{u-p_x}{f_x} \\ \frac{v-p_y}{f_y} \\ 1 \end{bmatrix}\tag{3} $$

  值得注意的是Blender使用了OpenGL的坐标系风格:$x$轴向右,$y$轴向上,$z$轴(纸面)向外。而这里图片相对相机显然是向里的;而meshgrid生成的$j$坐标也是向下的,因此这两个维度需要反向。
  最后再使用c2w矩阵求解像素在世界坐标系中的位置$(X,Y,Z)$。容易知道此处$\boldsymbol o,\; \boldsymbol d$形状均为$(W,H,3)$。

测试

  在训练结束后我们需要进行测试。我们在数据处理时以及得出了供后续测试效果制作视频的相机位姿render_poses

def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180. * np.pi) @ c2w
    c2w = rot_theta(theta / 180. * np.pi) @ c2w
    c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w
    return c2w

render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0)

  为了保证制作的视频的连贯性,我们需要从相机坐标系的极坐标系中取位姿$(r,\theta,\varphi)$。此处固定$\varphi\equiv -30^\circ,\;r\equiv 4$,在水平上生成40个位置:$[-180^\circ,-171^{\circ},\cdots,171^\circ]$。
  测试与验证的核心函数是主函数69行定义的render

def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
                  near=0., far=1.,
                  use_viewdirs=False, c2w_staticcam=None,
                  **kwargs):
    """Render rays
    Args:
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: float. Focal length of pinhole camera.
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.
      near: float or array of shape [batch_size]. Nearest distance for a ray.
      far: float or array of shape [batch_size]. Farthest distance for a ray.
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 
       camera while using other c2w argument for viewing directions.
    Returns:
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.
      disp_map: [batch_size]. Disparity map. Inverse of depth.
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
      extras: dict with everything returned by render_rays().
    """
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, K, c2w)
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1,3]).float()

    sh = rays_d.shape # [..., 3]
    if ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    rays_d = torch.reshape(rays_d, [-1,3]).float()

    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)

    # Render and reshape
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]

References

  1. NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis, ECCV 2020 -NeRF
    Ben Mildenhall, Pratul P. Srinivasan, Matthew Tancik, Jonathan T. Barron, Ravi Ramamoorthi, Ren Ng, 2020.3 | [[ECCV pdf]](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123460392.pdf) [[arXiv pdf]](https://arxiv.org/pdf/2003.08934.pdf) [[Project homepage]](https://www.matthewtancik.com/nerf) [[Official implementation (TF)]](https://github.com/bmild/nerf)
  2. https://github.com/yenchenlin/nerf-pytorch
  3. NeRF源码解读(pytorch实现)
  4. 为什么NeRF里可以仅凭位置和角度信息经过MLP预测出某点的rgb颜色?
如果觉得我的文章对你有用,请随意赞赏