Understanding Neural Radiance Fields (NeRFs)

Imagine being able to generate photorealistic 3D models of objects and scenes that can be viewed from any angle, with details so realistic that they are indistinguishable from reality. That's what the Neural Radiance Fields (NeRF) is capable of doing and much more. With more than 50 papers related to NeRFs in the CVPR 2022, it is one of the most influential papers of all time.


Neural fields

A neural field is a neural network that parametrizes a signal. In our case, this signal is either a single 3D scene or an object. It is important to note that a single network needs to be trained to encode (capture) a single scene. It is worth mentioning that, unlike standard machine learning, the objective is to overfit the neural network to a particular scene. Essentially, neural fields embed the scene into the weights of the network.

3D scenes are usually stored in computer graphics using voxel grids or polygon meshes. However, voxels are costly to store and polygon meshes are limited to hard surfaces. Neural fields are gaining popularity as they are efficient and compact representations of objects or scenes that are differentiable, continuous, and can have arbitrary dimensions and resolutions. Neural radiance fields are a special case of Neural fields that solve the view synthesis problem.



Neural Radiance Fields (NeRFs)

NeRFs as proposed by Mildenhall et al accept a single continuous 5D coordinate as input, which consists of a spatial location $(x, y, z)$ and viewing direction $(\theta, \phi)$. This particular point of the object/scene is fed into an MLP, which outputs the corresponding color intensities $c = (r, g, b)$ and volume density $\sigma$.



The network’s weights are optimized to encode the representation of the scene so that the model can easily render novel views seen from any point in space.



Ray Marching

To gain a better grasp of the different stages in NeRF training, let's use this 3D scene as an instance.

The training dataset includes images ($H \times W$) captured from $n$ different viewpoints of the same scene. Each image's camera viewpoint is stored as $x_c, y_c, z_c$ in the world space. Since the camera is always "aimed" at the object, we only need two more rotation parameters to fully describe the pose: the inclination and azimuth angles ($\theta$, $\phi$).

Now, for each camera pose, we "shoot" a ray from the camera (or viewer's eye), through every pixel of the image, resulting in $H \times W$ rays per pose. Each ray is described by two vectors,

  • $\mathbf{o}$ : a vector denoting the origin of the ray.
  • $\mathbf{d}$ : a normalized vector denoting the direction of the ray.

An arbitrary point on the ray can then be defined as $r(t) = \mathbf{o} + t * \mathbf{d}$.

This process of ray tracing is known as backward tracing, as it involves tracing the path of light rays from the camera to the object, as opposed to tracing from the light source to the object.

Input : A set of camera poses $(x_c, \, y_c, \, z_c,, \, \theta, \, \phi)_n$
Output : A bundle of rays for every pose $(r_{\mathbf{o}, \mathbf{d}})_{H \times W \times n}$



Sampling Query points

You may wonder, what is done with the rays? We trace them from the camera through the scene by adjusting the parameter $t$ until they intersect with some interesting location (object) in the scene. To find these locations, we incrementally step along the ray and sample points at regular intervals.

By querying a trained neural network at these 3D points along the viewing ray, we can determine if they belong to the object volume and obtain their visual properties to render an image. However, sampling points along a ray is challenging, as too many non-object points won't provide useful information, and focusing only on high-density regions may miss interesting areas.

In our toy example, we uniformly sample along the ray by taking $m$ samples. However, to improve performance, the authors use "hierarchical volume sampling" to allocate samples proportionally to their expected impact on the final rendering.

Input : A bundle of rays for every pose $(r_{\mathbf{o}, \mathbf{d}})_{H \times W \times n}$
Output : A set of 3D query points for every ray $(x_p, \, y_p, \, z_p)_{m \times H \times W \times n}$



Positional Encoding

Once we collect the query points for every ray, we are potentially ready to feed them into our neural network. However, the authors found that resulting renderings perform poorly at representing high-frequency variations in color and geometry that make images perceptually sharp and vivid for the human eye.

This observation is consistent with Rahaman et al. who show that deep networks have a tendency to learn lower-frequency functions. They claim that mapping the inputs to a higher dimensional space using high-frequency functions before passing them to the network enables better fitting of data that contains high-frequency variation.

The authors use the positional encoding containing sine and cosine functions of varying frequencies. \begin{equation*} \gamma(p) = [sin (2^0 \pi p) \, cos (2^0 \pi p) \; sin (2 \pi p) \, cos (2 \pi p) \; ... \; sin (2^{L-1} \pi p) \, cos (2^{L-1} \pi p)] \end{equation*} where $L$ is the number of dimensions in the positional encoding.

The function $\gamma(.)$ is applied separately to each of the three coordinate values in $\mathbf{x}$ (which are normalized to lie in [−1, 1]). As viewing direction is also an input to our network, we embed it as well. The authors use $L = 10$ for embeding the 3D query points and $L = 4$ for embeding the viewing direction.

Input : A set of 3D query points for every ray $(\mathbf{x} = x_p, \, y_p, \, z_p)_{m \times H \times W \times n}$, and 3D viewing direction $ (\mathbf{d} = \theta, \phi, \psi)_{n}$
Output : Embedings of query points $\gamma(\mathbf{x})_{m \times H \times W \times n}$ and viewing direction $\gamma(\mathbf{d})_n$



Neural Network inference

To achieve multiview consistency in a neural network, we restrict the network to predict the volume density $\sigma$ as a function of only the location in 3D space, while allowing the RGB color $c$ to be predicted as a function of both location and viewing direction.

The MLP architecture consists of 8 fully-connected layers, each with 256 channels and ReLU activations. A skip connection is included in the network that concatenates the input to the fifth layer's activation. It takes the encoded query points $\gamma(\mathbf{x})$ as input and produces two outputs: the volume density $\sigma$ and a 256-dimensional feature vector.

This feature vector is then concatenated with the embedded viewing direction $\gamma(\mathbf{d})$ and passed through another fully-connected layer with 128 channels and a ReLU activation to produce the view-dependent RGB color $c$ (tuple in the range from 0 to 1). The authors claim that a model trained without view dependence has difficulty representing specularities.

Both of these pieces of information combined allow us to compute the volume density profile for every ray as shown in the figure below.

Example volume density profile of a single ray, output from a trained network that learns to represent the yellow lego bulldozer

Input : Embedings of query points $\gamma(\mathbf{x})_{m \times H \times W \times n}$ and viewing direction $\gamma(\mathbf{d})_n$
Output : RGB color and volume density for every query point $ (c, \, \sigma)_{m \times H \times W \times n}$



Volume Rendering

Now, it's time to turn this volume density profile along a ray into an rgb pixel value. Once we have computed the pixel values for all the rays, we will have a full $H \times W$ image for all the $n$ viewpoints.

The volume rendering process involves computing the accumulated radiance along the viewing ray as it passes through the neural radiance field. This is done by integrating the radiance values at each sampled point along the ray, weighted by the transmittance, or opacity of the medium at that point.

The expected color of a camera ray $r(t) = \mathbf{o} + t * \mathbf{d}$ with near and far bounds $t_n$ and $t_f$ is given as: \begin{equation*} C(r) = \int_{t_n}^{t_f} T(t) \; \sigma(r(t)) \; c(r(t), \mathbf{d}) \; dt \end{equation*} where function $T(t)$ denotes the accumulated transmittance along the ray from $t_n$ to $t$, i.e., the probability that the ray travels from $t_n$ to $t$ without hitting any other particle. \begin{equation*} T(t) = \text{exp} \left( - \int_{t_n}^{t} \sigma(r(s)) \; ds \right) \end{equation*}


These complex-looking integrals can be approximated via numerical quadrature. We use a stratified sampling approach where we select a random set of quadrature points $\{t_i\}_{i = 1}^N$ by uniformly drawing samples from $N$ evenly-spaced ray bins between $t_n$ and $t_f$.

The use of stratified sampling allows for a continuous representation of the scene, despite using a discrete set of samples to estimate the integral. This is because the MLP is evaluated at continuous positions during optimization.

The approximated color of each ray is computed as, \begin{equation*} C(r) \approx \sum_{i = 1}^N T(t_i) \; \alpha_i \; c_i \end{equation*} where $\alpha_i = (1 - \text{exp}(-\sigma_i \, \delta_i))$, which can be viewed as a measure of opacity and $\delta_i = t_{i+1} - t_i$ is the distance between two quadrature points. The accumulated transmittance is then a cumulative product of all the transmittance (which is 1 - opacity) behind it, \begin{equation*} T_i = \prod_{j=1}^{i-1} (1 - \alpha_j) \end{equation*}


Input : RGB color and volume density for every query point $ (c, \, \sigma)_{m \times H \times W \times n}$
Output : Rendered Images $(H \times W)_{n}$



Computing loss

The final step is to calculate the loss between the rendered image and the ground truth for each viewpoint in the dataset. This loss function is a simple L2 loss between each pixel of the rendered image and the corresponding pixel of the ground truth image.





Credits: While writing this blog, I referred to dtransposed's blog post (the images are also from this post), and AI Summer's post.

Comments

Popular posts from this blog

Three Wheeled Omnidirectional Robot : Motion Analysis

The move_base ROS node

Overview of ATmega328P