diff --git a/README.md b/README.md index ac1e954..88550fe 100644 --- a/README.md +++ b/README.md @@ -107,14 +107,14 @@ dataset = ffn.PixelDataset.create(path_to_image_file, color_space="RGB", ## 3D Datasets This is where the library becomes a bit picky about input data. The -[`RayDataset`](nerf/ray_dataset.py) supports a set format for data, +[`ImageDataset`](nerf/image_dataset.py) supports a set format for data, and we provide several datasets in this format to play with. These datasets are not stored in the repo, but the library will automatically download them to the `data` folder when you first requests them which you can do like so: ```python -dataset = ffn.RayDataset.load("antinous_400.npz", split="train", num_samples=64) +dataset = ffn.ImageDataset.load("antinous_400.npz", split="train", num_samples=64) ``` We recommend you use one of the following (all datasets are provided in 400 and 800 versions): @@ -231,7 +231,7 @@ It will produce the frames of the following video: https://user-images.githubusercontent.com/6894931/142744837-382e13b1-d1cf-4305-870a-b64763c73e54.mp4 -Another way to visualize what the model has learned is toproduce a +Another way to visualize what the model has learned is to produce a voxelization of the model. This is different from the voxel-based volume rendering, in which multiple voxels contribute to a single sample. Rather, it is a sparse octree containing voxels at the places the model has determined are diff --git a/azureml/aml_env.yml b/azureml/aml_env.yml index ee8f378..e698f2c 100644 --- a/azureml/aml_env.yml +++ b/azureml/aml_env.yml @@ -2,21 +2,21 @@ channels: - pytorch - nvidia dependencies: - - python=3.7 - - cudatoolkit=11.1 - - pytorch=1.9.0 - - torchvision=0.10 - - torchaudio=0.9 - - pip=20.2 + - python=3.9 + - cudatoolkit=11.3 + - pytorch=1.12.0 + - torchvision=0.13.0 + - torchaudio=0.12.0 + - pip=21.2.4 - cudnn - pip: - azureml-defaults - azureml-train-core - ffmpeg-python - matplotlib - - numba==0.54.1 - - numpy==1.20.3 - - opencv-python-headless>=4.5.3 + - numba==0.55.2 + - numpy==1.22.4 + - opencv-python-headless>=4.5.5 - progress - requests - scenepic diff --git a/fourier_feature_nets/__init__.py b/fourier_feature_nets/__init__.py index 12271a0..be6bb7a 100644 --- a/fourier_feature_nets/__init__.py +++ b/fourier_feature_nets/__init__.py @@ -8,15 +8,30 @@ MLP, PositionalFourierMLP ) +from .image_dataset import ImageDataset from .nerf_model import NeRF from .octree import OcTree from .pixel_dataset import PixelDataset from .ray_caster import Raycaster -from .ray_dataset import RayData, RayDataset +from .ray_dataset import RayDataset from .ray_sampler import RaySampler, RaySamples from .signal_dataset import SignalDataset -from .utils import ETABar, interpolate_bilinear, load_model, orbit +from .utils import ( + calculate_blend_weights, + ETABar, + exponential_lr_decay, + hemisphere, + interpolate_bilinear, + load_model, + orbit +) from .version import __version__ +from .visualizers import ( + ActivationVisualizer, + ComparisonVisualizer, + EvaluationVisualizer, + OrbitVideoVisualizer +) from .voxels_model import Voxels __all__ = ["__version__", @@ -29,8 +44,12 @@ "FourierFeatureMLP", "PositionalFourierMLP", "GaussianFourierMLP", + "ImageDataset", "Voxels", + "calculate_blend_weights", + "exponential_lr_decay", "interpolate_bilinear", + "hemisphere", "load_model", "orbit", "OcTree", @@ -38,8 +57,12 @@ "Raycaster", "RaySampler", "RaySamples", - "RayData", "RayDataset", "Resolution", "SignalDataset", - "Triangulation"] + "Triangulation", + "ActivationVisualizer", + "ComparisonVisualizer", + "EvaluationVisualizer", + "OrbitVideoVisualizer", + "PatchVisualizer"] diff --git a/fourier_feature_nets/camera_info.py b/fourier_feature_nets/camera_info.py index be2823b..790fce1 100644 --- a/fourier_feature_nets/camera_info.py +++ b/fourier_feature_nets/camera_info.py @@ -34,6 +34,11 @@ def square(self) -> "Resolution": size = min(self.width, self.height) return Resolution(size, size) + @property + def ratio(self) -> float: + """Aspect ratio.""" + return self.width / self.height + class CameraInfo(NamedTuple("CameraInfo", [("name", str), ("resolution", Resolution), @@ -79,6 +84,13 @@ def project(self, positions: np.ndarray) -> np.ndarray: points = points[:, :2] / points[:, 2:3] return points + @property + def fov_y_degrees(self) -> float: + """Y-axis field of view (in degrees) for the camera.""" + fov_y = (0.5 * self.resolution.width) / self.intrinsics[1, 1] + fov_y = 2 * np.arctan(fov_y) + return fov_y * 180 / np.pi + @property def position(self) -> np.ndarray: """Returns the position of the camera in world coordinates.""" diff --git a/fourier_feature_nets/image_dataset.py b/fourier_feature_nets/image_dataset.py new file mode 100644 index 0000000..5e0a9ce --- /dev/null +++ b/fourier_feature_nets/image_dataset.py @@ -0,0 +1,597 @@ +"""Module providing an image dataset for training NeRF models.""" + +import os +from typing import List, Set, Union + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import scenepic as sp +import torch +import torch.nn as nn +from torch.utils.data import Dataset + +from .camera_info import CameraInfo, Resolution +from .ray_dataset import RayDataset +from .ray_sampler import RaySampler, RaySamples +from .utils import download_asset, ETABar, RenderResult + + +class ImageDataset(Dataset, RayDataset): + """Dataset built from images for sampling from rays cast into a volume.""" + + def __init__(self, label: str, images: np.ndarray, bounds: np.ndarray, + cameras: List[CameraInfo], num_samples: int, + include_alpha=True, stratified=False, + opacity_model: nn.Module = None, + batch_size=4096, color_space="RGB", + sparse_size=50, anneal_start=0.2, + num_anneal_steps=0, alpha_weight=0.1): + """Constructor. + + Args: + label (str): Label used to identify this dataset. + images (np.ndarray): Images of the object from each camera + bounds (np.ndarray): Bounds of the render volume defined as a + transform matrix on the unit cube. + cameras (List[CameraInfo]): List of all cameras in the scene + num_samples (int): The number of samples to take per ray + include_alpha (bool): Whether to include alpha if present + stratified (bool, optional): Whether to use stratified random + sampling + opacity_model (nn.Module, optional): Optional model which predicts + opacity in the volume, used + for performing targeted + sampling if provided. Defaults + to None. + batch_size (int, optional): Batch size to use with the opacity + model. Defaults to 4096. + color_space (str, optional): The color space to use. Defaults to + "RGB". + sparse_size (int, optional): The vertical resolution of + the sparse sampling version. + anneal_start (float, optiona): Starting value for the sample space + annealing. Defaults to 0.2. + num_anneal_steps (int, optional): Steps over which to anneal + sampling to the full range of + volume intersection. Defaults + to 0. + alpha_weight (float, optional): weight for the alpha term of the + loss + """ + assert len(images.shape) == 4 + assert len(images) == len(cameras) + assert images.dtype == np.uint8 + + self._color_space = color_space + self._mode = RayDataset.Mode.Full + self.image_height, self.image_width = images.shape[1:3] + self._images = images + self._label = label + self.include_alpha = include_alpha + self._subsample_index = None + self.sampler = RaySampler(bounds, cameras, num_samples, stratified, + opacity_model, batch_size, anneal_start, + num_anneal_steps) + + source_resolution = np.array([self.image_width, self.image_height], + np.float32) + crop_start = source_resolution // 4 + crop_end = source_resolution - crop_start + x_vals = np.arange(self.image_width) + y_vals = np.arange(self.image_height) + points = np.stack(np.meshgrid(x_vals, y_vals), -1) + points = points.reshape(-1, 2) + + inside_crop = (points >= crop_start) & (points < crop_end) + inside_crop = inside_crop.all(-1) + crop_points = np.nonzero(inside_crop)[0] + crop_points = torch.from_numpy(crop_points) + self.crop_rays_per_camera = len(crop_points) + + sparse_points = torch.LongTensor(self._subsample_rays(sparse_size)) + sparse_height = sparse_size + sparse_width = sparse_size * self.image_width // self.image_height + self.sparse_size = sparse_size + self.sparse_resolution = sparse_width, sparse_height + self.sparse_rays_per_camera = len(sparse_points) + + stencil_radius = 8 * min(self.image_width, self.image_height) // 100 + size = 2 * stencil_radius + 1 + element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) + self.dilate_ranges = [] + num_dilate = 0 + + colors = [] + alphas = [] + crop_index = [] + sparse_index = [] + dilate_index = [] + bar = ETABar("Indexing", max=len(images)) + for image in images: + bar.next() + color = image[..., :3] + if color_space == "YCrCb": + color = cv2.cvtColor(color, cv2.COLOR_RGB2YCrCb) + + color = color.astype(np.float32) / 255 + color = color[self.sampler.points[:, 1], + self.sampler.points[:, 0]] + colors.append(torch.from_numpy(color)) + + offset = len(crop_index) * self.sampler.rays_per_camera + if image.shape[-1] == 4: + alpha = image[..., 3].astype(np.float32) / 255 + mask = (alpha > 0).astype(np.uint8) + + alpha = alpha[self.sampler.points[:, 1], + self.sampler.points[:, 0]] + alphas.append(torch.from_numpy(alpha)) + + mask = cv2.dilate(mask, element) + mask = mask[self.sampler.points[:, 1], + self.sampler.points[:, 0]] + dilate_points, = np.nonzero(mask) + dilate_index.append(torch.from_numpy(dilate_points) + offset) + start = num_dilate + end = start + len(dilate_points) + num_dilate = end + self.dilate_ranges.append((start, end)) + + crop_index.append(crop_points + offset) + sparse_index.append(sparse_points + offset) + + bar.finish() + self.crop_index = torch.cat(crop_index) + self.sparse_index = torch.cat(sparse_index) + self.dilate_index = torch.cat(dilate_index) + + if len(alphas) > 0 and include_alpha: + self.alphas = torch.cat(alphas) + self.alpha_weight = alpha_weight + else: + self.alphas = None + self.alpha_weight = 0 + + self.colors = torch.cat(colors) + + @property + def color_space(self) -> str: + """Color space used by the dataset.""" + return self._color_space + + @property + def mode(self) -> RayDataset.Mode: + """Sampling mode of the dataset.""" + return self._mode + + @mode.setter + def mode(self, value: "RayDataset.Mode"): + if value == RayDataset.Mode.Dilate and len(self.dilate_index) == 0: + raise ValueError("Unable to use dilate mode: missing alpha channel") + + self._mode = value + + @property + def subsample_index(self) -> Set[int]: + """Set of pixel indices in an image to sample.""" + return self._subsample_index + + @subsample_index.setter + def subsample_index(self, index: Set[int]): + self._subsample_index = index + + @property + def images(self) -> List[np.ndarray]: + """Dataset images.""" + return self._images + + @property + def label(self) -> str: + """A label for the dataset.""" + return self._label + + @property + def num_cameras(self) -> bool: + """Number of cameras in the dataset.""" + return self.sampler.num_cameras + + @property + def num_samples(self) -> int: + """Number of samples per ray in the dataset.""" + return self.sampler.num_samples + + @property + def cameras(self) -> List[CameraInfo]: + """Camera information.""" + return self.sampler.cameras + + def to_valid(self, idx: List[int]) -> List[int]: + """Filters the list of ray indices to include only valid rays. + + Description: + In this context, a "valid" ray is one which intersects the bounding + volume. + + Args: + idx (List[int]): An index of rays in the dataset. + + Returns: + List[int]: a filtered list of valid rays + """ + return self.sampler.to_valid(idx) + + def loss(self, _: int, rays: RaySamples, render: RenderResult) -> torch.Tensor: + """Compute the dataset loss for the prediction. + + Args: + actual (RaySamples): The rays to render + predicted (RenderResult): The ray rendering result + + Returns: + torch.Tensor: a scalar loss tensor + """ + actual = self.render(rays) + actual = actual.to(render.device) + + color_loss = (actual.color - render.color).square().mean() + if self.alpha_weight > 0 and actual.alpha is not None: + alpha_loss = (actual.alpha - render.alpha).square().mean() + return color_loss + self.alpha_weight * alpha_loss + + return color_loss + + def render(self, samples: RaySamples) -> RenderResult: + """Returns a (ground truth) render of the rays. + + Args: + rays (RaySamples): the rays to render + + Returns: + RenderResult: the ground truth render + """ + color = self.colors[samples.rays] + if self.alphas is None or self.mode == RayDataset.Mode.Dilate: + alpha = None + else: + alpha = self.alphas[samples.rays] + color = torch.where(alpha.unsqueeze(1) > 0, color, + torch.zeros_like(color)) + + return RenderResult(color, alpha, None) + + def index_for_camera(self, camera: int) -> List[int]: + """Returns a pixel index for the camera. + + Description: + This method will take into account special patterns from sampling, + such as sparsity, center cropping, or dilation. + + Args: + camera (int): the camera index + + Returns: + List[int]: index into the rays for this camera + """ + camera_start = camera * self.sampler.rays_per_camera + if self.mode == RayDataset.Mode.Center: + start = camera * self.crop_rays_per_camera + end = start + self.crop_rays_per_camera + idx = self.crop_index[start:end].tolist() + elif self.mode == RayDataset.Mode.Sparse: + start = camera * self.sparse_rays_per_camera + end = start + self.sparse_rays_per_camera + idx = self.sparse_index[start:end].tolist() + elif self.mode == RayDataset.Mode.Dilate: + start, end = self.dilate_ranges[camera] + idx = self.dilate_index[start:end].tolist() + elif self.mode == RayDataset.Mode.Full: + camera_end = camera_start + self.sampler.rays_per_camera + idx = list(range(camera_start, camera_end)) + else: + raise NotImplementedError("Unsupported sampling mode") + + idx = self.sampler.to_valid(idx) + idx = [i - camera_start for i in idx] + return idx + + def rays_for_camera(self, camera: int) -> RaySamples: + """Returns ray samples for the specified camera.""" + if self.mode == RayDataset.Mode.Center: + start = camera * self.crop_rays_per_camera + end = start + self.crop_rays_per_camera + elif self.mode == RayDataset.Mode.Sparse: + start = camera * self.sparse_rays_per_camera + end = start + self.sparse_rays_per_camera + elif self.mode == RayDataset.Mode.Dilate: + start, end = self.dilate_ranges[camera] + elif self.mode == RayDataset.Mode.Full: + start = camera * self.sampler.rays_per_camera + end = start + self.sampler.rays_per_camera + else: + raise NotImplementedError("Unsupported sampling mode") + + return self.get_rays(list(range(start, end)), None) + + def __len__(self) -> int: + """The number of rays in the dataset.""" + if self.mode == RayDataset.Mode.Center: + return len(self.crop_index) + + if self.mode == RayDataset.Mode.Sparse: + return len(self.sparse_index) + + if self.mode == RayDataset.Mode.Dilate: + return len(self.dilate_index) + + if self.mode == RayDataset.Mode.Full: + return len(self.sampler) + + raise NotImplementedError("Unsupported sampling mode") + + def subset(self, cameras: List[int], + num_samples: int, + stratified: bool, + label: str) -> "ImageDataset": + """Returns a subset of this dataset (by camera). + + Args: + cameras (List[int]): List of camera indices + num_samples (int): Number of samples per ray. + resolution (int): Ray sampling resolution + stratified (bool): Whether to use stratified sampling. + Defaults to False. + + Returns: + RayDataset: New dataset with the subset of cameras + """ + return ImageDataset(label, + self.images[cameras], + self.sampler.bounds, + [self.sampler.cameras[i] for i in cameras], + num_samples, + self.include_alpha, + stratified, + self.sampler.opacity_model, + self.sampler.batch_size, + self.color_space, + self.sparse_size, + self.sampler.anneal_start, + self.sampler.num_anneal_steps, + self.alpha_weight) + + def get_rays(self, + idx: Union[List[int], torch.Tensor], + step: int = None) -> RaySamples: + """Returns samples from the selected rays.""" + if torch.is_tensor(idx): + idx = idx.tolist() + + if self.mode == RayDataset.Mode.Center: + idx = self.crop_index[idx].tolist() + elif self.mode == RayDataset.Mode.Sparse: + idx = self.sparse_index[idx].tolist() + elif self.mode == RayDataset.Mode.Dilate: + idx = self.dilate_index[idx].tolist() + + if not isinstance(idx, list): + idx = [idx] + + if self.subsample_index: + idx = [i for i in idx + if i % self.sampler.rays_per_camera in self.subsample_index] + + idx = self.sampler.to_valid(idx) + return self.sampler.sample(idx, step) + + @staticmethod + def load(path: str, split: str, num_samples: int, + include_alpha: bool, stratified: bool, + opacity_model: nn.Module = None, + batch_size=4096, color_space="RGB", + sparse_size=50, anneal_start=0.2, + num_anneal_steps=0) -> "ImageDataset": + """Loads a dataset from an NPZ file. + + Description: + The NPZ file should contain the following elements: + + images: a (NxRxRx[3,4]) tensor of images in RGB(A) format. + bounds: a (4x4) transform from the unit cube to a render volume + intrinsics: a (Nx3x3) tensor of camera intrinsics (projection) + extrinsics: a (Nx4x4) tensor of camera extrinsics (camera to world) + split_counts: a (3) tensor of counts per split in train, val, test + order + + Args: + path (str): path to an NPZ file containing the dataset + split (str): the split to load [train, val, test] + num_samples (int): the number of samples per ray + include_alpha (bool): Whether to include alpha if present + stratified (bool): whether to use stratified sampling. + opacity_model (nn.Module, optional): model that predicts opacity + from 3D position. If the model + predicts more than one value, + the last channel is used. + Defaults to None. + batch_size (int, optional): Batch size to use when sampling the + opacity model. + sparse_size (int, optional): Resolution for sparse sampling. + anneal_start (float, optiona): Starting value for the sample space + annealing. Defaults to 0.2. + num_anneal_steps (int, optional): Steps over which to anneal + sampling to the full range of + volume intersection. Defaults + to 0. + + Returns: + RayDataset: A dataset made from the camera and image data + """ + if not os.path.exists(path): + path = os.path.join(os.path.dirname(__file__), "..", "data", path) + path = os.path.abspath(path) + if not os.path.exists(path): + print("Downloading dataset...") + dataset_name = os.path.basename(path) + success = download_asset(dataset_name, path) + if not success: + print("Unable to download dataset", dataset_name) + return None + + data = np.load(path) + test_end, height, width = data["images"].shape[:3] + split_counts = data["split_counts"] + train_end = split_counts[0] + val_end = train_end + split_counts[1] + + if split == "train": + idx = list(range(train_end)) + elif split == "val": + idx = list(range(train_end, val_end)) + elif split == "test": + idx = list(range(val_end, test_end)) + else: + print("Unrecognized split:", split) + return None + + bounds = data["bounds"] + images = data["images"][idx] + intrinsics = data["intrinsics"][idx] + extrinsics = data["extrinsics"][idx] + + cameras = [CameraInfo.create("{}{:03}".format(split, i), + Resolution(width, height), + intr, extr) + for i, (intr, extr) in enumerate(zip(intrinsics, + extrinsics))] + return ImageDataset(split, images, bounds, cameras, num_samples, + include_alpha, stratified, opacity_model, + batch_size, color_space, sparse_size, + anneal_start, num_anneal_steps) + + def _subsample_rays(self, resolution: int) -> List[int]: + num_x_samples = resolution * self.image_width // self.image_height + num_y_samples = resolution + x_vals = np.linspace(0, self.image_width - 1, num_x_samples) + 0.5 + y_vals = np.linspace(0, self.image_height - 1, num_y_samples) + 0.5 + x_vals, y_vals = np.meshgrid(x_vals.astype(np.int32), + y_vals.astype(np.int32)) + index = y_vals.reshape(-1) * self.image_width + x_vals.reshape(-1) + index = index.tolist() + return index + + def to_scenepic(self) -> sp.Scene: + """Creates a ray sampling visualization ScenePic for the dataset.""" + scene = sp.Scene() + frustums = scene.create_mesh("frustums", layer_id="frustums") + height = 800 + width = height * self.image_width // self.image_height + canvas = scene.create_canvas_3d(width=width, + height=height) + canvas.shading = sp.Shading(sp.Colors.Gray) + + idx = np.arange(len(self.sampler.cameras)) + images = self.images + cameras = self.sampler.cameras + + cmap = plt.get_cmap("jet") + camera_colors = cmap(np.linspace(0, 1, len(cameras)))[:, :3] + image_meshes = [] + bar = ETABar("Plotting cameras", max=self.num_cameras) + thumb_height = 200 + thumb_width = thumb_height * self.image_width // self.image_height + for i, pixels, camera, color in zip(idx, images, + cameras, camera_colors): + bar.next() + camera = camera.to_scenepic() + + image = scene.create_image() + cam_index = self.index_for_camera(i) + pixels = (pixels / 255).astype(np.float32) + pixels = pixels[..., :3].reshape(-1, 3)[cam_index] + pixels = self.to_image(i, pixels) + pixels = cv2.resize(pixels, (thumb_width, thumb_height), + cv2.INTER_AREA) + image.from_numpy(pixels) + mesh = scene.create_mesh(layer_id="images", texture_id=image.image_id, + double_sided=True) + mesh.add_camera_image(camera, depth=0.5) + image_meshes.append(mesh) + + frustums.add_camera_frustum(camera, color, depth=0.5, thickness=0.01) + + bar.finish() + + bar = ETABar("Sampling Rays", max=self.num_cameras) + + bounds = scene.create_mesh("bounds", layer_id="bounds") + bounds.add_cube(sp.Colors.Blue, transform=self.sampler.bounds) + + frame = canvas.create_frame() + frame.add_mesh(frustums) + frame.add_mesh(bounds) + frame.camera = sp.Camera([0, 0, 10], aspect_ratio=width/height) + for mesh in image_meshes: + frame.add_mesh(mesh) + + sampling_mode = self.mode + for cam in idx: + bar.next() + camera = self.sampler.cameras[cam] + + self.mode = RayDataset.Mode.Sparse + index = set(self.index_for_camera(cam)) + self.mode = sampling_mode + index.intersection_update(self.index_for_camera(cam)) + self.mode = RayDataset.Mode.Full + cam_start = cam * self.sampler.rays_per_camera + index = [cam_start + i for i in index] + samples = self.get_rays(index) + render = self.render(samples) + + colors = render.color.unsqueeze(1).expand(-1, self.num_samples, -1) + positions = samples.positions.numpy().reshape(-1, 3) + colors = colors.numpy().copy().reshape(-1, 3) + + if render.alpha is not None: + alphas = render.alpha.unsqueeze(1) + alphas = alphas.expand(-1, self.num_samples) + alphas = alphas.reshape(-1) + empty = (alphas < 0.1).numpy() + else: + empty = np.zeros_like(colors[..., 0]) + + not_empty = np.logical_not(empty) + + samples = scene.create_mesh(layer_id="samples") + samples.add_sphere(sp.Colors.White, transform=sp.Transforms.scale(0.01)) + samples.enable_instancing(positions=positions[not_empty], + colors=colors[not_empty]) + + frame = canvas.create_frame() + + if empty.any(): + empty_samples = scene.create_mesh(layer_id="empty samples") + empty_samples.add_sphere(sp.Colors.Black, + transform=sp.Transforms.scale(0.01)) + empty_samples.enable_instancing(positions=positions[empty], + colors=colors[empty]) + frame.add_mesh(empty_samples) + + frame.camera = camera.to_scenepic() + frame.add_mesh(bounds) + frame.add_mesh(samples) + frame.add_mesh(frustums) + for mesh in image_meshes: + frame.add_mesh(mesh) + + self.mode = sampling_mode + + canvas.set_layer_settings({ + "bounds": {"opacity": 0.25}, + "images": {"opacity": 0.5} + }) + bar.finish() + + scene.framerate = 10 + return scene diff --git a/fourier_feature_nets/octree.py b/fourier_feature_nets/octree.py index c5fdb37..2bd33ac 100644 --- a/fourier_feature_nets/octree.py +++ b/fourier_feature_nets/octree.py @@ -9,7 +9,7 @@ from progress.bar import ChargingBar import trimesh -from .utils import ETABar, download_asset, interpolate_bilinear +from .utils import download_asset, ETABar, interpolate_bilinear Vector = NamedTuple("Vector", [("x", float), ("y", float), ("z", float)]) @@ -468,7 +468,7 @@ def _trace_ray_path(scale: float, node_index: np.ndarray, while _node_contains(current, point): # ...very paranoid about this failure case. # we NEED to leave the current leaf or the algorithm - # will never return. + # will never return. # TOD This would be safer/better/faster with integers. t += 1e-5 point = _cast_ray(ray, t) diff --git a/fourier_feature_nets/ray_caster.py b/fourier_feature_nets/ray_caster.py index 836ed6e..bfb3e92 100644 --- a/fourier_feature_nets/ray_caster.py +++ b/fourier_feature_nets/ray_caster.py @@ -1,9 +1,8 @@ """Module implementing a differentiable volumetric raycaster.""" import copy -import os import time -from typing import NamedTuple, OrderedDict +from typing import List, NamedTuple, OrderedDict import cv2 from matplotlib.pyplot import get_cmap @@ -19,35 +18,32 @@ Run = None print("Unable to import AzureML, running as local experiment") -from .ray_dataset import RayData, RayDataset +from .ray_dataset import RayDataset from .ray_sampler import RaySampler, RaySamples -from .utils import calculate_blend_weights, ETABar, exponential_lr_decay +from .utils import ( + calculate_blend_weights, + ETABar, + exponential_lr_decay, + RenderResult +) +from .visualizers import Visualizer - -RenderResult = NamedTuple("RenderResult", [("color", torch.Tensor), - ("alpha", torch.Tensor), - ("depth", torch.Tensor)]) LogEntry = NamedTuple("LogEntry", [("step", int), ("timestamp", float), ("state", OrderedDict[str, torch.Tensor]), ("train_psnr", float), ("val_psnr", float)]) -class Raycaster(nn.Module): +class Raycaster: """Implementation of a volumetric raycaster.""" - def __init__(self, model: nn.Module, alpha_weight=0.1): + def __init__(self, model: nn.Module): """Constructor. Args: model (nn.Module): The model used to predict color and opacity. - use_view (bool, optional): Whether to pass view information to - the model. Defaults to False. - alpha_weight (float, optional): weight for the alpha term of the - loss """ nn.Module.__init__(self) self.model = model - self._alpha_weight = alpha_weight def render(self, ray_samples: RaySamples, include_depth=False) -> RenderResult: @@ -74,6 +70,9 @@ def render(self, ray_samples: RaySamples, color = torch.sigmoid(color) opacity = F.softplus(opacity) + assert not color.isnan().any() + assert not opacity.isnan().any() + opacity = opacity.squeeze(-1) weights = calculate_blend_weights(ray_samples.t_values, opacity) @@ -93,19 +92,50 @@ def render(self, ray_samples: RaySamples, return RenderResult(output_color, output_alpha, output_depth) - def _loss(self, rays: RayData) -> torch.Tensor: + def _loss(self, step: int, dataset: RayDataset, batch: List[int]) -> torch.Tensor: device = next(self.model.parameters()).device + rays = dataset.get_rays(batch, step) rays = rays.to(device) - colors, alphas, _ = self.render(rays.samples) - color_loss = (colors - rays.colors).square().mean() - if rays.alphas is not None: - alpha_loss = (alphas - rays.alphas).square().mean() - else: - alpha_loss = 0 + render = self.render(rays, True) + return dataset.loss(step, rays, render) + + def batched_render(self, samples: RaySamples, + batch_size: int, include_depth: bool) -> RenderResult: + """Render a set of rays in batches. + + Args: + samples (RaySamples): The ray samples to render + batch_size (int): Number of rays per batch + include_depth (bool): whether to include depth in the render + + Returns: + RenderResult: result of rendering all rays + """ + self.model.eval() + colors = [] + alphas = [] + depths = [] + with torch.no_grad(): + device = next(self.model.parameters()).device + num_rays = len(samples.positions) + for start in range(0, num_rays, batch_size): + end = min(start + batch_size, num_rays) + idx = list(range(start, end)) + batch = samples.subset(idx) + batch = batch.to(device) + pred = self.render(batch, include_depth).numpy() + colors.append(pred.color) + alphas.append(pred.alpha) + if include_depth: + depths.append(pred.depth) - loss = color_loss + self._alpha_weight * alpha_loss - return loss + self.model.train() + return RenderResult( + np.concatenate(colors), + np.concatenate(alphas), + np.concatenate(depths) if include_depth else None + ) def render_image(self, sampler: RaySampler, index: int, @@ -124,24 +154,9 @@ def render_image(self, sampler: RaySampler, np.ndarray: a (H,W,3) RGB image. """ camera = index % sampler.num_cameras - self.model.eval() - with torch.no_grad(): - device = next(self.model.parameters()).device - samples = sampler.rays_for_camera(camera) - num_rays = len(samples.positions) - predicted = [] - for start in range(0, num_rays, batch_size): - end = min(start + batch_size, num_rays) - idx = list(range(start, end)) - batch = samples.subset(idx) - batch = batch.to(device) - pred = self.render(batch, False) - pred_colors = pred.color.cpu().numpy() - predicted.append(pred_colors) - - self.model.train() - predicted = np.concatenate(predicted) - return sampler.to_image(camera, predicted, color_space) + samples = sampler.rays_for_camera(camera) + pred = self.batched_render(samples, batch_size, False) + return sampler.to_image(camera, pred.color, color_space) def render_activations(self, sampler: RaySampler, index: int, @@ -202,110 +217,6 @@ def render_activations(self, sampler: RaySampler, return act_pixels - def _render_act(self, sampler: RaySampler, - index: int, - color_space: str, - batch_size: int, - results_dir: str): - image = self.render_activations(sampler, index, batch_size, color_space) - act_dir = os.path.join(results_dir, "activations") - if not os.path.exists(act_dir): - os.makedirs(act_dir) - - path = os.path.join(act_dir, "frame_{:05d}.png".format(index)) - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - cv2.imwrite(path, image) - - def _render_video(self, sampler: RaySampler, - index: int, - color_space: str, - batch_size: int, - results_dir: str): - image = self.render_image(sampler, index, batch_size, color_space) - video_dir = os.path.join(results_dir, "video") - if not os.path.exists(video_dir): - os.makedirs(video_dir) - - path = os.path.join(video_dir, "frame_{:05d}.png".format(index)) - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - cv2.imwrite(path, image) - - def _render_eval_image(self, dataset: RayDataset, step: int, - batch_size: int, results_dir: str, - index: int): - camera = index % dataset.num_cameras - self.model.eval() - with torch.no_grad(): - device = next(self.model.parameters()).device - image_rays = dataset.rays_for_camera(camera) - num_rays = len(image_rays.samples.positions) - predicted = [] - actual = [] - depth = [] - error = [] - max_depth = 10 - for start in range(0, num_rays, batch_size): - end = min(start + batch_size, num_rays) - idx = list(range(start, end)) - batch_rays = image_rays.subset(idx) - batch_rays = batch_rays.to(device) - pred = self.render(batch_rays.samples, True) - pred_colors = pred.color.cpu().numpy() - act_colors = batch_rays.colors.cpu().numpy() - pred_error = np.square(act_colors - pred_colors).sum(-1) / 3 - if batch_rays.alphas is not None: - pred_alphas = pred.alpha.cpu().numpy() - act_alphas = batch_rays.alphas.cpu().numpy() - pred_error = 3 * pred_error - pred_error += np.square(act_alphas - pred_alphas) - pred_error /= 4 - - predicted.append(pred_colors) - actual.append(act_colors) - depth.append(pred.depth.clamp(0, max_depth).cpu().numpy()) - error.append(pred_error) - - self.model.train() - - cam_index = dataset.index_for_camera(camera) - - width, height = dataset.image_width, dataset.image_height - predicted = np.concatenate(predicted) - predicted_image = dataset.to_image(camera, np.clip(predicted, 0, 1)) - - actual_image = np.zeros((height*width, 3), np.float32) - actual_image[cam_index] = np.concatenate(actual) - actual_image = actual_image.reshape(height, width, 3) - actual_image = (actual_image * 255).astype(np.uint8) - - depth_image = np.zeros(height*width, np.float32) - depth_image[cam_index] = np.concatenate(depth) - depth_image = np.clip(depth_image, 0, max_depth) - depth_image = (depth_image / max_depth).reshape(height, width, 1) - depth_image = (depth_image * 255).astype(np.uint8) - - error_image = np.zeros(height*width, np.float32) - error_image[cam_index] = np.concatenate(error) - error_image = np.sqrt(error_image) - error_image = error_image / error_image.max() - error_image = error_image.reshape(height, width, 1) - error_image = (error_image * 255).astype(np.uint8) - - name = "s{:07}_c{:03}.png".format(step, camera) - image_dir = os.path.join(results_dir, dataset.label) - if not os.path.exists(image_dir): - os.makedirs(image_dir) - - image_path = os.path.join(image_dir, name) - - compare = np.zeros((height*2, width*2, 3), np.uint8) - compare[:height, :width] = predicted_image - compare[height:, :width] = actual_image - compare[:height, width:] = depth_image - compare[height:, width:] = error_image - compare = cv2.cvtColor(compare, cv2.COLOR_RGB2BGR) - cv2.imwrite(image_path, compare) - def _validate(self, dataset: RayDataset, batch_size: int, @@ -316,7 +227,7 @@ def _validate(self, if num_validate_rays < num_rays: val_index = np.linspace(0, num_rays, num_validate_rays, endpoint=False) val_index = val_index.astype(np.int32) - val_index = dataset.sampler.to_valid(val_index.tolist()) + val_index = dataset.to_valid(val_index.tolist()) else: val_index = np.arange(num_rays) @@ -327,8 +238,7 @@ def _validate(self, break batch = val_index[start:start + batch_size] - batch_rays = dataset.get_rays(batch, step) - loss.append(self._loss(batch_rays).item()) + loss.append(self._loss(step, dataset, batch).item()) self.model.train() loss = np.mean(loss) @@ -337,41 +247,31 @@ def _validate(self, def fit(self, train_dataset: RayDataset, val_dataset: RayDataset, - results_dir: str, batch_size: int, learning_rate: float, num_steps: int, - image_interval: int, crop_steps: int, report_interval: int, decay_rate: float, decay_steps: int, weight_decay: float, - video_sampler: RaySampler = None, - act_sampler: RaySampler = None, - disable_aml=False): + visualizers: List[Visualizer], + disable_aml=False) -> List[LogEntry]: """Fits the volumetric model using the raycaster. Args: train_dataset (RayDataset): The train dataset. val_dataset (RayDataset): The validation dataset. - results_dir (str): The output directory for results images. batch_size (int): The ray batch size. learning_rate (float): Initial learning rate for the model. num_steps (int): Number of steps (i.e. batches) to use for training. - image_interval (int): Number of steps between logging and images crop_steps (int): Number of steps to use center-cropped data at the beginning of training. report_interval (int): Frequency for progress reports decay_rate (float): Exponential decay term for the learning rate decay_steps (int): Number of steps over which the exponential decay is compounded. - video_sampler (RaySampler, optional): sampler used to create frames - for a training video. - Defaults to None. - act_sampler (RaySampler, optional): sampler used to create - activation images. - Defaults to None. + visualizers (List[Visualizer]): List of visualizer objects Returns: List[LogEntry]: logging information from the training run @@ -381,9 +281,6 @@ def fit(self, train_dataset: RayDataset, else: run = None - if results_dir and not os.path.exists(results_dir): - os.makedirs(results_dir) - trainval_dataset = train_dataset.sample_cameras(val_dataset.num_cameras, val_dataset.num_samples, False) @@ -394,7 +291,6 @@ def fit(self, train_dataset: RayDataset, start_time = time.time() log = [] epoch = 0 - render_index = 0 dataset_mode = train_dataset.mode if crop_steps: train_dataset.mode = RayDataset.Mode.Center @@ -404,6 +300,13 @@ def fit(self, train_dataset: RayDataset, val_dataset.mode = dataset_mode trainval_dataset.mode = dataset_mode + def render_image(samples: RaySamples, include_depth: bool): + return self.batched_render(samples, batch_size, include_depth) + + def render_act(sampler: RaySampler, camera: int): + return self.render_activations(sampler, camera, batch_size, + train_dataset.color_space) + while step <= num_steps: num_rays = len(train_dataset) index = np.arange(num_rays) @@ -417,10 +320,12 @@ def fit(self, train_dataset: RayDataset, decay_rate, decay_steps) end = min(start + batch_size, num_rays) batch = index[start:end].tolist() - batch_rays = train_dataset.get_rays(batch, step) + optim.zero_grad() - loss = self._loss(batch_rays) + loss = self._loss(step, train_dataset, batch) loss.backward() + torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.1) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1) optim.step() if step < 10 or step % report_interval == 0: @@ -464,30 +369,8 @@ def fit(self, train_dataset: RayDataset, step += 1 break - if results_dir and step % image_interval == 0: - if video_sampler or act_sampler: - if video_sampler: - self._render_video(video_sampler, - render_index, - train_dataset.color_space, - batch_size, - results_dir) - - if act_sampler: - self._render_act(act_sampler, - render_index, - train_dataset.color_space, - batch_size, - results_dir) - else: - self._render_eval_image(val_dataset, step, - batch_size, - results_dir, render_index) - self._render_eval_image(trainval_dataset, step, - batch_size, - results_dir, render_index) - - render_index += 1 + for visualizer in visualizers: + visualizer.visualize(step, render_image, render_act) step += 1 @@ -517,9 +400,9 @@ def to_scenepic(self, dataset: RayDataset, num_cameras=10, scene = sp.Scene() frustums = scene.create_mesh("frustums", layer_id="frustums") - height = 800 - width = dataset.image_width * height / dataset.image_height - canvas = scene.create_canvas_3d(width=width, height=height) + canvas_res = dataset.cameras[0].resolution.scale_to_height(800) + canvas = scene.create_canvas_3d(width=canvas_res.width, + height=canvas_res.height) canvas.shading = sp.Shading(sp.Colors.Gray) cmap = get_cmap("jet") @@ -543,21 +426,20 @@ def to_scenepic(self, dataset: RayDataset, num_cameras=10, bar.finish() - num_x_samples = resolution * dataset.image_width // dataset.image_height - num_y_samples = resolution - x_vals = np.linspace(0, dataset.image_width - 1, num_x_samples) + 0.5 - y_vals = np.linspace(0, dataset.image_height - 1, num_y_samples) + 0.5 + image_res = dataset.cameras[0].resolution + sample_res = image_res.scale_to_height(resolution) + x_vals = np.linspace(0, image_res.width - 1, sample_res.width) + 0.5 + y_vals = np.linspace(0, image_res.height - 1, sample_res.height) + 0.5 x_vals, y_vals = np.meshgrid(x_vals.astype(np.int32), y_vals.astype(np.int32)) - index = y_vals.reshape(-1) * dataset.image_width + x_vals.reshape(-1) + index = y_vals.reshape(-1) * image_res.width + x_vals.reshape(-1) dataset.subsample_index = set(index.tolist()) bar = ETABar("Sampling rays", max=dataset.num_cameras) device = next(self.model.parameters()).device for i, camera in enumerate(dataset.cameras): bar.next() - entry = dataset.rays_for_camera(i) - ray_samples = entry.samples + ray_samples = dataset.rays_for_camera(i) ray_samples = ray_samples.to(device) with torch.no_grad(): diff --git a/fourier_feature_nets/ray_dataset.py b/fourier_feature_nets/ray_dataset.py index a89d310..dc9be11 100644 --- a/fourier_feature_nets/ray_dataset.py +++ b/fourier_feature_nets/ray_dataset.py @@ -1,58 +1,21 @@ -"""Module providing dataset classes for use in training NeRF models.""" +"""Module providing a dataset prototype for use in training NeRF models.""" +from abc import ABC, abstractmethod from enum import Enum -import os -from typing import List, NamedTuple, Union +from typing import List, Set, Union import cv2 -import matplotlib.pyplot as plt import numpy as np import scenepic as sp import torch -import torch.nn as nn -from torch.utils.data import Dataset - -from .camera_info import CameraInfo, Resolution -from .ray_sampler import RaySampler, RaySamples -from .utils import download_asset, ETABar - - -class RayData(NamedTuple("RayData", [("samples", RaySamples), - ("colors", torch.Tensor), - ("alphas", torch.Tensor)])): - """Class representing ray data, with samples, colors and alpha values.""" - - def to(self, *args) -> "RayData": - """Calls torch.to on each tensor in the sample.""" - alphas = None if self.alphas is None else self.alphas.to(*args) - return RayData(self.samples.to(*args), - self.colors.to(*args), - alphas) - - def pin_memory(self) -> "RayData": - """Pins all tensors in preparation for movement to the GPU.""" - alphas = None if self.alphas is None else self.alphas.pin_memory() - return RayData(self.samples.pin_memory(), - self.colors.pin_memory(), - alphas) - - def subset(self, index: List[int]) -> "RayData": - """Selects a subset of the samples.""" - alphas = None if self.alphas is None else self.alphas[index] - return RayData(self.samples.subset(index), - self.colors[index], - alphas) - - def numpy(self) -> "RayData": - """Moves the tensors from pytorch to numpy.""" - alphas = None if self.alphas is None else self.alphas.cpu().numpy() - return RayData(self.samples.numpy(), - self.colors.cpu().numpy(), - alphas) - - -class RayDataset(Dataset): - """Dataset for sampling from rays cast into a volume.""" + +from .camera_info import CameraInfo +from .ray_sampler import RaySamples +from .utils import RenderResult + + +class RayDataset(ABC): + """Prototype for a dataset containing rays.""" class Mode(Enum): """The sampling mode of the dataset.""" @@ -68,190 +31,102 @@ class Mode(Enum): Dilate = 3 """Returns rays from a dilated region around the alpha mask.""" - def __init__(self, label: str, images: np.ndarray, bounds: np.ndarray, - cameras: List[CameraInfo], num_samples: int, - include_alpha=True, stratified=False, - opacity_model: nn.Module = None, - batch_size=4096, color_space="RGB", - sparse_size=50, anneal_start=0.2, - num_anneal_steps=0): - """Constructor. + Patch = 4 + """Returns rays forming distinct patches within the image.""" - Args: - label (str): Label used to identify this dataset. - images (np.ndarray): Images of the object from each camera - bounds (np.ndarray): Bounds of the render volume defined as a - transform matrix on the unit cube. - cameras (List[CameraInfo]): List of all cameras in the scene - num_samples (int): The number of samples to take per ray - include_alpha (bool): Whether to include alpha if present - stratified (bool, optional): Whether to use stratified random - sampling - opacity_model (nn.Module, optional): Optional model which predicts - opacity in the volume, used - for performing targeted - sampling if provided. Defaults - to None. - batch_size (int, optional): Batch size to use with the opacity - model. Defaults to 4096. - color_space (str, optional): The color space to use. Defaults to - "RGB". - sparse_size (int, optional): The vertical resolution of - the sparse sampling version. - anneal_start (float, optiona): Starting value for the sample space - annealing. Defaults to 0.2. - num_anneal_steps (int, optional): Steps over which to anneal - sampling to the full range of - volume intersection. Defaults - to 0. - """ - assert len(images.shape) == 4 - assert len(images) == len(cameras) - assert images.dtype == np.uint8 - - self.color_space = color_space - self._mode = RayDataset.Mode.Full - self.image_height, self.image_width = images.shape[1:3] - self.images = images - self.label = label - self.include_alpha = include_alpha - self.subsample_index = None - self.sampler = RaySampler(bounds, cameras, num_samples, stratified, - opacity_model, batch_size, anneal_start, - num_anneal_steps) - - source_resolution = np.array([self.image_width, self.image_height], - np.float32) - crop_start = source_resolution // 4 - crop_end = source_resolution - crop_start - x_vals = np.arange(self.image_width) - y_vals = np.arange(self.image_height) - points = np.stack(np.meshgrid(x_vals, y_vals), -1) - points = points.reshape(-1, 2) - - inside_crop = (points >= crop_start) & (points < crop_end) - inside_crop = inside_crop.all(-1) - crop_points = np.nonzero(inside_crop)[0] - crop_points = torch.from_numpy(crop_points) - self.crop_rays_per_camera = len(crop_points) - - sparse_points = torch.LongTensor(self._subsample_rays(sparse_size)) - sparse_height = sparse_size - sparse_width = sparse_size * self.image_width // self.image_height - self.sparse_resolution = sparse_width, sparse_height - self.sparse_rays_per_camera = len(sparse_points) - - stencil_radius = 8 * min(self.image_width, self.image_height) // 100 - size = 2 * stencil_radius + 1 - element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) - self.dilate_ranges = [] - num_dilate = 0 - - colors = [] - alphas = [] - crop_index = [] - sparse_index = [] - dilate_index = [] - bar = ETABar("Indexing", max=len(images)) - for image in images: - bar.next() - color = image[..., :3] - if color_space == "YCrCb": - color = cv2.cvtColor(color, cv2.COLOR_RGB2YCrCb) - - color = color.astype(np.float32) / 255 - color = color[self.sampler.points[:, 1], - self.sampler.points[:, 0]] - colors.append(torch.from_numpy(color)) - - offset = len(crop_index) * self.sampler.rays_per_camera - if image.shape[-1] == 4: - alpha = image[..., 3].astype(np.float32) / 255 - mask = (alpha > 0).astype(np.uint8) - - alpha = alpha[self.sampler.points[:, 1], - self.sampler.points[:, 0]] - alphas.append(torch.from_numpy(alpha)) - - mask = cv2.dilate(mask, element) - mask = mask[self.sampler.points[:, 1], - self.sampler.points[:, 0]] - dilate_points, = np.nonzero(mask) - dilate_index.append(torch.from_numpy(dilate_points) + offset) - start = num_dilate - end = start + len(dilate_points) - num_dilate = end - self.dilate_ranges.append((start, end)) - - crop_index.append(crop_points + offset) - sparse_index.append(sparse_points + offset) - - bar.finish() - self.crop_index = torch.cat(crop_index) - self.sparse_index = torch.cat(sparse_index) - self.dilate_index = torch.cat(dilate_index) - - if len(alphas) > 0 and include_alpha: - self.alphas = torch.cat(alphas) - else: - self.alphas = None + @property + @abstractmethod + def num_cameras(self) -> int: + """Number of cameras in the dataset.""" + + @property + @abstractmethod + def num_samples(self) -> int: + """Number of samples per ray in the dataset.""" + + @property + @abstractmethod + def color_space(self) -> str: + """Color space used by the dataset.""" - self.colors = torch.cat(colors) + @property + @abstractmethod + def label(self) -> str: + """A label for the dataset.""" @property - def resolution(self) -> Resolution: - """The resolution of the images.""" - return Resolution(self.image_width, self.image_height) + @abstractmethod + def cameras(self) -> List[CameraInfo]: + """Camera information.""" @property + @abstractmethod + def images(self) -> List[np.ndarray]: + """Dataset images.""" + + @property + @abstractmethod def mode(self) -> "RayDataset.Mode": - """The current sampling mode of the dataset.""" - return self._mode + """Sampling mode of the dataset.""" @mode.setter + @abstractmethod def mode(self, value: "RayDataset.Mode"): - if value == RayDataset.Mode.Dilate and len(self.dilate_index) == 0: - raise ValueError("Unable to use dilate mode: missing alpha channel") + """Sampling mode of the dataset.""" - self._mode = value + @abstractmethod + def rays_for_camera(self, camera: int) -> RaySamples: + """Returns ray samples for the specified camera.""" - def to_image(self, camera: int, colors: np.ndarray) -> np.ndarray: - """Creates an image given the camera and the compute pixel colors. + @property + @abstractmethod + def subsample_index(self) -> Set[int]: + """Set of pixel indices in an image to sample.""" + + @subsample_index.setter + @abstractmethod + def subsample_index(self, index: Set[int]): + """Set of pixel indices in an image to sample.""" + + @abstractmethod + def loss(self, step: int, rays: RaySamples, + render: RenderResult) -> torch.Tensor: + """Compute the dataset loss for the prediction. Args: - camera (int): The camera index. Needed to handle special patterns, - i.e. for Dilate mode. - colors (np.ndarray): The computed colors, one per ray, in the order - returned by the dataset. + actual (RaySamples): The rays to render + predicted (RenderResult): The ray rendering result Returns: - np.ndarray: A (H,W,3) uint8 RGB tensor + torch.Tensor: a scalar loss tensor """ - pixels = np.zeros((self.image_height*self.image_width, 3), np.float32) - index = self.index_for_camera(camera) - pixels[index] = colors - pixels = pixels.reshape(self.image_height, self.image_width, 3) - pixels = (pixels * 255).astype(np.uint8) - if self.color_space == "YCrCb": - pixels = cv2.cvtColor(pixels, cv2.COLOR_YCrCB2RGB) - return pixels + @abstractmethod + def get_rays(self, + idx: Union[List[int], torch.Tensor], + step: int = None) -> RaySamples: + """Returns samples from the selected rays. - @property - def num_cameras(self) -> bool: - """Number of cameras in the dataset.""" - return self.sampler.num_cameras + Args: + idx (Union[List[int], torch.Tensor]): index into the dataset + step (int, optional): Step in optimization. Defaults to None. - @property - def num_samples(self) -> int: - """Number of samples per ray.""" - return self.sampler.num_samples + Returns: + (RaySamples): Returns ray data + """ - @property - def cameras(self) -> List[CameraInfo]: - """Camera information.""" - return self.sampler.cameras + @abstractmethod + def render(self, rays: RaySamples) -> RenderResult: + """Returns a (ground truth) render of the rays. + Args: + rays (RaySamples): the rays to render + + Returns: + RenderResult: the ground truth render + """ + + @abstractmethod def index_for_camera(self, camera: int) -> List[int]: """Returns a pixel index for the camera. @@ -265,87 +140,47 @@ def index_for_camera(self, camera: int) -> List[int]: Returns: List[int]: index into the rays for this camera """ - camera_start = camera * self.sampler.rays_per_camera - if self.mode == RayDataset.Mode.Center: - start = camera * self.crop_rays_per_camera - end = start + self.crop_rays_per_camera - idx = self.crop_index[start:end].tolist() - elif self.mode == RayDataset.Mode.Sparse: - start = camera * self.sparse_rays_per_camera - end = start + self.sparse_rays_per_camera - idx = self.sparse_index[start:end].tolist() - elif self.mode == RayDataset.Mode.Dilate: - start, end = self.dilate_ranges[camera] - idx = self.dilate_index[start:end].tolist() - elif self.mode == RayDataset.Mode.Full: - camera_end = camera_start + self.sampler.rays_per_camera - idx = list(range(camera_start, camera_end)) - else: - raise NotImplementedError("Unsupported sampling mode") - idx = self.sampler.to_valid(idx) - idx = [i - camera_start for i in idx] - return idx + @abstractmethod + def to_valid(self, idx: List[int]) -> List[int]: + """Filters the list of ray indices to include only valid rays. - def rays_for_camera(self, camera: int) -> RayData: - """Returns ray samples for the specified camera.""" - if self.mode == RayDataset.Mode.Center: - start = camera * self.crop_rays_per_camera - end = start + self.crop_rays_per_camera - elif self.mode == RayDataset.Mode.Sparse: - start = camera * self.sparse_rays_per_camera - end = start + self.sparse_rays_per_camera - elif self.mode == RayDataset.Mode.Dilate: - start, end = self.dilate_ranges[camera] - elif self.mode == RayDataset.Mode.Full: - start = camera * self.sampler.rays_per_camera - end = start + self.sampler.rays_per_camera - else: - raise NotImplementedError("Unsupported sampling mode") - - return self.get_rays(list(range(start, end)), None) - - def __len__(self) -> int: - """The number of rays in the dataset.""" - if self.mode == RayDataset.Mode.Center: - return len(self.crop_index) - - if self.mode == RayDataset.Mode.Sparse: - return len(self.sparse_index) - - if self.mode == RayDataset.Mode.Dilate: - return len(self.dilate_index) + Description: + In this context, a "valid" ray is one which intersects the bounding + volume. - if self.mode == RayDataset.Mode.Full: - return len(self.sampler) + Args: + idx (List[int]): An index of rays in the dataset. - raise NotImplementedError("Unsupported sampling mode") + Returns: + List[int]: a filtered list of valid rays + """ - def subset(self, cameras: List[int], - num_samples: int, - stratified: bool) -> "RayDataset": - """Returns a subset of this dataset (by camera). + def to_image(self, camera: int, colors: np.ndarray) -> np.ndarray: + """Creates an image given the camera and the compute pixel colors. Args: - cameras (List[int]): List of camera indices - num_samples (int): Number of samples per ray. - resolution (int): Ray sampling resolution - stratified (bool): Whether to use stratified sampling. - Defaults to False. + camera (int): The camera index. Needed to handle special patterns, + i.e. for Dilate mode. + colors (np.ndarray): The computed colors, one per ray, in the order + returned by the dataset. Returns: - RayDataset: New dataset with the subset of cameras + np.ndarray: A (H,W,3) uint8 RGB tensor """ - return RayDataset(self.label, - self.images[cameras], - self.sampler.bounds, - [self.sampler.cameras[i] for i in cameras], - num_samples, - self.include_alpha, - stratified, - self.sampler.opacity_model, - self.sampler.batch_size, - self.color_space) + if len(colors.shape) == 1: + colors = colors[..., np.newaxis] + + resolution = self.cameras[camera].resolution + pixels = np.zeros((resolution.width*resolution.height, 3), np.float32) + index = self.index_for_camera(camera) + pixels[index] = colors + pixels = pixels.reshape(resolution.height, resolution.width, 3) + pixels = (pixels * 255).astype(np.uint8) + if self._color_space == "YCrCb": + pixels = cv2.cvtColor(pixels, cv2.COLOR_YCrCB2RGB) + + return pixels def sample_cameras(self, num_cameras: int, num_samples: int, @@ -378,251 +213,30 @@ def sample_cameras(self, num_cameras: int, choice = unchosen[distances.argmax()] samples.add(choice) - return self.subset(list(samples), num_samples, stratified) - - def get_rays(self, - idx: Union[List[int], torch.Tensor], - step: int = None) -> RayData: - """Returns samples from the selected rays.""" - if torch.is_tensor(idx): - idx = idx.tolist() - - if self.mode == RayDataset.Mode.Center: - idx = self.crop_index[idx].tolist() - elif self.mode == RayDataset.Mode.Sparse: - idx = self.sparse_index[idx].tolist() - elif self.mode == RayDataset.Mode.Dilate: - idx = self.dilate_index[idx].tolist() - - if not isinstance(idx, list): - idx = [idx] - - if self.subsample_index: - idx = [i for i in idx - if i % self.sampler.rays_per_camera in self.subsample_index] - - idx = self.sampler.to_valid(idx) - samples = self.sampler.sample(idx, step) - colors = self.colors[idx] - if self.alphas is None or self.mode == RayDataset.Mode.Dilate: - alphas = None - else: - alphas = self.alphas[idx] - colors = torch.where(alphas.unsqueeze(1) > 0, colors, - torch.zeros_like(colors)) - - entry = RayData(samples, colors, alphas) - entry = entry.pin_memory() - return entry - - @staticmethod - def load(path: str, split: str, num_samples: int, - include_alpha: bool, stratified: bool, - opacity_model: nn.Module = None, - batch_size=4096, color_space="RGB", - sparse_size=50, anneal_start=0.2, - num_anneal_steps=0) -> "RayDataset": - """Loads a dataset from an NPZ file. + return self.subset(list(samples), num_samples, stratified, self.label) - Description: - The NPZ file should contain the following elements: + @abstractmethod + def __len__(self) -> int: + """The number of rays in the dataset.""" - images: a (NxRxRx[3,4]) tensor of images in RGB(A) format. - bounds: a (4x4) transform from the unit cube to a render volume - intrinsics: a (Nx3x3) tensor of camera intrinsics (projection) - extrinsics: a (Nx4x4) tensor of camera extrinsics (camera to world) - split_counts: a (3) tensor of counts per split in train, val, test - order + @abstractmethod + def subset(self, cameras: List[int], + num_samples: int, + stratified: bool, + label: str) -> "RayDataset": + """Returns a subset of this dataset (by camera). Args: - path (str): path to an NPZ file containing the dataset - split (str): the split to load [train, val, test] - num_samples (int): the number of samples per ray - include_alpha (bool): Whether to include alpha if present - stratified (bool): whether to use stratified sampling. - opacity_model (nn.Module, optional): model that predicts opacity - from 3D position. If the model - predicts more than one value, - the last channel is used. - Defaults to None. - batch_size (int, optional): Batch size to use when sampling the - opacity model. - sparse_size (int, optional): Resolution for sparse sampling. - anneal_start (float, optiona): Starting value for the sample space - annealing. Defaults to 0.2. - num_anneal_steps (int, optional): Steps over which to anneal - sampling to the full range of - volume intersection. Defaults - to 0. + cameras (List[int]): List of camera indices + num_samples (int): Number of samples per ray. + resolution (int): Ray sampling resolution + stratified (bool): Whether to use stratified sampling. + Defaults to False. Returns: - RayDataset: A dataset made from the camera and image data + RayDataset: New dataset with the subset of cameras """ - if not os.path.exists(path): - path = os.path.join(os.path.dirname(__file__), "..", "data", path) - path = os.path.abspath(path) - if not os.path.exists(path): - print("Downloading dataset...") - dataset_name = os.path.basename(path) - success = download_asset(dataset_name, path) - if not success: - print("Unable to download dataset", dataset_name) - return None - - data = np.load(path) - test_end, height, width = data["images"].shape[:3] - split_counts = data["split_counts"] - train_end = split_counts[0] - val_end = train_end + split_counts[1] - - if split == "train": - idx = list(range(train_end)) - elif split == "val": - idx = list(range(train_end, val_end)) - elif split == "test": - idx = list(range(val_end, test_end)) - else: - print("Unrecognized split:", split) - return None - - bounds = data["bounds"] - images = data["images"][idx] - intrinsics = data["intrinsics"][idx] - extrinsics = data["extrinsics"][idx] - - cameras = [CameraInfo.create("{}{:03}".format(split, i), - Resolution(width, height), - intr, extr) - for i, (intr, extr) in enumerate(zip(intrinsics, - extrinsics))] - return RayDataset(split, images, bounds, cameras, num_samples, - include_alpha, stratified, opacity_model, - batch_size, color_space, sparse_size, - anneal_start, num_anneal_steps) - - def _subsample_rays(self, resolution: int) -> List[int]: - num_x_samples = resolution * self.image_width // self.image_height - num_y_samples = resolution - x_vals = np.linspace(0, self.image_width - 1, num_x_samples) + 0.5 - y_vals = np.linspace(0, self.image_height - 1, num_y_samples) + 0.5 - x_vals, y_vals = np.meshgrid(x_vals.astype(np.int32), - y_vals.astype(np.int32)) - index = y_vals.reshape(-1) * self.image_width + x_vals.reshape(-1) - index = index.tolist() - return index + @abstractmethod def to_scenepic(self) -> sp.Scene: """Creates a ray sampling visualization ScenePic for the dataset.""" - scene = sp.Scene() - frustums = scene.create_mesh("frustums", layer_id="frustums") - height = 800 - width = height * self.image_width // self.image_height - canvas = scene.create_canvas_3d(width=width, - height=height) - canvas.shading = sp.Shading(sp.Colors.Gray) - - idx = np.arange(len(self.sampler.cameras)) - images = self.images - cameras = self.sampler.cameras - - cmap = plt.get_cmap("jet") - camera_colors = cmap(np.linspace(0, 1, len(cameras)))[:, :3] - image_meshes = [] - bar = ETABar("Plotting cameras", max=self.num_cameras) - thumb_height = 200 - thumb_width = thumb_height * self.image_width // self.image_height - for i, pixels, camera, color in zip(idx, images, - cameras, camera_colors): - bar.next() - camera = camera.to_scenepic() - - image = scene.create_image() - cam_index = self.index_for_camera(i) - pixels = (pixels / 255).astype(np.float32) - pixels = pixels[..., :3].reshape(-1, 3)[cam_index] - pixels = self.to_image(i, pixels) - pixels = cv2.resize(pixels, (thumb_width, thumb_height), - cv2.INTER_AREA) - image.from_numpy(pixels) - mesh = scene.create_mesh(layer_id="images", texture_id=image.image_id, - double_sided=True) - mesh.add_camera_image(camera, depth=0.5) - image_meshes.append(mesh) - - frustums.add_camera_frustum(camera, color, depth=0.5, thickness=0.01) - - bar.finish() - - bar = ETABar("Sampling Rays", max=self.num_cameras) - - bounds = scene.create_mesh("bounds", layer_id="bounds") - bounds.add_cube(sp.Colors.Blue, transform=self.sampler.bounds) - - frame = canvas.create_frame() - frame.add_mesh(frustums) - frame.add_mesh(bounds) - frame.camera = sp.Camera([0, 0, 10], aspect_ratio=width/height) - for mesh in image_meshes: - frame.add_mesh(mesh) - - sampling_mode = self.mode - for cam in idx: - bar.next() - camera = self.sampler.cameras[cam] - - self.mode = RayDataset.Mode.Sparse - index = set(self.index_for_camera(cam)) - self.mode = sampling_mode - index.intersection_update(self.index_for_camera(cam)) - self.mode = RayDataset.Mode.Full - cam_start = cam * self.sampler.rays_per_camera - index = [cam_start + i for i in index] - entry = self.get_rays(index) - - colors = entry.colors.unsqueeze(1) - colors = colors.expand(-1, self.num_samples, -1) - positions = entry.samples.positions.numpy().reshape(-1, 3) - colors = colors.numpy().copy().reshape(-1, 3) - - if entry.alphas is not None: - alphas = entry.alphas.unsqueeze(1) - alphas = alphas.expand(-1, self.num_samples) - alphas = alphas.reshape(-1) - empty = (alphas < 0.1).numpy() - else: - empty = np.zeros_like(colors[..., 0]) - - not_empty = np.logical_not(empty) - - samples = scene.create_mesh(layer_id="samples") - samples.add_sphere(sp.Colors.White, transform=sp.Transforms.scale(0.01)) - samples.enable_instancing(positions=positions[not_empty], - colors=colors[not_empty]) - - frame = canvas.create_frame() - - if empty.any(): - empty_samples = scene.create_mesh(layer_id="empty samples") - empty_samples.add_sphere(sp.Colors.Black, - transform=sp.Transforms.scale(0.01)) - empty_samples.enable_instancing(positions=positions[empty], - colors=colors[empty]) - frame.add_mesh(empty_samples) - - frame.camera = camera.to_scenepic() - frame.add_mesh(bounds) - frame.add_mesh(samples) - frame.add_mesh(frustums) - for mesh in image_meshes: - frame.add_mesh(mesh) - - self.mode = sampling_mode - - canvas.set_layer_settings({ - "bounds": {"opacity": 0.25}, - "images": {"opacity": 0.5} - }) - bar.finish() - - scene.framerate = 10 - return scene diff --git a/fourier_feature_nets/ray_sampler.py b/fourier_feature_nets/ray_sampler.py index 78110b5..b6c17c0 100644 --- a/fourier_feature_nets/ray_sampler.py +++ b/fourier_feature_nets/ray_sampler.py @@ -14,7 +14,8 @@ class RaySamples(NamedTuple("RaySamples", [("positions", torch.Tensor), ("view_directions", torch.Tensor), - ("t_values", torch.Tensor)])): + ("t_values", torch.Tensor), + ("rays", torch.Tensor)])): """Points samples from rays. Description: @@ -29,6 +30,7 @@ class RaySamples(NamedTuple("RaySamples", [("positions", torch.Tensor), positions: the 3D positions view_directions: the direction from each position back to the camera t_values: the t_values corresponding to the positions + rays: the ray indices Each tensor is grouped by ray, so the first two dimensions will be (num_rays,num_samples). @@ -394,5 +396,8 @@ def sample(self, idx: Union[List[int], torch.Tensor], directions = directions.repeat(1, self.num_samples, 1) positions = starts + t_values.unsqueeze(-1) * directions - ray_samples = RaySamples(positions, directions, t_values) + if not isinstance(idx, torch.Tensor): + idx = torch.LongTensor(idx) + + ray_samples = RaySamples(positions, directions, t_values, idx) return ray_samples diff --git a/fourier_feature_nets/utils.py b/fourier_feature_nets/utils.py index b895159..2175fb4 100644 --- a/fourier_feature_nets/utils.py +++ b/fourier_feature_nets/utils.py @@ -1,8 +1,9 @@ """Utility module.""" import base64 +import math import os -from typing import List +from typing import List, NamedTuple import numpy as np from progress.bar import Bar @@ -302,6 +303,122 @@ def orbit(up_dir: np.ndarray, forward_dir: np.ndarray, num_frames: int, return camera_info +def shuffle_positions(positions: np.ndarray, random=True) -> List[int]: + """Shuffles a list of positions. + + Description: + Shuffles the positions in order such that each subsequent position + is likely to be far from the preceding positions. + + Args: + positions (np.ndarray): the positions in space + random (bool, optional): whether the positions should be chosen at + random (otherwise the farthest away is + always chosen). Defaults to True. + + Returns: + List[int]: a shuffling of the positions + """ + samples = [0] + all_positions = set(range(len(positions))) + while len(samples) < len(all_positions): + sample_positions = positions[samples] + distances = positions[:, None, :] - sample_positions[None, :, :] + distances = np.square(distances).sum(-1).min(-1) + unchosen = np.array(list(all_positions - set(samples))) + if random: + weights = np.array(distances[unchosen], np.float32) + weights = weights / weights.sum() + choice = np.random.choice(unchosen, p=weights) + else: + distances = distances[unchosen] + choice = unchosen[distances.argmax()] + + samples.append(choice) + + return list(samples) + + +def fibonacci_hemisphere(num_samples: int) -> np.ndarray: + """Computes points on a unit hemisphere using the Fibonacci method. + + Args: + num_samples (int): Number of samples from the hemisphere + + Returns: + np.ndarray: a (N,3) tensor of positions on the unit hemisphere + """ + points = [] + phi = math.pi * (3. - math.sqrt(5.)) # golden angle in radians + + for i in range(num_samples): + y = 1 - (i / float(num_samples - 1)) # y goes from 1 to 0 + radius = math.sqrt(1 - y * y) # radius at y + + theta = phi * i # golden angle increment + + x = math.cos(theta) * radius + z = math.sin(theta) * radius + + points.append((x, y, z)) + + points = np.stack(points) + index = shuffle_positions(points) + return points[index] + + +def hemisphere(up_dir: np.ndarray, forward_dir: np.ndarray, num_cameras: int, + fov_y_degrees: float, resolution: Resolution, + distance: float, pos_noise=0.1) -> List[CameraInfo]: + """Generates a random set of evenly placed cameras in a hemisphere. + + Args: + up_dir (np.ndarray): unit vector indicating the "up" direction + forward_dir (np.ndarray): unit vector indicating the "forward" direction + num_cameras (int): number of cameras to sample + fov_y_degrees (float): the y-axis field of view (in degrees) + resolution (Resolution): the resolution of the cameras + distance (float): the mean distance of cameras to the origin + pos_noise (float, optional): the positional noise for cameras. + Defaults to 0.1. + + Returns: + List[CameraInfo]: the sampled cameras + """ + directions = fibonacci_hemisphere(num_cameras) + right_dir = np.cross(up_dir, forward_dir) + + fov_y = fov_y_degrees * np.pi / 180 + focal_length = .5 * resolution.width / np.tan(.5 * fov_y) + + intrinsics = np.array([ + focal_length, 0, resolution.width / 2, + 0, focal_length, resolution.height / 2, + 0, 0, 1 + ], np.float32).reshape(3, 3) + + camera_info = [] + for direction in directions: + position = direction * distance + position += np.random.normal(0, pos_noise, size=3) + distance = np.linalg.norm(position) + azimuth = math.atan2(direction[0], direction[2]) + altitude = math.asin(direction[1]) + pos = sp.Transforms.translate([0, 0, -distance]) + elevate = sp.Transforms.rotation_matrix_from_axis_angle(right_dir, + altitude) + rotate = sp.Transforms.rotation_matrix_from_axis_angle(up_dir, + azimuth) + + extrinsics = rotate @ elevate @ pos + camera = CameraInfo.create("cam{}".format(len(camera_info)), + resolution, + intrinsics, extrinsics) + camera_info.append(camera) + + return camera_info + + def exponential_lr_decay(optim: torch.optim.Adam, initial_learning_rate: float, step: int, decay_rate: float, @@ -384,3 +501,28 @@ def load_model(path: str) -> torch.nn.Module: model.load_state_dict(state_dict) model.eval() return model + + +class RenderResult(NamedTuple("RenderResult", [("color", torch.Tensor), + ("alpha", torch.Tensor), + ("depth", torch.Tensor)])): + """The result of a rendering operation. + + Description: + Contains color, alpha, and (optionally) depth values which have been + rendered on a per-ray basis. + """ + @property + def device(self) -> torch.device: + """The device on which the tensors are stored.""" + return self.color.device + + def to(self, *args) -> "RenderResult": + """Calls torch.to on each tensor in the sample.""" + return RenderResult(*[None if tensor is None else tensor.to(*args) + for tensor in self]) + + def numpy(self) -> "RenderResult": + """Moves all of the tensors from pytorch to numpy.""" + return RenderResult(*[None if tensor is None else tensor.cpu().numpy() + for tensor in self]) diff --git a/fourier_feature_nets/visualizers.py b/fourier_feature_nets/visualizers.py new file mode 100644 index 0000000..582b3d3 --- /dev/null +++ b/fourier_feature_nets/visualizers.py @@ -0,0 +1,263 @@ +"""Module defining various visualizers that can be used during training.""" + +from abc import ABC, abstractmethod +import os +from typing import Callable + +import cv2 +import numpy as np + +from .camera_info import Resolution +from .image_dataset import ImageDataset +from .ray_sampler import RaySampler, RaySamples +from .utils import orbit, RenderResult + +ImageRender = Callable[[RaySamples, bool], RenderResult] +ActivationRender = Callable[[RaySampler, int], np.ndarray] + + +class Visualizer(ABC): + """A visualizer can hook into the training process to produce artifacts.""" + @abstractmethod + def visualize(self, step: int, render: ImageRender, + act_render: ActivationRender): + """Create a visualization using the provided render functions. + + Args: + step (int): Step in the optimization + render (ImageRender): Render function in image space + act_render (ActivationRender): Render function for the activations + """ + + +class EvaluationVisualizer(Visualizer): + """Produces image grids showing GT, prediction, depth, and error.""" + + def __init__(self, results_dir: str, dataset: ImageDataset, interval: int, + max_depth=10): + """Constructor. + + Args: + results_dir (str): the base results directory. + dataset (ImageDataset): the dataset to use as reference. + interval (int): the number of steps between images. + max_depth (int, optional): Value used to clip the depth. + Defaults to 10. + """ + path = os.path.join(results_dir, dataset.label) + os.makedirs(path, exist_ok=True) + self._output_dir = path + self._dataset = dataset + self._interval = interval + self._index = 0 + self._max_depth = max_depth + + def visualize(self, step: int, render: ImageRender, + _: ActivationRender): + """Create a visualization using the provided render functions. + + Args: + step (int): Step in the optimization + render (ImageRender): Render function in image space + act_render (ActivationRender): Render function for the activations + """ + if step % self._interval != 0: + return + + camera = self._index % self._dataset.num_cameras + samples = self._dataset.rays_for_camera(camera) + act = self._dataset.render(samples).numpy() + pred = render(samples, True) + + error = np.square(act.color - pred.color).sum(-1) + if act.alpha is not None: + error = 3 * error + error += np.square(act.alpha - pred.alpha) + error = error / 4 + + width, height = self._dataset.cameras[camera].resolution + predicted = np.clip(pred.color, 0, 1) + predicted_image = self._dataset.to_image(camera, predicted) + + color = act.color * act.alpha[..., np.newaxis] + actual_image = self._dataset.to_image(camera, color) + + depth = np.clip(pred.depth, 0, self._max_depth) / self._max_depth + depth_image = self._dataset.to_image(camera, depth) + + error = np.sqrt(error) + error = error / error.max() + error_image = self._dataset.to_image(camera, error) + + name = "s{:07}_c{:03}.png".format(step, camera) + image_path = os.path.join(self._output_dir, name) + + compare = np.zeros((height*2, width*2, 3), np.uint8) + compare[:height, :width] = predicted_image + compare[height:, :width] = actual_image + compare[:height, width:] = depth_image + compare[height:, width:] = error_image + compare = cv2.cvtColor(compare, cv2.COLOR_RGB2BGR) + cv2.imwrite(image_path, compare) + self._index += 1 + + +class OrbitVideoVisualizer(Visualizer): + """Produces a video where the camera orbits the render volume.""" + + def __init__(self, results_dir: str, num_steps: int, + resolution: Resolution, num_frames: int, + num_samples: int, color_space: str): + """Constructor. + + Args: + results_dir (str): the base results directory + num_steps (int): the number of steps in the training sequence + resolution (Resolution): the resolution of the video + num_frames (int): number of frames in the video + num_samples (int): number of samples per ray + color_space (str): the color space (RGB or YCrCb) + """ + video_dir = os.path.join(results_dir, "video") + os.makedirs(video_dir, exist_ok=True) + self._output_dir = video_dir + cameras = orbit(np.array([0, 1, 0]), np.array([0, 0, -1]), + num_frames, 40, resolution.square(), 4) + bounds = np.eye(4, dtype=np.float32) * 2 + self._sampler = RaySampler(bounds, cameras, num_samples) + self._interval = num_steps // num_frames + self._index = 0 + self._color_space = color_space + + def visualize(self, step: int, render: ImageRender, + _: ActivationRender): + """Create a visualization using the provided render functions. + + Args: + step (int): Step in the optimization + render (ImageRender): Render function in image space + act_render (ActivationRender): Render function for the activations + """ + if step % self._interval != 0: + return + + camera = self._index % self._sampler.num_cameras + samples = self._sampler.rays_for_camera(camera) + pred = render(samples, False) + image = self._sampler.to_image(camera, pred.color, self._color_space) + name = "frame_{:05d}.png".format(self._index) + path = os.path.join(self._output_dir, name) + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + cv2.imwrite(path, image) + self._index += 1 + + +class ActivationVisualizer(Visualizer): + """Creates a video of the layer activations during training.""" + + def __init__(self, results_dir: str, num_steps: int, + resolution: Resolution, num_frames: int, + num_samples: int, color_space: str): + """Constructor. + + Args: + results_dir (str): the base results directory + num_steps (int): the number of steps in the training sequence + resolution (Resolution): the resolution of the video + num_frames (int): number of frames in the video + num_samples (int): number of samples per ray + color_space (str): the color space (RGB or YCrCb) + """ + act_dir = os.path.join(results_dir, "activations") + os.makedirs(act_dir, exist_ok=True) + self._output_dir = act_dir + cameras = orbit(np.array([0, 1, 0]), np.array([0, 0, -1]), + num_frames, 40, resolution.square(), 4) + bounds = np.eye(4, dtype=np.float32) * 2 + self._sampler = RaySampler(bounds, cameras, num_samples) + self._interval = num_steps // num_frames + self._index = 0 + self._color_space = color_space + + def visualize(self, step: int, _: ImageRender, + act_render: ActivationRender): + """Create a visualization using the provided render functions. + + Args: + step (int): Step in the optimization + render (ImageRender): Render function in image space + act_render (ActivationRender): Render function for the activations + """ + if step % self._interval != 0: + return + + image = act_render(self._sampler, self._index) + name = "frame_{:05d}.png".format(self._index) + path = os.path.join(self._output_dir, name) + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + cv2.imwrite(path, image) + self._index += 1 + + +class ComparisonVisualizer(Visualizer): + """This visualizer compares training and validation renders.""" + + def __init__(self, results_dir: str, num_steps: int, + num_frames: int, + train: ImageDataset, val: ImageDataset): + """Constructor. + + Args: + results_dir (str): the base results directory + num_steps (int): the number of steps in the training sequence + num_frames (int): number of frames in the video + train (ImageDataset): training data + val (ImageDataset): validation data + """ + compare_dir = os.path.join(results_dir, "compare") + os.makedirs(compare_dir, exist_ok=True) + assert train.num_cameras == val.num_cameras + self._output_dir = compare_dir + self._train = train + self._val = val + self._interval = num_steps // num_frames + self._index = 0 + + def visualize(self, step: int, render: ImageRender, + _: ActivationRender): + """Create a visualization using the provided render functions. + + Args: + step (int): Step in the optimization + render (ImageRender): Render function in image space + act_render (ActivationRender): Render function for the activations + """ + if step % self._interval != 0: + return + + num_cameras = self._train.num_cameras + resolution = self._train.cameras[0].resolution + width = resolution.width * 4 + height = resolution.height * num_cameras + frame = np.zeros((height, width, 3), np.uint8) + c = [i * resolution.width for i in range(5)] + for camera in range(num_cameras): + r0 = camera * resolution.height + r1 = r0 + resolution.height + samples = self._train.rays_for_camera(camera) + act = self._train.render(samples) + pred = render(samples, False) + frame[r0:r1, c[0]:c[1]] = self._train.to_image(camera, act.color) + frame[r0:r1, c[1]:c[2]] = self._train.to_image(camera, pred.color) + + samples = self._val.rays_for_camera(camera) + act = self._val.render(samples) + pred = render(samples, False) + frame[r0:r1, c[2]:c[3]] = self._val.to_image(camera, act.color) + frame[r0:r1, c[3]:c[4]] = self._val.to_image(camera, pred.color) + + name = "frame_{:05d}.png".format(self._index) + path = os.path.join(self._output_dir, name) + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + cv2.imwrite(path, frame) + self._index += 1 diff --git a/fourier_feature_nets/voxels_model.py b/fourier_feature_nets/voxels_model.py index cc3e800..87ddf9a 100644 --- a/fourier_feature_nets/voxels_model.py +++ b/fourier_feature_nets/voxels_model.py @@ -41,6 +41,7 @@ def forward(self, positions: torch.Tensor) -> torch.Tensor: output = output.transpose(1, 2) output = output.reshape(-1, 4) output = output + self.bias + assert not output.isnan().any() return output def save(self, path: str): diff --git a/lecture_notes.ipynb b/lecture_notes.ipynb index 2c59571..d1e6526 100644 --- a/lecture_notes.ipynb +++ b/lecture_notes.ipynb @@ -12,6 +12,7 @@ "import csv\n", "import os\n", "\n", + "from IPython.display import Video\n", "import cv2\n", "import fourier_feature_nets as ffn\n", "import numpy as np\n", @@ -32,9 +33,7 @@ "metadata": {}, "source": [ "# Fourier Feature Networks and Neural Volume Rendering\n", - "## Lecture Notes\n", - "### Engineering Tripos Part IIB, 4F12: Computer Vision\n", - "### Lecturer: Matthew Johnson\n", + "### Matthew Johnson\n", "\n", "Welcome to this notebook, which is intended to function as interactive lecture notes for my lecture. In it you should be able to run the same experiments, and produce the same visualizations (provided you have access to the necessary compute requirements, i.e. a relatively recent GPU or equivalent cloud resource)." ] @@ -472,6 +471,8 @@ "train_image = image[::2, ::2]\n", "train_ax.imshow(image[::2, ::2], interpolation=None)\n", "train_ax.set_title(\"Training\")\n", + "train_ax.set_xlim(0, 512)\n", + "train_ax.set_ylim(512, 0)\n", "val_ax.imshow(image)\n", "val_ax.set_title(\"Validation\")\n", "fig.tight_layout()\n", @@ -483,7 +484,7 @@ "id": "44f6d369", "metadata": {}, "source": [ - "If it has learned a mapping from $uv$ space to RGB, then it should be able to fill in the missing pixels. Let's see how we do with a standard MLP:" + "If it has learned a mapping from $uv$ space to RGB, then it should be able to upscale and fill in the missing pixels. Let's see how we do with a standard MLP:" ] }, { @@ -708,8 +709,8 @@ "num_samples = 64\n", "paths = [\"antinous_400.npz\", \"benin_400.npz\", \"lego_400.npz\", \"trex_400.npz\", \"matthew_400.npz\", \"rubik_400.npz\"]\n", "choice = 0\n", - "train_dataset = ffn.RayDataset.load(paths[choice], \"train\", num_samples, include_alpha=True, stratified=True)\n", - "val_dataset = ffn.RayDataset.load(paths[choice], \"val\", num_samples, include_alpha=True, stratified=False)" + "train_dataset = ffn.ImageDataset.load(paths[choice], \"train\", num_samples, include_alpha=True, stratified=True)\n", + "val_dataset = ffn.ImageDataset.load(paths[choice], \"val\", num_samples, include_alpha=True, stratified=False)\n" ] }, { @@ -727,7 +728,8 @@ "metadata": {}, "outputs": [], "source": [ - "val_dataset.to_scenepic()" + "demo_dataset = ffn.ImageDataset.load(paths[choice], \"val\", 16, include_alpha=True, stratified=False)\n", + "demo_dataset.to_scenepic()" ] }, { @@ -737,11 +739,6 @@ "source": [ "You can move the camera around with the mouse button. Press R to reset the view. If you press or you can move from camera to camera to see what it sees. For the rest of the controls, see the help message below. You can use the menu in the top right to turn things on and off, for example the blue cube that indicates the bounds of the render volume, the images associated with each camera, or the \"empty\" samples (i.e. those associated with an alpha value of 0).\n", "\n", - "
\n", - "Due to a bug in Jupyter, these animations can sometimes cause issues with GPU usage. If the notebook is getting laggy\n", - "simply choose \"Restart Kernel and Clear Output\" from the \"Kernel\" menu and then close your browser.\n", - "
\n", - "\n", "How do we position these samples within the space like this? The answer comes from a technique called *ray casting*. Let's begin by reviewing some concepts from the course so far." ] }, @@ -752,12 +749,7 @@ "metadata": {}, "outputs": [], "source": [ - "voxels = ffn.OcTree.load(\"antinous_octree_10.npz\")\n", - "\n", - "scene = sp.Scene()\n", - "world_to_camera(scene, voxels, train_dataset.cameras[6], train_dataset.images[6])\n", - "camera_to_world(scene, voxels, train_dataset.cameras[6], train_dataset.images[6])\n", - "scene" + "display(Video(\"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3YvcyFBbld2SzJiNTFuR3FsdkYydkhOSmtTcjJBdWk5eXc_ZT02MEZUSHM/root/content\", width=600, height=300))" ] }, { @@ -882,6 +874,19 @@ "where each $i$ value corresponds to a discrete $t_i$ sample. Here is a graph of what that looks like:" ] }, + { + "cell_type": "markdown", + "id": "0e0b5786", + "metadata": {}, + "source": [ + "Break this down a LOT more:\n", + "\n", + "1. Show how $t_n$ and $t_f$ are computed (i.e. voxel intersection)\n", + "2. Graph showing $T(t)$. Get a patch and show the pixel and the ray's path through the model with a scenepic, and then show the graph below. Color the samples. \n", + "3. \n", + "\n" + ] + }, { "cell_type": "code", "execution_count": null, @@ -921,7 +926,7 @@ "source": [ "The two peaks in the blue line (the $\\sigma$ value) indicate that this ray passes through two objects. However, we can see that the orange line (which tracks the $T$ value) starts at 1 as it passes through empty space, but then decreases to 0 as we reach the middle of the first peak. As a result, the colors from the second object will not contribute much of anything to the final pixel color. If the first peak were lower (indicating that the object was less opaque) then $T$ would decrease more slowly, allowing color from the object behind to contribute more.\n", "\n", - "Let's look at this process a different way. The follow animation will probably take a bit of time to build, but once it is done feel free to watch the full animation and then pause it and explore the space at any point. Remember that you can always press the R key to return to the preset camera position." + "Let's look at this process a different way." ] }, { @@ -931,9 +936,7 @@ "metadata": {}, "outputs": [], "source": [ - "antinous_dataset = ffn.RayDataset.load(\"antinous_400.npz\", \"train\", num_samples, include_alpha=True, stratified=True)\n", - "anim = VolumeRaycastingAnimation(antinous_dataset, voxels, 1000)\n", - "anim.scene" + "display(Video(\"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3YvcyFBbld2SzJiNTFuR3FsdUo2RUVaTmpibU5KTFl0WXc_ZT02N2M1Wjk/root/content\", width=640, height=360))" ] }, { @@ -957,7 +960,7 @@ "metadata": {}, "outputs": [], "source": [ - "voxels_animation(voxels)" + "display(Video(\"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3YvcyFBbld2SzJiNTFuR3FsdkYzWnFGYUtNX2hqWngwWmc_ZT00eHhUeGM/root/content\", width=400, height=400))" ] }, { @@ -978,7 +981,7 @@ "weight_decay = 0.0\n", "num_layers = 3\n", "image_interval = 0\n", - "report_interval = 200\n", + "report_interval = 1000\n", "num_channels = 64\n", "crop_steps = 0\n", "\n", @@ -988,7 +991,6 @@ "model = model.to(device)\n", "\n", "raycaster = ffn.Raycaster(model)\n", - "raycaster = raycaster.to(device)\n", "\n", "def _render_log(log, resolution=ffn.Resolution(200, 200), fov=40, distance=4):\n", " cameras = ffn.orbit(np.array([0, 1, 0], np.float32),\n", @@ -1065,8 +1067,8 @@ "metadata": {}, "outputs": [], "source": [ - "raycaster.fit(train_dataset, val_dataset, None, batch_size, learning_rate, num_steps, image_interval,\n", - " crop_steps, report_interval, decay_rate, decay_steps, weight_decay, disable_aml=True)\n", + "raycaster.fit(train_dataset, val_dataset, batch_size, learning_rate, num_steps,\n", + " crop_steps, report_interval, decay_rate, decay_steps, weight_decay, [], disable_aml=True)\n", "\n", "voxel_renders = _render_orbit()" ] @@ -1122,8 +1124,8 @@ "num_samples = 64\n", "paths = [\"antinous_400.npz\", \"benin_400.npz\", \"lego_400.npz\", \"trex_400.npz\", \"matthew_400.npz\", \"rubik_400.npz\"]\n", "choice = 0\n", - "train_dataset = ffn.RayDataset.load(paths[choice], \"train\", num_samples, include_alpha=True, stratified=True)\n", - "val_dataset = ffn.RayDataset.load(paths[choice], \"val\", num_samples, include_alpha=True, stratified=False)" + "train_dataset = ffn.ImageDataset.load(paths[choice], \"train\", num_samples, include_alpha=True, stratified=True)\n", + "val_dataset = ffn.ImageDataset.load(paths[choice], \"val\", num_samples, include_alpha=True, stratified=False)" ] }, { @@ -1137,8 +1139,8 @@ "model = ffn.MLP(3, 4, num_channels=64)\n", "raycaster.model = model.to(device)\n", "\n", - "log = raycaster.fit(train_dataset, val_dataset, None, batch_size, learning_rate, num_steps, image_interval,\n", - " crop_steps, report_interval, decay_rate, decay_steps, weight_decay, disable_aml=True)\n", + "log = raycaster.fit(train_dataset, val_dataset, batch_size, learning_rate, num_steps,\n", + " crop_steps, report_interval, decay_rate, decay_steps, weight_decay, [], disable_aml=True)\n", "\n", "mlp_renders = _render_log(log)" ] @@ -1178,8 +1180,8 @@ "\n", "model = ffn.PositionalFourierMLP(3, 4, max_log_scale, num_channels=64)\n", "raycaster.model = model.to(device)\n", - "log = raycaster.fit(train_dataset, val_dataset, None, batch_size, learning_rate, num_steps, image_interval,\n", - " crop_steps, report_interval, decay_rate, decay_steps, weight_decay, disable_aml=True)\n", + "log = raycaster.fit(train_dataset, val_dataset, batch_size, learning_rate, num_steps,\n", + " crop_steps, report_interval, decay_rate, decay_steps, weight_decay, [], disable_aml=True)\n", "\n", "pos_renders = _render_log(log)" ] @@ -1249,7 +1251,7 @@ "outputs": [], "source": [ "opacity_model = ffn.load_model(\"antinous_800_vox128.pt\")\n", - "focus_dataset = ffn.RayDataset.load(\"antinous_400.npz\", \"val\", num_samples, include_alpha=True,\n", + "focus_dataset = ffn.ImageDataset.load(\"antinous_400.npz\", \"val\", num_samples, include_alpha=True,\n", " stratified=True, opacity_model=opacity_model)\n", "focus_dataset.to_scenepic()" ] diff --git a/orbit_video.py b/orbit_video.py index b293535..24007f8 100644 --- a/orbit_video.py +++ b/orbit_video.py @@ -72,7 +72,7 @@ def _main(): else: opacity_model = model - raycaster = ffn.Raycaster(model, isinstance(model, ffn.NeRF)) + raycaster = ffn.Raycaster(model) sampler = ffn.RaySampler(bounds_transform, orbit_cameras, args.num_samples, False, opacity_model, args.batch_size) @@ -80,15 +80,18 @@ def _main(): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) + progress = ffn.ETABar("Rendering", max=args.num_frames) with torch.no_grad(): for frame in range(args.num_frames): - print(frame, "/", args.num_frames) + progress.next() image = raycaster.render_image(sampler, frame, args.batch_size) path = os.path.join(args.output_dir, - "frame_{:04d}.png".format(frame)) + "frame_{:05d}.png".format(frame)) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cv2.imwrite(path, image) + progress.finish() + if __name__ == "__main__": _main() diff --git a/submit_aml_run.py b/submit_aml_run.py index 11c6e59..048c810 100644 --- a/submit_aml_run.py +++ b/submit_aml_run.py @@ -27,7 +27,6 @@ def _main(): experiment = Experiment(workspace=ws, name=args.name) env_path = os.path.join("azureml", "aml_env.yml") environment = Environment.from_conda_specification("training", env_path) - environment.environment_variables["AZUREML_COMPUTE_USE_COMMON_RUNTIME"] = "false" config = ScriptRunConfig(source_directory=".", script=args.script_path, arguments=args.script_args.split(), diff --git a/test_ray_sampling.py b/test_ray_sampling.py index 82ca55b..862e9c9 100644 --- a/test_ray_sampling.py +++ b/test_ray_sampling.py @@ -2,7 +2,7 @@ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser -from fourier_feature_nets import load_model, RayDataset +from fourier_feature_nets import load_model, ImageDataset, RayDataset def _parse_args(): @@ -44,10 +44,10 @@ def _main(): else: model = None - dataset = RayDataset.load(args.data_path, args.split, - args.num_samples, True, - args.stratified, model, - args.batch_size, sparse_size=args.resolution) + dataset = ImageDataset.load(args.data_path, args.split, + args.num_samples, True, + args.stratified, model, + args.batch_size, sparse_size=args.resolution) if dataset is None: return 1 diff --git a/train_image_regression.py b/train_image_regression.py index 85f362f..b85154d 100644 --- a/train_image_regression.py +++ b/train_image_regression.py @@ -131,7 +131,14 @@ def _main(): if step % args.report_interval == 0 or step == args.num_steps: with torch.no_grad(): model.eval() - output = torch.sigmoid(model(dataset.val_uv)) + batch_rows = args.image_size // 4 + output = [] + for i in range(4): + start = i * batch_rows + end = start + batch_rows + output.append(model(dataset.val_uv[start:end])) + + output = torch.sigmoid(torch.cat(output)) psnr_val = dataset.psnr(output) print("step", step, "val:", psnr_val, "lr:", optim.param_groups[0]["lr"]) diff --git a/train_nerf.py b/train_nerf.py index 27daed4..187ffb2 100644 --- a/train_nerf.py +++ b/train_nerf.py @@ -5,7 +5,6 @@ import os import fourier_feature_nets as ffn -import numpy as np import torch @@ -67,7 +66,7 @@ def _parse_args(): help="Pytorch compute device") parser.add_argument("--anneal-start", type=float, default=0.2, help="Starting value for the sample space annealing.") - parser.add_argument("--num-anneal-steps", type=int, default=0, + parser.add_argument("--num-anneal-steps", type=int, default=2000, help=("Steps over which to anneal sampling to the full" "range of volume intersection.")) @@ -93,43 +92,53 @@ def _main(): opacity_model = None include_alpha = args.mode == "rgba" - train_dataset = ffn.RayDataset.load(args.data_path, "train", + train_dataset = ffn.ImageDataset.load(args.data_path, "train", + args.num_samples, include_alpha, + True, opacity_model, + args.batch_size, args.color_space, + anneal_start=args.anneal_start, + num_anneal_steps=args.num_anneal_steps) + val_dataset = ffn.ImageDataset.load(args.data_path, "val", args.num_samples, include_alpha, - True, opacity_model, - args.batch_size, args.color_space, - anneal_start=args.anneal_start, - num_anneal_steps=args.num_anneal_steps) - val_dataset = ffn.RayDataset.load(args.data_path, "val", - args.num_samples, include_alpha, - False, opacity_model, - args.batch_size, args.color_space) + False, opacity_model, + args.batch_size, args.color_space) if train_dataset is None: return 1 + visualizers = [] if args.make_video: - cameras = ffn.orbit(np.array([0, 1, 0]), np.array([0, 0, -1]), - args.num_frames, 40, - train_dataset.resolution.square(), 4) - bounds = np.eye(4, dtype=np.float32) * 2 - video_sampler = ffn.RaySampler(bounds, cameras, args.num_samples) - image_interval = args.num_steps // args.num_frames + resolution = train_dataset.cameras[0].resolution + visualizers.append(ffn.OrbitVideoVisualizer( + args.results_dir, + args.num_steps, + resolution, + args.num_frames, + args.num_samples, + args.color_space + )) else: - video_sampler = None - image_interval = args.image_interval + visualizers.append(ffn.EvaluationVisualizer( + args.results_dir, + train_dataset, + args.image_interval + )) + visualizers.append(ffn.EvaluationVisualizer( + args.results_dir, + val_dataset, + args.image_interval + )) if args.mode == "dilate": train_dataset.mode = ffn.RayDataset.Mode.Dilate - raycaster = ffn.Raycaster(model) - raycaster.to(args.device) + raycaster = ffn.Raycaster(model.to(args.device)) - log = raycaster.fit(train_dataset, val_dataset, args.results_dir, + log = raycaster.fit(train_dataset, val_dataset, args.batch_size, args.learning_rate, - args.num_steps, image_interval, - args.crop_steps, args.report_interval, + args.num_steps, args.crop_steps, args.report_interval, args.decay_rate, args.decay_steps, - args.weight_decay, video_sampler) + args.weight_decay, visualizers) model.save(os.path.join(args.results_dir, "nerf.pt")) @@ -138,8 +147,10 @@ def _main(): file.write("\n\n") file.write("\t".join(["step", "timestamp", "psnr_train", "psnr_val"])) file.write("\t") - for line in log: - file.write("\t".join([str(val) for val in line]) + "\n") + for entry in log: + file.write("\t".join([str(val) for val in [ + entry.step, entry.timestamp, entry.train_psnr, entry.val_psnr + ]]) + "\n") sp_path = os.path.join(args.results_dir, "nerf.html") raycaster.to_scenepic(val_dataset).save_as_html(sp_path) diff --git a/train_tiny_nerf.py b/train_tiny_nerf.py index b1241b6..fd049f9 100644 --- a/train_tiny_nerf.py +++ b/train_tiny_nerf.py @@ -5,7 +5,6 @@ import os import fourier_feature_nets as ffn -import numpy as np import torch @@ -60,6 +59,12 @@ def _parse_args(): help="Number of frames in the training video orbit.") parser.add_argument("--device", default="cuda", help="Pytorch compute device") + parser.add_argument("--anneal-start", type=float, default=0.2, + help="Starting value for the sample space annealing.") + parser.add_argument("--num-anneal-steps", type=int, default=2000, + help=("Steps over which to anneal sampling to the full" + "range of volume intersection.")) + return parser.parse_args() @@ -92,53 +97,64 @@ def _main(): opacity_model = None include_alpha = args.mode == "rgba" - train_dataset = ffn.RayDataset.load(args.data_path, "train", + train_dataset = ffn.ImageDataset.load(args.data_path, "train", + args.num_samples, include_alpha, + True, opacity_model, + args.batch_size, args.color_space, + anneal_start=args.anneal_start, + num_anneal_steps=args.num_anneal_steps) + val_dataset = ffn.ImageDataset.load(args.data_path, "val", args.num_samples, include_alpha, - True, opacity_model, + False, opacity_model, args.batch_size, args.color_space) - val_dataset = ffn.RayDataset.load(args.data_path, "val", - args.num_samples, include_alpha, - False, opacity_model, - args.batch_size, args.color_space) if train_dataset is None: return 1 + visualizers = [] if args.make_video: - cameras = ffn.orbit(np.array([0, 1, 0]), np.array([0, 0, -1]), - args.num_frames, 40, - ffn.Resolution(512, 512), 4) - bounds = np.eye(4, dtype=np.float32) * 2 - video_sampler = ffn.RaySampler(bounds, cameras, args.num_samples) + resolution = train_dataset.cameras[0].resolution + visualizers.append(ffn.OrbitVideoVisualizer( + args.results_dir, + args.num_steps, + resolution, + args.num_frames, + args.num_samples, + args.color_space + )) else: - video_sampler = None - image_interval = args.image_interval + visualizers.append(ffn.EvaluationVisualizer( + args.results_dir, + train_dataset, + args.image_interval + )) + visualizers.append(ffn.EvaluationVisualizer( + args.results_dir, + val_dataset, + args.image_interval + )) if args.make_activations: - cameras = ffn.orbit(np.array([0, 1, 0]), np.array([0, 0, -1]), - args.num_frames, 40, ffn.Resolution(64, 64), 4) - bounds = np.eye(4, dtype=np.float32) * 2 - act_sampler = ffn.RaySampler(bounds, cameras, args.num_samples) - else: - act_sampler = None - - if args.make_video or args.make_activations: - image_interval = args.num_steps // args.num_frames - else: - image_interval = args.image_interval + resolution = train_dataset.cameras[0].resolution + visualizers.append(ffn.ActivationVisualizer( + args.results_dir, + args.num_steps, + resolution, + args.num_frames, + args.num_samples, + args.color_space + )) if args.mode == "dilate": train_dataset.mode = ffn.RayDataset.Mode.Dilate - raycaster = ffn.Raycaster(model) - raycaster.to(args.device) + raycaster = ffn.Raycaster(model.to(args.device)) - log = raycaster.fit(train_dataset, val_dataset, args.results_dir, + log = raycaster.fit(train_dataset, val_dataset, args.batch_size, args.learning_rate, - args.num_steps, image_interval, - args.crop_steps, args.report_interval, + args.num_steps, args.crop_steps, args.report_interval, args.decay_rate, args.decay_steps, - args.weight_decay, video_sampler, act_sampler) + args.weight_decay, visualizers) model.save(os.path.join(args.results_dir, "tiny_nerf.pt")) @@ -147,8 +163,10 @@ def _main(): file.write("\n\n") file.write("\t".join(["step", "timestamp", "psnr_train", "psnr_val"])) file.write("\t") - for line in log: - file.write("\t".join([str(val) for val in line]) + "\n") + for entry in log: + file.write("\t".join([str(val) for val in [ + entry.step, entry.timestamp, entry.train_psnr, entry.val_psnr + ]]) + "\n") sp_path = os.path.join(args.results_dir, "tiny_nerf.html") raycaster.to_scenepic(val_dataset).save_as_html(sp_path) diff --git a/train_voxels.py b/train_voxels.py index b4609aa..d76ea5a 100644 --- a/train_voxels.py +++ b/train_voxels.py @@ -5,7 +5,6 @@ import os import fourier_feature_nets as ffn -import numpy as np import torch @@ -47,7 +46,7 @@ def _parse_args(): help="Pytorch compute device") parser.add_argument("--anneal-start", type=float, default=0.2, help="Starting value for the sample space annealing.") - parser.add_argument("--num-anneal-steps", type=int, default=0, + parser.add_argument("--num-anneal-steps", type=int, default=2000, help=("Steps over which to anneal sampling to the full" "range of volume intersection.")) @@ -60,52 +59,64 @@ def _main(): torch.manual_seed(args.seed) include_alpha = args.mode == "rgba" - train_dataset = ffn.RayDataset.load(args.data_path, "train", + train_dataset = ffn.ImageDataset.load(args.data_path, "train", + args.num_samples, include_alpha, + True, color_space=args.color_space, + anneal_start=args.anneal_start, + num_anneal_steps=args.num_anneal_steps) + val_dataset = ffn.ImageDataset.load(args.data_path, "val", args.num_samples, include_alpha, - True, color_space=args.color_space, - anneal_start=args.anneal_start, - num_anneal_steps=args.num_anneal_steps) - val_dataset = ffn.RayDataset.load(args.data_path, "val", - args.num_samples, include_alpha, - False, color_space=args.color_space) - - if args.make_video: - cameras = ffn.orbit(np.array([0, 1, 0]), np.array([0, 0, -1]), - args.num_frames, 40, - train_dataset.resolution.square(), 4) - bounds = np.eye(4, dtype=np.float32) * 2 - video_sampler = ffn.RaySampler(bounds, cameras, args.num_samples) - image_interval = args.num_steps // args.num_frames - else: - video_sampler = None - image_interval = args.image_interval + False, color_space=args.color_space) if train_dataset is None: return 1 + visualizers = [] + if args.make_video: + resolution = train_dataset.cameras[0].resolution + visualizers.append(ffn.OrbitVideoVisualizer( + args.results_dir, + args.num_steps, + resolution, + args.num_frames, + args.num_samples, + args.color_space + )) + else: + visualizers.append(ffn.EvaluationVisualizer( + args.results_dir, + train_dataset, + args.image_interval + )) + visualizers.append(ffn.EvaluationVisualizer( + args.results_dir, + val_dataset, + args.image_interval + )) + if args.mode == "dilate": train_dataset.mode = ffn.RayDataset.Mode.Dilate scale = 2 / train_dataset.sampler.bounds[0, 0] model = ffn.Voxels(args.side, scale) - raycaster = ffn.Raycaster(model) - raycaster.to(args.device) + raycaster = ffn.Raycaster(model.to(args.device)) - log = raycaster.fit(train_dataset, val_dataset, args.results_dir, - args.batch_size, args.learning_rate, - args.num_steps, image_interval, 0, + log = raycaster.fit(train_dataset, val_dataset, args.batch_size, + args.learning_rate, args.num_steps, 0, args.report_interval, args.decay_rate, args.decay_steps, - 0.0, video_sampler) + 0.0, visualizers) model.save(os.path.join(args.results_dir, "voxels.pt")) with open(os.path.join(args.results_dir, "log.txt"), "w") as file: json.dump(vars(args), file) file.write("\n\n") file.write("\t".join(["step", "timestamp", "psnr_train", "psnr_val"])) - file.write("\t") - for line in log: - file.write("\t".join([str(val) for val in line]) + "\n") + file.write("\n") + for entry in log: + file.write("\t".join([str(val) for val in [ + entry.step, entry.timestamp, entry.train_psnr, entry.val_psnr + ]]) + "\n") sp_path = os.path.join(args.results_dir, "voxels.html") raycaster.to_scenepic(val_dataset).save_as_html(sp_path) diff --git a/visualizations/camera_to_world.py b/visualizations/camera_to_world.py index 84cec18..9760b1e 100644 --- a/visualizations/camera_to_world.py +++ b/visualizations/camera_to_world.py @@ -187,7 +187,7 @@ def _add_meshes(frame: sp.Frame3D, camera_transform: np.ndarray, if __name__ == "__main__": - dataset = ffn.RayDataset.load("antinous_400.npz", "train", 64, True, False) + dataset = ffn.ImageDataset.load("antinous_400.npz", "train", 64, True, False) voxels = ffn.OcTree.load("antinous_octree_10.npz") scene = sp.Scene() camera_to_world(scene, voxels, dataset.cameras[6], dataset.images[6], 800) diff --git a/visualizations/ray_cube_intersection.py b/visualizations/ray_cube_intersection.py new file mode 100644 index 0000000..e33138c --- /dev/null +++ b/visualizations/ray_cube_intersection.py @@ -0,0 +1,165 @@ +from typing import List, NamedTuple, Tuple + +import fourier_feature_nets as ffn +import numpy as np +import scenepic as sp + + +class Ray(NamedTuple("Ray", [("x", float), ("y", float), ("z", float), + ("dx", float), ("dy", float), ("dz", float)])): + def cast(self, t: float) -> np.ndarray: + x = self.x + t * self.dx + y = self.y + t * self.dy + z = self.z + t * self.dz + return np.array([x, y, z], np.float32) + + +Intersection = NamedTuple("Intersection", [("enter", float), ("exit", float)]) + + +def _in_order(a: float, b: float) -> Tuple[float, float]: + if b < a: + return b, a + + return a, b + + +def _near_far(coord_diff: float, ray_dir: float): + near = (coord_diff - 1) / ray_dir + far = (coord_diff + 1) / ray_dir + return _in_order(near, far) + + +def _min(x: float, y: float, z: float) -> Tuple[float, int]: + if x < y: + if x < z: + return x, 0 + else: + if y < z: + return y, 1 + + return z, 2 + + +def _max(x: float, y: float, z: float) -> Tuple[float, int]: + if x > y: + if x > z: + return x, 0 + else: + if y > z: + return y, 1 + + return z, 2 + + +def _intersect_cube_with_ray(ray: Ray) -> List[Intersection]: + x0, x1 = _near_far(-ray.x, ray.dx) + y0, y1 = _near_far(-ray.y, ray.dy) + z0, z1 = _near_far(-ray.z, ray.dz) + + return [Intersection(x0, x1), + Intersection(y0, y1), + Intersection(z0, z1)] + + +def _random_point() -> np.ndarray: + point = np.random.random_sample(size=3) + 1 + sign = np.sign(np.random.random_sample(size=3) - 0.5) + return (point * sign).astype(np.float32) + + +def _on_edge(x: float) -> bool: + if x > 0: + return abs(x - 1) < 1e-2 + + return abs(x + 1) < 1e-2 + + +def _build_animation(num_rays: int, num_samples, num_pause) -> sp.Scene: + scene = sp.Scene() + main = scene.create_canvas_3d("main", width=600, height=600) + main.shading = sp.Shading(bg_color=sp.Colors.White) + x_proj = scene.create_canvas_2d("x_proj", width=200, height=200, + background_color=sp.Colors.White) + y_proj = scene.create_canvas_2d("y_proj", width=200, height=200, + background_color=sp.Colors.White) + z_proj = scene.create_canvas_2d("z_proj", width=200, height=200, + background_color=sp.Colors.White) + + cube_mesh = scene.create_mesh("cube") + cube_mesh.add_cube(sp.Colors.Black, transform=sp.Transforms.scale(2), + add_wireframe=True, fill_triangles=False) + cube_mesh.add_coordinate_axes(transform=sp.Transforms.scale(0.5)) + + up_dir = np.array([0, 1, 0], np.float32) + forward_dir = np.array([0, 0, 1], np.float32) + orbit = ffn.orbit(up_dir, forward_dir, num_rays * (num_samples + 2*num_pause), + 65, ffn.Resolution(600, 600), 5) + orbit = iter(orbit) + + for _ in range(num_rays): + ray_start = _random_point() + ray_end = _random_point() + check = ray_start * ray_end + if (check > 0).any(): + index = np.nonzero(check > 0) + ray_end[index] *= -1 + + direction = (ray_end - ray_start) + length = np.linalg.norm(direction) + direction /= length + ray = Ray(*ray_start, *direction) + x_int, y_int, z_int = _intersect_cube_with_ray(ray) + samples = np.linspace(0, length, num_samples) + t_min, a_min = _max(x_int.enter, y_int.enter, z_int.enter) + t_max, a_max = _min(x_int.exit, y_int.exit, z_int.exit) + colors = [sp.Colors.Red, sp.Colors.Green, sp.Colors.Blue] + samples = np.sort(np.concatenate([samples, np.array([t_min, t_max])])) + + for sample in samples: + ray_mesh = scene.create_mesh() + point = ray.cast(sample) + ray_mesh.add_thickline(sp.Colors.Black, ray_start, point, 0.01, 0.01) + + transform = sp.Transforms.scale(0.15) + transform = sp.Transforms.translate(point) @ transform + num_frames = 1 + if sample == t_min: + num_frames = num_pause + ray_mesh.add_sphere(colors[a_min], transform=transform) + elif sample == t_max: + num_frames = num_pause + ray_mesh.add_sphere(colors[a_max], transform=transform) + + coords = np.stack([ray_start, point]) + + for _ in range(num_frames): + camera = next(orbit).to_scenepic() + main.create_frame(meshes=[cube_mesh, ray_mesh], camera=camera) + + for axis, proj in enumerate([x_proj, y_proj, z_proj]): + frame = proj.create_frame() + frame.add_rectangle(400/6, 400/6, 400/6, 400/6, colors[axis], 2) + coords2d = np.roll(coords, axis, axis=1)[:, 1:] + coords2d[:, 1] *= -1 + x, y = coords2d[-1] + coords2d = (coords2d + 3) * 200 / 6 + if sample == t_min and (_on_edge(x) or _on_edge(y)): + frame.add_circle(*coords2d[-1], 4, fill_color=colors[a_min]) + elif sample == t_max and (_on_edge(x) or _on_edge(y)): + frame.add_circle(*coords2d[-1], 4, fill_color=colors[a_max]) + + frame.add_line(coords2d, line_width=2) + + scene.grid("800px", "200px 200px 200px", "600px 200px") + scene.place(main.canvas_id, "1 / span 3", "1") + scene.place(x_proj.canvas_id, "1", "2") + scene.place(y_proj.canvas_id, "2", "2") + scene.place(z_proj.canvas_id, "3", "2") + scene.link_canvas_events(main, x_proj, y_proj, z_proj) + return scene + + +if __name__ == "__main__": + scene = _build_animation(5, 100, 20) + scene.save_as_html("ray_cube_int.html", title="Ray/Cube Intersection") diff --git a/visualizations/rendering_equation.py b/visualizations/rendering_equation.py new file mode 100644 index 0000000..5771d55 --- /dev/null +++ b/visualizations/rendering_equation.py @@ -0,0 +1,125 @@ +"""Animation of the rendering equation.""" + +import fourier_feature_nets as ffn +import numpy as np +import scenepic as sp +import torch +import torch.nn.functional as F + + +def rendering_equation(voxels: ffn.OcTree, ray_samples: ffn.RaySamples, + camera: ffn.CameraInfo, image: np.ndarray, + model: ffn.NeRF) -> sp.Scene: + scene = sp.Scene() + resolution = 600 + main = scene.create_canvas_3d(width=resolution, height=3 * resolution / 4) + main.shading = sp.Shading(bg_color=sp.Colors.White) + + graph = scene.create_graph(width=resolution, height=resolution / 4, + text_size=32) + + num_samples = len(ray_samples.positions[0]) + + leaf_centers = voxels.leaf_centers() + leaf_depths = voxels.leaf_depths() + leaf_colors = voxels.leaf_data() + depths = np.unique(leaf_depths) + cubes = [] + for depth in depths[-1:]: + mesh = scene.create_mesh(layer_id="model") + transform = sp.Transforms.scale(pow(2., 1-depth) * voxels.scale) + mesh.add_cube(sp.Colors.White, transform=transform, add_wireframe=True, fill_triangles=False) + depth_centers = leaf_centers[leaf_depths == depth] + depth_colors = leaf_colors[leaf_depths == depth] + mesh.enable_instancing(depth_centers, colors=depth_colors) + cubes.append(mesh) + + sp_image = scene.create_image() + image = image[..., :3] + sp_image.from_numpy(image) + camera_image = scene.create_mesh(texture_id=sp_image.image_id, + double_sided=True) + camera_image.add_camera_image(camera.to_scenepic()) + + frustum = scene.create_mesh() + frustum.add_camera_frustum(camera.to_scenepic(), sp.Colors.White) + + positions = ray_samples.positions.reshape(-1, 3) + views = ray_samples.view_directions.reshape(-1, 3) + model.eval() + with torch.no_grad(): + color_o = model(positions, views) + + color_o = color_o.reshape(1, num_samples, 4) + color, opacity = torch.split(color_o, [3, 1], -1) + color = torch.sigmoid(color) + opacity = F.softplus(opacity) + + assert not color.isnan().any() + assert not opacity.isnan().any() + + opacity = opacity.squeeze(-1) + deltas = ray_samples.t_values[:, 1:] - ray_samples.t_values[:, :-1] + max_dist = torch.full_like(deltas[:, :1], 1e10) + deltas = torch.cat([deltas, max_dist], dim=-1) + trans = torch.exp(-(opacity * deltas).cumsum(-1)) + + graph.add_sparkline("σ", opacity[0].numpy(), sp.Colors.Red, 3) + graph.add_sparkline("T", trans[0].numpy(), sp.Colors.Blue, 3) + + camera_start = [-1.5, 0, 0] + lookat = [0, 0, 0] + fov = 70 + + def _add_meshes(frame: sp.Frame3D): + for mesh in cubes: + frame.add_mesh(mesh) + + frame.add_mesh(frustum) + frame.add_mesh(camera_image) + + for i in range(num_samples): + start = camera.position + end = ray_samples.positions[0, i] + ray_mesh = scene.create_mesh() + ray_mesh.add_thickline(sp.Colors.Black, start, end, 0.005, 0.005) + + angle = (i * np.pi) / num_samples + + view_rot = sp.Transforms.rotation_about_y(angle) + view_pos = view_rot[:3, :3] @ np.array(camera_start) + view_cam = sp.Camera(view_pos, lookat, fov_y_degrees=fov, aspect_ratio=4/3) + + sample_mesh = scene.create_mesh() + sample_mesh.add_sphere(sp.Colors.White, transform=sp.Transforms.scale(0.03)) + positions = ray_samples.positions[0, :i+1] + colors = color[0, :i+1] + index = opacity[0, :i+1] > 0.1 + positions = positions[index] + colors = colors[index] + sample_mesh.enable_instancing(positions.numpy(), colors=colors.numpy()) + + frame = main.create_frame(camera=view_cam) + _add_meshes(frame) + frame.add_mesh(ray_mesh) + frame.add_mesh(sample_mesh) + + scene.grid("600px", "600px 200px", "600px") + scene.place(main.canvas_id, "1", "1") + scene.place(graph.canvas_id, "2", "1") + scene.link_canvas_events(main, graph) + return scene + + +if __name__ == "__main__": + voxels = ffn.OcTree.load("antinous_octree_7.npz") + dataset = ffn.ImageDataset.load("antinous_400.npz", "train", 256, True, True) + model = ffn.load_model("antinous_800_nerf.pt") + image = dataset.images[17] + row = 190 + col = 240 + camera = dataset.cameras[17] + index = 17 * dataset.sampler.rays_per_camera + row * 400 + col + rays = dataset.sampler.sample([index], None) + scene = rendering_equation(voxels, rays, camera, image, model) + scene.save_as_html("rendering_eq.html", title="Rendering Equation") diff --git a/visualizations/view_angle.py b/visualizations/view_angle.py new file mode 100644 index 0000000..1a709b5 --- /dev/null +++ b/visualizations/view_angle.py @@ -0,0 +1,71 @@ +import os +import cv2 +import numpy as np +import fourier_feature_nets as ffn + + +def _camera_angle(source: ffn.CameraInfo, dest: ffn.CameraInfo) -> float: + angle0 = source.position / np.linalg.norm(source.position, keepdims=True) + angle1 = dest.position / np.linalg.norm(dest.position, keepdims=True) + return (angle0 * angle1).sum() + + +def _main(): + output_dir = os.path.join("results","view_angle") + os.makedirs(output_dir, exist_ok=True) + dataset = ffn.ImageDataset.load("trex_400.npz", "train", 256, True, True) + model = ffn.load_model("trex_800_nerf.pt") + image = dataset.images[1].astype(np.float32) / 255 + image = image[..., :3] * image[..., 3:] + image = (image * 255).astype(np.uint8) + row = 310 + col = 137 + camera = dataset.cameras[1] + index = dataset.sampler.rays_per_camera + row * 400 + col + rays = dataset.sampler.sample([index], None) + raycaster = ffn.Raycaster(model) + render = raycaster.render(rays, True) + start = dataset.sampler.starts[index].numpy() + direction = dataset.sampler.directions[index].numpy() + depth = render.depth[0].item() + + position = start + direction * depth + source_cam = dataset.cameras[1] + index = 0 + for camera, image in zip(dataset.cameras, dataset.images): + angle = _camera_angle(source_cam, camera) + if angle < 0.5: + continue + + print(index, angle) + image = image.astype(np.float32) / 255 + image = image[..., :3] * image[..., 3:] + image = (image * 255).astype(np.uint8) + + col, row = camera.project(position[np.newaxis])[0] + col = int(col - 16) + row = int(row - 16) + patch = image[row:row+32, col:col+32] + patch = cv2.resize(patch, (128, 128), cv2.INTER_NEAREST) + + frame = np.zeros((400, 800, 3), np.uint8) + frame[:, :400] = image + frame[136:264, 536:664] = patch + + frame = cv2.rectangle(frame, (col, row), (col+32, row+32), + (255, 255, 255), 2) + frame = cv2.rectangle(frame, (536, 136), (664, 264), + (255, 255, 255), 2) + frame = cv2.line(frame, (col+32, row), (536, 136), + (255, 255, 255), 2) + frame = cv2.line(frame, (col+32, row+32), (536, 264), + (255, 255, 255), 2) + + path = "frame_{:04d}.png".format(index) + path = os.path.join(output_dir, path) + cv2.imwrite(path, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + index += 1 + + +if __name__ == "__main__": + _main() diff --git a/visualizations/volume_raycasting.py b/visualizations/volume_raycasting.py index d51e54a..fb5d532 100644 --- a/visualizations/volume_raycasting.py +++ b/visualizations/volume_raycasting.py @@ -504,7 +504,7 @@ def save_as_html(self, path): if __name__ == "__main__": - dataset = ffn.RayDataset.load("antinous_800.npz", "train", 64, True, False) + dataset = ffn.ImageDataset.load("antinous_800.npz", "train", 64, True, False) voxels = ffn.OcTree.load("antinous_octree_8.npz") anim = VolumeRaycastingAnimation(dataset, voxels, width=1280, height=720) print("Writing scenepic to file...") diff --git a/visualizations/world_to_camera.py b/visualizations/world_to_camera.py index 44aa3fb..04da3fa 100644 --- a/visualizations/world_to_camera.py +++ b/visualizations/world_to_camera.py @@ -170,7 +170,7 @@ def _add_meshes(frame: sp.Frame3D, model_transform: np.ndarray, if __name__ == "__main__": - dataset = ffn.RayDataset.load("antinous_400.npz", "train", 64, True, False) + dataset = ffn.ImageDataset.load("antinous_400.npz", "train", 64, True, False) voxels = ffn.OcTree.load("antinous_octree_8.npz") scene = sp.Scene() world_to_camera(scene, voxels, dataset.cameras[6], dataset.images[6], 800) diff --git a/voxelize_model.py b/voxelize_model.py index 61a01d1..45e433b 100644 --- a/voxelize_model.py +++ b/voxelize_model.py @@ -45,8 +45,8 @@ def _main(): else: opacity_model = None - dataset = ffn.RayDataset.load(args.data_path, "train", 400, 128, False, - opacity_model) + dataset = ffn.ImageDataset.load(args.data_path, "train", 400, 128, False, + opacity_model) if dataset is None: return 1 @@ -58,7 +58,6 @@ def _main(): sampler = dataset.sampler model = model.to(args.device) raycaster = ffn.Raycaster(model) - raycaster.to(args.device) num_rays = len(sampler) colors = [] positions = [] @@ -67,14 +66,14 @@ def _main(): for start in range(0, num_rays, args.batch_size): end = min(start + args.batch_size, num_rays) index = list(range(start, end)) - rays = sampler[list(range(start, end))] - color, alpha, depth = raycaster.render(rays.to(args.device), True) - valid = (alpha > args.alpha_threshold).cpu() - colors.append(color[valid].cpu().numpy()) - starts = sampler.starts[index] - dirs = sampler.directions[index] - position = starts + dirs * depth.cpu().unsqueeze(-1) - positions.append(position[valid].cpu().numpy()) + rays = sampler.sample(list(range(start, end)), None) + color, alpha, depth = raycaster.render(rays.to(args.device), True).numpy() + valid = (alpha > args.alpha_threshold) + colors.append(color[valid]) + starts = sampler.starts[index].cpu().numpy() + dirs = sampler.directions[index].cpu().numpy() + position = starts + dirs * depth[..., np.newaxis] + positions.append(position[valid]) bar.next(end - start) bar.finish()