前言
我们在上一篇文章中介绍了NeRF:
事实上仅仅只是理论,其中还有非常多的技术细节值得讨论。特别是由于笔者对图形学相关知识知之甚少,在看代码的时候困难颇多。因此本文将逐行代码介绍NeRF的PyTorch实现,主要参考为[2]。
坐标变换
这部分参考了[3]-[5]。
在讨论NeRF的代码之前,我们首先需要明确各个坐标所属的坐标系。首先我们来看NeRF的输入部分。我们知道论文中的输入为五维:
$$ input=(x,y,z,\theta,\varphi),\;x,y,z\in\mathbb R,\theta,\varphi\in[-\pi,\pi]\tag{1} $$
这五维是统一在世界坐标系
下的,也即相机相对于世界坐标系的平移(相机(笛卡尔)坐标系
原点在世界坐标系下的坐标$(x,y,z)$),以及相机坐标系相对于世界坐标系的旋转(相机坐标系原点对于世界坐标系$x,\;z$轴的旋转$(\theta,\varphi)$)。但是对于某一条光线上的点$\boldsymbol r(t)=\boldsymbol o+t\boldsymbol d$,如果仍然选择用世界坐标系下表示则过于复杂,我们需要世界坐标系$(X,Y,Z)$到相机坐标系$(X_c,Y_c,Z_c)$的刚体变换。完成这一变换的矩阵通常称为world2camera
矩阵或简称为w2c
,通常表示为:
$$ \boldsymbol T=\left[ \begin{array}{ccc|c} &&&&\\ &\boldsymbol R&&\boldsymbol t \\ &&&&\\\hline 0 &0&0&1 \end{array} \right]_{4\times 4} $$
也即$[X_c,Y_c,Z_c]=\boldsymbol T[X,Y,Z]$。其中$\boldsymbol R$的1至3列分别代表相机坐标系的$x,y,z$轴的方向,而$t$则代表了原点在世界坐标系中的位置。w2c矩阵也被称为相机外参矩阵。相应地,也有实现camera2world
的变换矩阵,简称为c2w
。
有了在相机坐标系中的位置,我们还需要将其变换成像素坐标系(二维)。设相机在$x,y$方向上的焦距为$f_x,f_y$,主点为$p_x,p_y$,则:
$$ \begin{bmatrix} u\\ v \\ 1 \end{bmatrix} =\boldsymbol K \begin{bmatrix} X_c\\ Y_c\\ Z_c \end{bmatrix} =\begin{bmatrix} f_x & 0 & p_x\\ 0 &f_y & p_y\\ 0 & 0 & 1 \end{bmatrix}\begin{bmatrix} X_c\\ Y_c\\ Z_c \end{bmatrix} \tag{2} $$
称$\boldsymbol K$为投影变换矩阵,也称为相机内参矩阵。
数据
现在我们可以开始研究NeRF的实现了。首先我们来看数据的读取与处理。下文以blender的乐高挖掘机为例。
在data/nerf_synthetic/lego
文件夹下可以发现,不仅有图片,还有对应的内外参:transforms_xxx.json
。以测试集为例,形式为:
{
"camera_angle_x": 0.6911112070083618,
"frames": [
{
"file_path": "./test/r_0",
"rotation": 0.031415926535897934,
"transform_matrix": [
[-0.9999999403953552, 0.0, 0.0, 0.0],
[0.0, -0.7341099977493286, 0.6790305972099304, 2.737260103225708],
[0.0, 0.6790306568145752, 0.7341098785400391, 2.959291696548462],
[0.0, 0.0, 0.0, 1.0]
]
},
...
]
}
具体而言,其包括:相机水平场视角camera_angle_x
、图片位置、以及c2w矩阵transform_matrix
。随后使用load_blender.py
读取数据。图片与pose数据分别保存至imgs: [N*H*W*4]
与poses: [N*4*4]
中。
接下来到run_nerf.py
572行继续执行。
i_train, i_val, i_test = i_split
near = 2.
far = 6.
if args.white_bkgd:
images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
else:
images = images[...,:3]
这里几个操作分别为:取出3类数据集对应的下标列表、设置边界框的远近值以及RGBA转为RGB。而后转到611行,hwf分别为高、宽与焦距。最后计算内参矩阵(参见2式):
if K is None:
K = np.array([
[focal, 0, 0.5*W],
[0, focal, 0.5*H],
[0, 0, 1]
])
显然地,这里是以图片左下角作为像素坐标系的原点了(主点为$(W/2, H/2)$)。至此,数据的加载与参数的设置完成。
模型
而后再在640行创建NeRF模型。
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
create_nerf
首先是进行Positional encoding:
def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input' : True,
'input_dims' : 3,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
其中multires, multires_views
便是文章Positional encoding中的$L$。代码中multires
指对$(x,y,z)$的$L$,这里取了10;而multires_views
是指对$(\theta,\varphi)$的$L$,这里取了4。注意这里是将输入$p$与$\gamma(p)$拼接之后作为Positional encoding的结果,也即将$p$变为$[\gamma(p),p]_{2L+1}$。
接下来187-199便是定义了fine与coarse两个模型,我们首先看NeRF类的定义,再结合参数一起看。
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
"""
"""
super(NeRF, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.skips = skips
self.use_viewdirs = use_viewdirs
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in
range(D - 1)])
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)])
### Implementation according to the paper
# self.views_linears = nn.ModuleList(
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
if use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W // 2, 3)
else:
self.output_linear = nn.Linear(W, output_ch)
首先来看参数与初始化,其中W=256
是intermediate dim;input_ch=63
为$(x,y,z)$的Positional encoding维度;input_ch_views=27
是$(\theta,\varphi)$的Positional encoding维度。
而后是一些线性层:
- self.pts_linears:定义了前8层,其中
skip=[4]
也即第五层进行了Positional encoding的连接。 - self.views_linears:定义了第10层,输入为
W+input_ch_views=256_27
,输出是W/2=128
。 - self.alpha_linear与self.rgb_linear的解码层。
之后的前向过程比较简单,这里就略过了。
光线
对于lego数据集我们并不使用batching,因此转到主函数的729行。
else:
# Random from one image
img_i = np.random.choice(i_train)
target = images[img_i]
target = torch.Tensor(target).to(device)
pose = poses[img_i, :3,:4]
if N_rand is not None:
rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
if i < args.precrop_iters:
dH = int(H//2 * args.precrop_frac)
dW = int(W//2 * args.precrop_frac)
coords = torch.stack(
torch.meshgrid(
torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH),
torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
), -1)
if i == start:
print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
else:
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
batch_rays = torch.stack([rays_o, rays_d], 0)
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
注意此时我们已经进训练的loop了,而我们的数据与光线还没开始准备。首先随机选择一张图片,而后通过get_rays得出光线:
def get_rays(H, W, K, c2w):
i, j = torch.meshgrid(torch.linspace(0, W - 1, W),
torch.linspace(0, H - 1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t()
j = j.t()
dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3],
-1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3, -1].expand(rays_d.shape)
return rays_o, rays_d
首先通过meshgrid
生成像素坐标系中个像素点的坐标,而后通过(2)式通过$(u,v)$求解$(X_c,Y_c,Z_c)$:
$$ \begin{bmatrix} X_c \\ Y_c \\ Z_c \end{bmatrix}= \begin{bmatrix} \frac{u-p_x}{f_x} \\ \frac{v-p_y}{f_y} \\ 1 \end{bmatrix}\tag{3} $$
值得注意的是Blender使用了OpenGL的坐标系风格:$x$轴向右,$y$轴向上,$z$轴(纸面)向外。而这里图片相对相机显然是向里的;而meshgrid生成的$j$坐标也是向下的,因此这两个维度需要反向。
最后再使用c2w
矩阵求解像素在世界坐标系中的位置$(X,Y,Z)$。容易知道此处$\boldsymbol o,\; \boldsymbol d$形状均为$(W,H,3)$。
测试
在训练结束后我们需要进行测试。我们在数据处理时以及得出了供后续测试效果制作视频的相机位姿render_poses
:
def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi / 180. * np.pi) @ c2w
c2w = rot_theta(theta / 180. * np.pi) @ c2w
c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w
return c2w
render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0)
为了保证制作的视频的连贯性,我们需要从相机坐标系的极坐标系中取位姿$(r,\theta,\varphi)$。此处固定$\varphi\equiv -30^\circ,\;r\equiv 4$,在水平上生成40个位置:$[-180^\circ,-171^{\circ},\cdots,171^\circ]$。
测试与验证的核心函数是主函数69行定义的render
:
def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
**kwargs):
"""Render rays
Args:
H: int. Height of image in pixels.
W: int. Width of image in pixels.
focal: float. Focal length of pinhole camera.
chunk: int. Maximum number of rays to process simultaneously. Used to
control maximum memory usage. Does not affect final results.
rays: array of shape [2, batch_size, 3]. Ray origin and direction for
each example in batch.
c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
ndc: bool. If True, represent ray origin, direction in NDC coordinates.
near: float or array of shape [batch_size]. Nearest distance for a ray.
far: float or array of shape [batch_size]. Farthest distance for a ray.
use_viewdirs: bool. If True, use viewing direction of a point in space in model.
c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
camera while using other c2w argument for viewing directions.
Returns:
rgb_map: [batch_size, 3]. Predicted RGB values for rays.
disp_map: [batch_size]. Disparity map. Inverse of depth.
acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
extras: dict with everything returned by render_rays().
"""
if c2w is not None:
# special case to render full image
rays_o, rays_d = get_rays(H, W, K, c2w)
else:
# use provided ray batch
rays_o, rays_d = rays
if use_viewdirs:
# provide ray directions as input
viewdirs = rays_d
if c2w_staticcam is not None:
# special case to visualize effect of viewdirs
rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
viewdirs = torch.reshape(viewdirs, [-1,3]).float()
sh = rays_d.shape # [..., 3]
if ndc:
# for forward facing scenes
rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
# Create ray batch
rays_o = torch.reshape(rays_o, [-1,3]).float()
rays_d = torch.reshape(rays_d, [-1,3]).float()
near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
rays = torch.cat([rays_o, rays_d, near, far], -1)
if use_viewdirs:
rays = torch.cat([rays, viewdirs], -1)
# Render and reshape
all_ret = batchify_rays(rays, chunk, **kwargs)
for k in all_ret:
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_sh)
k_extract = ['rgb_map', 'disp_map', 'acc_map']
ret_list = [all_ret[k] for k in k_extract]
ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
return ret_list + [ret_dict]
References
- NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis, ECCV 2020 -NeRF
Ben Mildenhall, Pratul P. Srinivasan, Matthew Tancik, Jonathan T. Barron, Ravi Ramamoorthi, Ren Ng, 2020.3 | [[ECCV pdf]](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123460392.pdf) [[arXiv pdf]](https://arxiv.org/pdf/2003.08934.pdf) [[Project homepage]](https://www.matthewtancik.com/nerf) [[Official implementation (TF)]](https://github.com/bmild/nerf) - https://github.com/yenchenlin/nerf-pytorch
- NeRF源码解读(pytorch实现)
- 为什么NeRF里可以仅凭位置和角度信息经过MLP预测出某点的rgb颜色?
1 条评论
怎么收藏这篇文章?