Skip to content

Computational Displays

Is there a good resource for classifying existing Augmented Reality glasses?

Using your favorite search engine, investigate if there is a reliable up-to-date table that helps comparing existing Augmented Reality glasses in terms of functionality and technical capabilities (e.g., field-of-View, resolution, focus cues).

Complex-valued Gaussian splatting for holography

Informative · Practical

Traditional 3D Gaussian Splatting represents a scene as a collection of 3D Gaussian primitives, each with a mean position, covariance (orientation and scale), colour, and opacity. These Gaussians are "splatted" onto the image plane by projecting their 3D covariance into 2D and alpha-compositing them in depth order to produce a rendered image. Complex-valued Gaussian splatting extends this idea to holographic rendering. Instead of compositing real-valued colours, each Gaussian carries a complex amplitude and a phase. When splatted, the contributions are summed as complex fields, and the resulting field is propagated to the hologram plane using the band-limited angular spectrum method. This produces a complex hologram that encodes both amplitude and phase information suitable for driving a holographic display.

odak provides a pure PyTorch implementation of this pipeline in odak.learn.wave.complex_gaussians, free of any external dependencies beyond PyTorch itself. The key classes are:

  • Gaussians: Stores and manages the 3D Gaussian primitives (means, quaternion rotations, scales, colours, phases, opacities, and plane assignments).
  • Scene: Combines a set of Gaussians with a camera to render complex holograms via tile-based splatting and angular spectrum propagation.
  • PerspectiveCamera: A lightweight camera model available from odak.learn.tools that stores rotation, translation, focal length, and principal point.
How does complex Gaussian splatting differ from standard Gaussian splatting?

In standard 3D Gaussian splatting, Gaussians are alpha-composited to produce real-valued pixel colours. In the complex-valued variant, each Gaussian contributes a complex field \(A \cdot e^{i\phi}\), where \(A\) is the amplitude (derived from colour and opacity) and \(\phi\) is a learned phase. The key equations remain similar for projection (3D covariance to 2D covariance), but the rendering step sums complex contributions rather than blending colours. The summed complex field is then propagated to the hologram plane using band-limited angular spectrum propagation to generate a hologram.

What is band-limited angular spectrum propagation?

The angular spectrum method propagates a 2D complex field by a distance \(d\) through free space. In the frequency domain, propagation amounts to multiplying the field's Fourier transform by a transfer function:

\[ H(f_x, f_y) = \exp\left(i \cdot d \cdot \sqrt{k^2 - (2\pi f_x)^2 - (2\pi f_y)^2}\right), \]

where \(k = 2\pi / \lambda\) is the wavenumber. The "band-limited" variant applies a frequency mask to avoid aliasing artefacts from evanescent waves, ensuring physically accurate results.

Usage example

The script below demonstrates how to initialise a set of random complex Gaussians, set up a camera, render a hologram, and visualise the result. We keep the example brief so that first-time readers can follow each step.

import sys
import torch
from argparse import Namespace
import odak


def main():
    # 1. Define scene parameters.
    num_points = 128       # number of Gaussian primitives
    num_planes = 1         # number of hologram planes
    img_size = (64, 64)    # rendered image resolution (W, H)
    wavelengths = [633e-9] # red laser wavelength in metres

    args = Namespace(
        num_planes=num_planes,
        wavelengths=wavelengths,
        pixel_pitch=8e-6,           # 8 micron pixel pitch
        distances=[0.02],           # propagation distance to hologram plane
        pad_size=list(img_size),
        aperture_size=-1,           # no aperture
    )

    # 2. Create randomly initialised Gaussians.
    #    We place them in a small box in front of the camera
    #    so they project within the image.
    from odak.learn.wave.complex_gaussians import Gaussians, Scene
    gaussians = Gaussians(
        init_type="random",
        device="cpu",
        num_points=num_points,
        args_prop=args,
    )

    # Override positions to a visible volume in front of the camera.
    with torch.no_grad():
        gaussians.means.data = torch.rand(num_points, 3) * 0.2 - 0.1 # (1)
        gaussians.means.data[:, 2] = gaussians.means.data[:, 2].abs() + 2.0
    print(f"Initialised {len(gaussians)} Gaussians")

    # 3. Set up a perspective camera looking at the origin.
    from odak.learn.tools import PerspectiveCamera
    camera = PerspectiveCamera(
        R=torch.eye(3).unsqueeze(0),
        T=torch.tensor([[0.0, 0.0, 0.0]]),
        focal_length=torch.tensor([500.0, 500.0]),
        principal_point=torch.tensor([32.0, 32.0]),
    )

    # 4. Create a Scene and render the hologram.
    scene = Scene(gaussians, args)
    hologram, plane_field = scene.render(
        camera=camera,
        img_size=img_size,
        tile_size=(32, 32),
    )
    print(f"Hologram shape: {hologram.shape}, dtype: {hologram.dtype}")

    # 5. Extract amplitude and phase from the complex hologram.
    amplitude = odak.learn.wave.calculate_amplitude(hologram[0])
    phase = odak.learn.wave.calculate_phase(hologram[0])

    # 6. Visualise the results.
    positions = gaussians.means.detach().cpu().numpy()
    colors = gaussians.colours.detach().cpu().numpy()

    visualize = True
    if visualize:
        # 3D point cloud of Gaussian positions.
        diagram = odak.visualize.plotly.rayshow(
            columns=1,
            marker_size=5.0,
            subplot_titles=["<b>Gaussian positions</b>"],
        )
        diagram.add_point(positions, color=colors, column=1)
        diagram.show()

        # Hologram amplitude and phase as 2D images.
        amplitude_image = amplitude.detach().unsqueeze(0).unsqueeze(0)
        phase_image = phase.detach().unsqueeze(0).unsqueeze(0)

        detector_amp = odak.visualize.plotly.detectorshow()
        detector_amp.add_field(amplitude_image)
        detector_amp.show()

        detector_phase = odak.visualize.plotly.detectorshow()
        detector_phase.add_field(phase_image)
        detector_phase.show()

    assert hologram.shape[0] == len(wavelengths)
    print("Done.")


if __name__ == "__main__":
    sys.exit(main())
  1. Positions are overridden to a small box (x, y in [-0.1, 0.1], z in [2.0, 2.2]) so that they fall within the camera frustum and produce a visible hologram. With focal_length=500 and principal_point=(32, 32), these Gaussians project near the center of the 64×64 image.

The code above follows a simple pipeline:

  1. Define parameters – number of Gaussians, image size, wavelength, pixel pitch, and propagation distance.
  2. Initialise GaussiansGaussians(init_type="random", ...) creates randomly placed primitives with random colours, phases, and opacities.
  3. Set up the cameraPerspectiveCamera from odak.learn.tools defines the view with rotation, translation, focal length, and principal point.
  4. RenderScene.render() performs depth-sorted tile-based splatting followed by band-limited angular spectrum propagation to produce a complex hologram.
  5. Analyseodak.learn.wave.calculate_amplitude and odak.learn.wave.calculate_phase extract the amplitude and phase from the complex field.

Let us also examine the key classes provided in odak for this pipeline.

Bases: Module

Complex-valued 3-D Gaussian primitives for holographic rendering.

Each Gaussian is parameterised by a 3-D mean, a rotation quaternion, log-scales, per-channel colour amplitudes, per-channel phases, opacity, and a discrete plane-assignment vector.

Parameters:

  • init_type (str) –
               One of ``"gaussians"`` (load from checkpoint),
               ``"random"`` (random initialisation), or
               ``"point"`` (from a point cloud).
    
  • device (str) –
               Torch device string, e.g. ``"cuda:0"`` or ``"cpu"``.
    
  • load_path (Optional[str], default: None ) –
               Path to a ``.pth`` checkpoint (required when
               ``init_type="gaussians"``).
    
  • num_points (Optional[int], default: None ) –
               Number of Gaussians (required when
               ``init_type="random"``).
    
  • args_prop (Namespace, default: None ) –
               Must contain at least ``num_planes``.
    
  • pointcloud_data (Optional[dict], default: None ) –
               ``{"positions": Tensor, "colors": Tensor}``
               (required when ``init_type="point"``).
    
  • generate_dense_point (int, default: False ) –
               Number of densification rounds (default: ``0``).
    
  • densepoint_scatter
               Standard deviation of the densification noise
               (default: ``0.01``).
    
  • img_size
               Image size for random init hints.
    
Source code in odak/learn/wave/complex_gaussians.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
class Gaussians(torch.nn.Module):
    """
    Complex-valued 3-D Gaussian primitives for holographic rendering.

    Each Gaussian is parameterised by a 3-D mean, a rotation quaternion,
    log-scales, per-channel colour amplitudes, per-channel phases,
    opacity, and a discrete plane-assignment vector.

    Parameters
    ----------
    init_type        : str
                       One of ``"gaussians"`` (load from checkpoint),
                       ``"random"`` (random initialisation), or
                       ``"point"`` (from a point cloud).
    device           : str
                       Torch device string, e.g. ``"cuda:0"`` or ``"cpu"``.
    load_path        : str or None, optional
                       Path to a ``.pth`` checkpoint (required when
                       ``init_type="gaussians"``).
    num_points       : int or None, optional
                       Number of Gaussians (required when
                       ``init_type="random"``).
    args_prop        : argparse.Namespace
                       Must contain at least ``num_planes``.
    pointcloud_data  : dict or None, optional
                       ``{"positions": Tensor, "colors": Tensor}``
                       (required when ``init_type="point"``).
    generate_dense_point : int, optional
                       Number of densification rounds (default: ``0``).
    densepoint_scatter   : float, optional
                       Standard deviation of the densification noise
                       (default: ``0.01``).
    img_size         : tuple or None, optional
                       Image size for random init hints.
    """

    def __init__(
        self,
        init_type: str,
        device: str,
        load_path: Optional[str] = None,
        num_points: Optional[int] = None,
        args_prop: Namespace = None,
        pointcloud_data: Optional[dict] = None,
        generate_dense_point=False,
        densepoint_scatter=0.01,
        img_size=None,
    ):
        super(Gaussians, self).__init__()

        self.device = device
        self.num_planes = args_prop.num_planes
        self.generate_dense_point = generate_dense_point
        self.densepoint_scatter = densepoint_scatter
        self.NEAR_PLANE = 1.0
        self.FAR_PLANE = 1000.0

        if init_type == "gaussians":
            if load_path is None:
                raise ValueError("load_path is required for init_type='gaussians'")
            data = self._load_gaussians(load_path)

        elif init_type == "random":
            if num_points is None:
                raise ValueError("num_points is required for init_type='random'")
            data = self._load_random(num_points, img_size)

        elif init_type == "point":
            if pointcloud_data is None:
                raise ValueError("pointcloud_data is required for init_type='point'")
            self.is_outdoor = args_prop.is_outdoor
            data = self._load_point(pointcloud_data)

        else:
            raise ValueError(f"Invalid init_type: {init_type}")

        self.register_parameter(
            "pre_act_quats",
            torch.nn.Parameter(data["pre_act_quats"], requires_grad=False),
        )
        self.register_parameter(
            "means", torch.nn.Parameter(data["means"], requires_grad=False)
        )
        self.register_parameter(
            "pre_act_scales",
            torch.nn.Parameter(data["pre_act_scales"], requires_grad=False),
        )
        self.register_parameter(
            "colours", torch.nn.Parameter(data["colours"], requires_grad=False)
        )
        self.register_parameter(
            "pre_act_phase",
            torch.nn.Parameter(data["pre_act_phase"], requires_grad=False),
        )
        self.register_parameter(
            "pre_act_opacities",
            torch.nn.Parameter(data["pre_act_opacities"], requires_grad=False),
        )
        self.register_parameter(
            "pre_act_plane_assignment",
            torch.nn.Parameter(data["pre_act_plane_assignment"], requires_grad=False),
        )
        self.to(self.device)

    def __len__(self):
        return len(self.means)

    def _load_gaussians(self, ply_path: str):
        if ply_path.endswith(".pth"):
            checkpoint = torch.load(ply_path, map_location="cpu", weights_only=False)
            data = {
                "pre_act_quats": checkpoint["pre_act_quats"]
                .clone()
                .detach()
                .to(torch.float32)
                .contiguous(),
                "means": checkpoint["means"]
                .clone()
                .detach()
                .to(torch.float32)
                .contiguous(),
                "pre_act_scales": checkpoint["pre_act_scales"]
                .clone()
                .detach()
                .to(torch.float32)
                .contiguous(),
                "colours": checkpoint["colours"]
                .clone()
                .detach()
                .to(torch.float32)
                .contiguous(),
                "pre_act_phase": checkpoint["pre_act_phase"]
                .clone()
                .detach()
                .to(torch.float32)
                .contiguous(),
                "pre_act_opacities": checkpoint["pre_act_opacities"]
                .clone()
                .detach()
                .to(torch.float32)
                .contiguous(),
                "pre_act_plane_assignment": checkpoint["pre_act_plane_assignment"]
                .clone()
                .detach()
                .to(torch.float32)
                .contiguous(),
            }
            num = len(data["means"])
            print(f"Loaded Gaussians {num} from checkpoint: {ply_path}")
            return data

    def _load_random(self, num_points: int, image_size=None):
        data = dict()
        means = (torch.rand((num_points, 3)) * 2 - 1).to(torch.float32) * 15.7
        data["means"] = means.to(torch.float32)
        data["colours"] = torch.rand((num_points, 3), dtype=torch.float32)
        quats_norm = torch.randn((num_points, 4), dtype=torch.float32)
        quats_norm = F.normalize(quats_norm, dim=1)
        quats = torch.zeros((num_points, 4), dtype=torch.float32)
        quats[:, 0] = 1.0
        data["pre_act_quats"] = quats + quats_norm * 0.01
        data["pre_act_scales"] = torch.log(
            (torch.rand((num_points, 1), dtype=torch.float32) + 1e-6) * 0.01
        )
        data["pre_act_scales"] = data["pre_act_scales"].repeat(1, 3)
        data["pre_act_phase"] = torch.randn((num_points, 3), dtype=torch.float32)
        data["pre_act_opacities"] = torch.ones((num_points,), dtype=torch.float32)
        data["pre_act_plane_assignment"] = (
            torch.randn((num_points, self.num_planes), dtype=torch.float32) * 10.0
        )

        print(
            f"Loaded Randomly {num_points} gaussians with image size {image_size if image_size else 'default'}"
        )
        return data

    def _load_point(self, pointcloud_data: dict) -> dict:
        positions = pointcloud_data["positions"]
        colors = pointcloud_data["colors"]
        data = {}

        centre = positions.mean(dim=0, keepdim=True)
        distances = torch.norm(positions - centre, dim=1)

        if self.is_outdoor:
            num_points = positions.shape[0]
            num_points_to_keep = int(num_points * 0.98)
            sorted_indices = torch.argsort(distances)
            keep_indices = sorted_indices[:num_points_to_keep]
            print(f"Keeping {num_points_to_keep} points from {num_points} points")
            positions = positions[keep_indices]
            colors = colors[keep_indices]

        if self.generate_dense_point > 0:
            orig_positions = positions
            orig_colors = colors
            for _ in range(self.generate_dense_point):
                offset = torch.randn_like(orig_positions) * self.densepoint_scatter
                positions = torch.cat([positions, orig_positions + offset], dim=0)
                colors = torch.cat([colors, orig_colors], dim=0)

        if self.is_outdoor:
            divide = self.generate_dense_point if self.generate_dense_point > 0 else 1
            divide = 1
            bg_count = int(positions.shape[0] * (0.8 / divide))
            print(f"randomize {bg_count} points for outdoor scene")
            centre = positions.mean(dim=0, keepdim=True)
            centred = positions - centre
            max_dist = torch.norm(centred, dim=1).max().item()

            cov = torch.matmul(centred.T, centred) / centred.shape[0]
            eigenvalues, eigenvectors = torch.linalg.eigh(cov)
            sorted_indices = torch.argsort(eigenvalues, descending=True)
            eigenvectors = eigenvectors[:, sorted_indices]

            main_axis = eigenvectors[:, 0]
            up_direction = eigenvectors[:, 1]
            side_direction = eigenvectors[:, 2]

            pole_threshold = 0.35
            valid_directions = []
            batch_size = bg_count * 3

            while len(valid_directions) < bg_count:
                directions_batch = F.normalize(
                    torch.randn(batch_size, 3, device=positions.device), dim=1
                )
                valid_mask = torch.abs(directions_batch[:, 2]) <= pole_threshold
                valid_batch = directions_batch[valid_mask]
                valid_directions.append(valid_batch)
                all_valid = torch.cat(valid_directions, dim=0)
                if all_valid.size(0) >= bg_count:
                    directions_standard = all_valid[:bg_count]
                    break

            rotation_matrix = torch.stack(
                [main_axis, up_direction, side_direction], dim=1
            )
            directions = torch.matmul(directions_standard, rotation_matrix.T)
            radii = torch.empty(bg_count, device=positions.device).uniform_(
                max_dist * 0.5, max_dist * 0.8
            )
            bg_positions = directions * radii.unsqueeze(1) + centre
            bg_colors = torch.rand((bg_count, 3), dtype=torch.float32).to(
                positions.device
            )
            positions = torch.cat([positions, bg_positions.to(positions.device)], dim=0)
            colors = torch.cat([colors, bg_colors], dim=0)

        total_points = positions.shape[0]
        print(f"Total points in original point cloud: {total_points}")

        data["means"] = positions.to(torch.float32).contiguous()
        data["colours"] = colors.to(torch.float32).contiguous()

        quats_norm = F.normalize(
            torch.randn((total_points, 4), dtype=torch.float32), dim=1
        )
        quats = torch.zeros((total_points, 4), dtype=torch.float32)
        quats[:, 0] = 1.0
        data["pre_act_quats"] = quats + quats_norm * 0.01

        scales = torch.log(
            (torch.rand((total_points, 1), dtype=torch.float32) + 1e-6) * 0.01
        )
        data["pre_act_scales"] = scales.repeat(1, 3)
        data["pre_act_phase"] = torch.randn((total_points, 3), dtype=torch.float32)
        data["pre_act_opacities"] = torch.ones(total_points, dtype=torch.float32)
        data["pre_act_plane_assignment"] = (
            torch.randn((total_points, self.num_planes), dtype=torch.float32) * 10.0
        )

        print(f"Initialized {total_points} Gaussians from point cloud data")
        return data

    def check_if_trainable(self):
        """Raise an exception if any learnable parameter has ``requires_grad=False``."""
        attrs = [
            "means",
            "pre_act_scales",
            "colours",
            "pre_act_phase",
            "pre_act_opacities",
            "pre_act_plane_assignment",
            "pre_act_quats",
        ]
        for attr in attrs:
            param = getattr(self, attr)
            if not getattr(param, "requires_grad", False):
                raise Exception(
                    "Please use function make_trainable to make parameters trainable"
                )

    def compute_cov_3D(self, quats: torch.Tensor, scales: torch.Tensor):
        """
        Compute 3-D covariance matrices from quaternions and scales.

        Parameters
        ----------
        quats  : torch.Tensor
                 Unit quaternions ``(N, 4)`` in ``(w, x, y, z)`` convention.
        scales : torch.Tensor
                 Scale vectors ``(N, 3)``.

        Returns
        -------
        cov_3D : torch.Tensor
                 Covariance matrices ``(N, 3, 3)``.
        """
        Is = torch.eye(scales.size(1), device=scales.device)
        scale_mats = (scales.unsqueeze(2).expand(*scales.size(), scales.size(1))) * Is
        rots = quaternion_to_rotation_matrix(quats)
        cov_3D = torch.matmul(rots, scale_mats)
        cov_3D = torch.matmul(cov_3D, torch.transpose(scale_mats, 1, 2))
        cov_3D = torch.matmul(cov_3D, torch.transpose(rots, 1, 2))
        return cov_3D

    def _compute_jacobian(self, cam_means_3D: torch.Tensor, fx, fy, img_size: Tuple):
        """
        Compute the Jacobian matrix for the perspective projection.

        Parameters
        ----------
        cam_means_3D : torch.Tensor
                       Camera-space 3-D means ``(N, 3)``.
        fx, fy       : float or torch.Tensor
                       Focal lengths in pixels.
        img_size     : tuple of int
                       ``(W, H)`` image dimensions.

        Returns
        -------
        J : torch.Tensor
            Jacobian matrices ``(N, 2, 3)``.
        """
        W, H = img_size
        half_tan_fov_x = 0.5 * W / fx
        half_tan_fov_y = 0.5 * H / fy

        tx = cam_means_3D[:, 0]
        ty = cam_means_3D[:, 1]
        tz = cam_means_3D[:, 2]
        tz2 = tz * tz

        clipping_mask = (tz > self.NEAR_PLANE) & (tz < self.FAR_PLANE)

        lim_x = 1.3 * half_tan_fov_x
        lim_y = 1.3 * half_tan_fov_y

        tx = torch.clamp(tx / tz, -lim_x, lim_x) * tz
        ty = torch.clamp(ty / tz, -lim_y, lim_y) * tz

        J = torch.zeros((len(tx), 2, 3), device=cam_means_3D.device)
        J[:, 0, 0] = fx / tz
        J[:, 1, 1] = fy / tz
        J[:, 0, 2] = -(fx * tx) / tz2
        J[:, 1, 2] = -(fy * ty) / tz2

        clipping_mask = clipping_mask.to(torch.float32).view(-1, 1, 1)
        J = J * clipping_mask

        return J

    def compute_cov_2D(
        self,
        cam_means_3D: torch.Tensor,
        quats: torch.Tensor,
        scales: torch.Tensor,
        fx,
        fy,
        R,
        img_size: Tuple,
    ):
        """
        Compute 2-D projected covariance matrices (Eq. 5 of 3DGS paper).

        Parameters
        ----------
        cam_means_3D : torch.Tensor
                       Camera-space means ``(N, 3)``.
        quats        : torch.Tensor
                       Quaternions ``(N, 4)``.
        scales       : torch.Tensor
                       Scales ``(N, 3)``.
        fx, fy       : float or torch.Tensor
                       Focal lengths.
        R            : torch.Tensor
                       View rotation matrix.
        img_size     : tuple of int
                       ``(W, H)``.

        Returns
        -------
        cov_2D : torch.Tensor
                 2-D covariance matrices ``(N, 2, 2)``.

        References
        ----------
        Kerbl, B. et al. "3D Gaussian Splatting for Real-Time Radiance Field
        Rendering." *SIGGRAPH 2023*.
        """
        J = self._compute_jacobian(cam_means_3D, fx, fy, img_size)
        N = J.shape[0]

        W = R.repeat(N, 1, 1)
        cov_3D = self.compute_cov_3D(quats, scales)

        cov_2D = torch.matmul(J, W)
        cov_2D = torch.matmul(cov_2D, cov_3D)
        cov_2D = torch.matmul(cov_2D, torch.transpose(W, 1, 2))
        cov_2D = torch.matmul(cov_2D, torch.transpose(J, 1, 2))

        cov_2D[:, 0, 0] += 0.3
        cov_2D[:, 1, 1] += 0.3

        return cov_2D

    def compute_means_2D(self, cam_means_3D: torch.Tensor, fx, fy, px, py):
        """
        Project 3-D camera-space points to 2-D pixel coordinates.

        Parameters
        ----------
        cam_means_3D : torch.Tensor
                       Camera-space means ``(N, 3)``.
        fx, fy       : float or torch.Tensor
                       Focal lengths.
        px, py       : float or torch.Tensor
                       Principal-point offsets.

        Returns
        -------
        means_2D : torch.Tensor
                   2-D pixel coordinates ``(N, 2)``.
        """
        clipping_mask = (cam_means_3D[:, 2] > self.NEAR_PLANE) & (
            cam_means_3D[:, 2] < self.FAR_PLANE
        )

        inv_z = 1.0 / cam_means_3D[:, 2].unsqueeze(1)
        cam_means_3D_xy = -cam_means_3D[:, :2] * inv_z

        means_2D = torch.empty((cam_means_3D.shape[0], 2), device=cam_means_3D.device)
        means_2D[:, 0] = fx * cam_means_3D_xy[:, 0] + px
        means_2D[:, 1] = fy * cam_means_3D_xy[:, 1] + py

        large_value = 1e6
        means_2D[~clipping_mask] = large_value

        return means_2D

    @staticmethod
    def invert_cov_2D(cov_2D: torch.Tensor):
        """
        Invert 2×2 covariance matrices.

        Parameters
        ----------
        cov_2D : torch.Tensor
                 Covariance matrices ``(N, 2, 2)``.

        Returns
        -------
        cov_2D_inverse : torch.Tensor
                         Inverse covariance matrices ``(N, 2, 2)``.
        """
        determinants = (
            cov_2D[:, 0, 0] * cov_2D[:, 1, 1] - cov_2D[:, 1, 0] * cov_2D[:, 0, 1]
        )
        determinants = determinants[:, None, None]

        cov_2D_inverse = torch.zeros_like(cov_2D)
        cov_2D_inverse[:, 0, 0] = cov_2D[:, 1, 1]
        cov_2D_inverse[:, 1, 1] = cov_2D[:, 0, 0]
        cov_2D_inverse[:, 0, 1] = -1.0 * cov_2D[:, 0, 1]
        cov_2D_inverse[:, 1, 0] = -1.0 * cov_2D[:, 1, 0]

        cov_2D_inverse = (1.0 / determinants) * cov_2D_inverse
        return cov_2D_inverse

    @staticmethod
    def calculate_gaussian_bounds(means_2D, cov_2D, img_size, confidence=3.0):
        """
        Compute axis-aligned bounding boxes from 2-D covariance.

        Parameters
        ----------
        means_2D   : torch.Tensor
                     2-D positions ``(N, 2)``.
        cov_2D     : torch.Tensor
                     Covariance matrices ``(N, 2, 2)``.
        img_size   : tuple of int
                     ``(W, H)``.
        confidence : float, optional
                     Number of standard deviations (default: ``3.0``).

        Returns
        -------
        bounds : torch.Tensor
                 ``(N, 4)`` with ``[min_x, min_y, max_x, max_y]``.
        """
        var_x = cov_2D[:, 0, 0]
        var_y = cov_2D[:, 1, 1]

        std_x = torch.sqrt(var_x)
        std_y = torch.sqrt(var_y)

        radius_x = confidence * std_x
        radius_y = confidence * std_y

        min_x = means_2D[:, 0] - radius_x
        min_y = means_2D[:, 1] - radius_y
        max_x = means_2D[:, 0] + radius_x
        max_y = means_2D[:, 1] + radius_y

        W, H = img_size
        min_x = torch.clamp(min_x, 0, W - 1)
        min_y = torch.clamp(min_y, 0, H - 1)
        max_x = torch.clamp(max_x, 0, W - 1)
        max_y = torch.clamp(max_y, 0, H - 1)

        bounds = torch.stack([min_x, min_y, max_x, max_y], dim=1)
        return bounds

    @staticmethod
    def apply_activations(
        pre_act_quats,
        pre_act_scales,
        pre_act_phase=None,
        pre_act_opacities=None,
        pre_act_plane_assignment=None,
        step=None,
        max_step=None,
    ):
        """
        Apply non-linear activations to raw Gaussian parameters.

        Parameters
        ----------
        pre_act_quats            : torch.Tensor
        pre_act_scales           : torch.Tensor
        pre_act_phase            : torch.Tensor or None
        pre_act_opacities        : torch.Tensor or None
        pre_act_plane_assignment : torch.Tensor or None
        step, max_step           : int or None

        Returns
        -------
        quats, scales, phase, opacities, plane_probs : torch.Tensor
        """
        scales = torch.exp(pre_act_scales)
        quats = F.normalize(pre_act_quats)
        phase = pre_act_phase % (2.0 * odak.pi)
        opacities = torch.sigmoid(pre_act_opacities)

        ste = StraightThroughEstimator()
        plane_probs = ste(pre_act_plane_assignment)

        return quats, scales, phase, opacities, plane_probs

    def save_gaussians(self, save_path: str):
        """
        Save Gaussian parameters to a ``.pth`` checkpoint.

        Parameters
        ----------
        save_path : str
                    Destination file path.
        """
        state_dict = {
            "pre_act_quats": self.pre_act_quats.cpu(),
            "means": self.means.cpu(),
            "pre_act_scales": self.pre_act_scales.cpu(),
            "colours": self.colours.cpu(),
            "pre_act_phase": self.pre_act_phase.cpu(),
            "pre_act_opacities": self.pre_act_opacities.cpu(),
            "pre_act_plane_assignment": self.pre_act_plane_assignment.cpu(),
        }
        torch.save(state_dict, save_path)
        print(f"Gaussians saved to {save_path}")

apply_activations(pre_act_quats, pre_act_scales, pre_act_phase=None, pre_act_opacities=None, pre_act_plane_assignment=None, step=None, max_step=None) staticmethod

Apply non-linear activations to raw Gaussian parameters.

Parameters:

  • pre_act_quats
  • pre_act_scales
  • pre_act_phase
  • pre_act_opacities
  • pre_act_plane_assignment (Tensor or None, default: None ) –
  • step
  • max_step

Returns:

  • quats, scales, phase, opacities, plane_probs : torch.Tensor
Source code in odak/learn/wave/complex_gaussians.py
@staticmethod
def apply_activations(
    pre_act_quats,
    pre_act_scales,
    pre_act_phase=None,
    pre_act_opacities=None,
    pre_act_plane_assignment=None,
    step=None,
    max_step=None,
):
    """
    Apply non-linear activations to raw Gaussian parameters.

    Parameters
    ----------
    pre_act_quats            : torch.Tensor
    pre_act_scales           : torch.Tensor
    pre_act_phase            : torch.Tensor or None
    pre_act_opacities        : torch.Tensor or None
    pre_act_plane_assignment : torch.Tensor or None
    step, max_step           : int or None

    Returns
    -------
    quats, scales, phase, opacities, plane_probs : torch.Tensor
    """
    scales = torch.exp(pre_act_scales)
    quats = F.normalize(pre_act_quats)
    phase = pre_act_phase % (2.0 * odak.pi)
    opacities = torch.sigmoid(pre_act_opacities)

    ste = StraightThroughEstimator()
    plane_probs = ste(pre_act_plane_assignment)

    return quats, scales, phase, opacities, plane_probs

calculate_gaussian_bounds(means_2D, cov_2D, img_size, confidence=3.0) staticmethod

Compute axis-aligned bounding boxes from 2-D covariance.

Parameters:

  • means_2D
         2-D positions ``(N, 2)``.
    
  • cov_2D
         Covariance matrices ``(N, 2, 2)``.
    
  • img_size
         ``(W, H)``.
    
  • confidence (float, default: 3.0 ) –
         Number of standard deviations (default: ``3.0``).
    

Returns:

  • bounds ( Tensor ) –

    (N, 4) with [min_x, min_y, max_x, max_y].

Source code in odak/learn/wave/complex_gaussians.py
@staticmethod
def calculate_gaussian_bounds(means_2D, cov_2D, img_size, confidence=3.0):
    """
    Compute axis-aligned bounding boxes from 2-D covariance.

    Parameters
    ----------
    means_2D   : torch.Tensor
                 2-D positions ``(N, 2)``.
    cov_2D     : torch.Tensor
                 Covariance matrices ``(N, 2, 2)``.
    img_size   : tuple of int
                 ``(W, H)``.
    confidence : float, optional
                 Number of standard deviations (default: ``3.0``).

    Returns
    -------
    bounds : torch.Tensor
             ``(N, 4)`` with ``[min_x, min_y, max_x, max_y]``.
    """
    var_x = cov_2D[:, 0, 0]
    var_y = cov_2D[:, 1, 1]

    std_x = torch.sqrt(var_x)
    std_y = torch.sqrt(var_y)

    radius_x = confidence * std_x
    radius_y = confidence * std_y

    min_x = means_2D[:, 0] - radius_x
    min_y = means_2D[:, 1] - radius_y
    max_x = means_2D[:, 0] + radius_x
    max_y = means_2D[:, 1] + radius_y

    W, H = img_size
    min_x = torch.clamp(min_x, 0, W - 1)
    min_y = torch.clamp(min_y, 0, H - 1)
    max_x = torch.clamp(max_x, 0, W - 1)
    max_y = torch.clamp(max_y, 0, H - 1)

    bounds = torch.stack([min_x, min_y, max_x, max_y], dim=1)
    return bounds

check_if_trainable()

Raise an exception if any learnable parameter has requires_grad=False.

Source code in odak/learn/wave/complex_gaussians.py
def check_if_trainable(self):
    """Raise an exception if any learnable parameter has ``requires_grad=False``."""
    attrs = [
        "means",
        "pre_act_scales",
        "colours",
        "pre_act_phase",
        "pre_act_opacities",
        "pre_act_plane_assignment",
        "pre_act_quats",
    ]
    for attr in attrs:
        param = getattr(self, attr)
        if not getattr(param, "requires_grad", False):
            raise Exception(
                "Please use function make_trainable to make parameters trainable"
            )

compute_cov_2D(cam_means_3D, quats, scales, fx, fy, R, img_size)

Compute 2-D projected covariance matrices (Eq. 5 of 3DGS paper).

Parameters:

  • cam_means_3D (Tensor) –
           Camera-space means ``(N, 3)``.
    
  • quats (Tensor) –
           Quaternions ``(N, 4)``.
    
  • scales (Tensor) –
           Scales ``(N, 3)``.
    
  • fx
           Focal lengths.
    
  • fy
           Focal lengths.
    
  • R
           View rotation matrix.
    
  • img_size (Tuple) –
           ``(W, H)``.
    

Returns:

  • cov_2D ( Tensor ) –

    2-D covariance matrices (N, 2, 2).

References

Kerbl, B. et al. "3D Gaussian Splatting for Real-Time Radiance Field Rendering." SIGGRAPH 2023.

Source code in odak/learn/wave/complex_gaussians.py
def compute_cov_2D(
    self,
    cam_means_3D: torch.Tensor,
    quats: torch.Tensor,
    scales: torch.Tensor,
    fx,
    fy,
    R,
    img_size: Tuple,
):
    """
    Compute 2-D projected covariance matrices (Eq. 5 of 3DGS paper).

    Parameters
    ----------
    cam_means_3D : torch.Tensor
                   Camera-space means ``(N, 3)``.
    quats        : torch.Tensor
                   Quaternions ``(N, 4)``.
    scales       : torch.Tensor
                   Scales ``(N, 3)``.
    fx, fy       : float or torch.Tensor
                   Focal lengths.
    R            : torch.Tensor
                   View rotation matrix.
    img_size     : tuple of int
                   ``(W, H)``.

    Returns
    -------
    cov_2D : torch.Tensor
             2-D covariance matrices ``(N, 2, 2)``.

    References
    ----------
    Kerbl, B. et al. "3D Gaussian Splatting for Real-Time Radiance Field
    Rendering." *SIGGRAPH 2023*.
    """
    J = self._compute_jacobian(cam_means_3D, fx, fy, img_size)
    N = J.shape[0]

    W = R.repeat(N, 1, 1)
    cov_3D = self.compute_cov_3D(quats, scales)

    cov_2D = torch.matmul(J, W)
    cov_2D = torch.matmul(cov_2D, cov_3D)
    cov_2D = torch.matmul(cov_2D, torch.transpose(W, 1, 2))
    cov_2D = torch.matmul(cov_2D, torch.transpose(J, 1, 2))

    cov_2D[:, 0, 0] += 0.3
    cov_2D[:, 1, 1] += 0.3

    return cov_2D

compute_cov_3D(quats, scales)

Compute 3-D covariance matrices from quaternions and scales.

Parameters:

  • quats (Tensor) –
     Unit quaternions ``(N, 4)`` in ``(w, x, y, z)`` convention.
    
  • scales (Tensor) –
     Scale vectors ``(N, 3)``.
    

Returns:

  • cov_3D ( Tensor ) –

    Covariance matrices (N, 3, 3).

Source code in odak/learn/wave/complex_gaussians.py
def compute_cov_3D(self, quats: torch.Tensor, scales: torch.Tensor):
    """
    Compute 3-D covariance matrices from quaternions and scales.

    Parameters
    ----------
    quats  : torch.Tensor
             Unit quaternions ``(N, 4)`` in ``(w, x, y, z)`` convention.
    scales : torch.Tensor
             Scale vectors ``(N, 3)``.

    Returns
    -------
    cov_3D : torch.Tensor
             Covariance matrices ``(N, 3, 3)``.
    """
    Is = torch.eye(scales.size(1), device=scales.device)
    scale_mats = (scales.unsqueeze(2).expand(*scales.size(), scales.size(1))) * Is
    rots = quaternion_to_rotation_matrix(quats)
    cov_3D = torch.matmul(rots, scale_mats)
    cov_3D = torch.matmul(cov_3D, torch.transpose(scale_mats, 1, 2))
    cov_3D = torch.matmul(cov_3D, torch.transpose(rots, 1, 2))
    return cov_3D

compute_means_2D(cam_means_3D, fx, fy, px, py)

Project 3-D camera-space points to 2-D pixel coordinates.

Parameters:

  • cam_means_3D (Tensor) –
           Camera-space means ``(N, 3)``.
    
  • fx
           Focal lengths.
    
  • fy
           Focal lengths.
    
  • px
           Principal-point offsets.
    
  • py
           Principal-point offsets.
    

Returns:

  • means_2D ( Tensor ) –

    2-D pixel coordinates (N, 2).

Source code in odak/learn/wave/complex_gaussians.py
def compute_means_2D(self, cam_means_3D: torch.Tensor, fx, fy, px, py):
    """
    Project 3-D camera-space points to 2-D pixel coordinates.

    Parameters
    ----------
    cam_means_3D : torch.Tensor
                   Camera-space means ``(N, 3)``.
    fx, fy       : float or torch.Tensor
                   Focal lengths.
    px, py       : float or torch.Tensor
                   Principal-point offsets.

    Returns
    -------
    means_2D : torch.Tensor
               2-D pixel coordinates ``(N, 2)``.
    """
    clipping_mask = (cam_means_3D[:, 2] > self.NEAR_PLANE) & (
        cam_means_3D[:, 2] < self.FAR_PLANE
    )

    inv_z = 1.0 / cam_means_3D[:, 2].unsqueeze(1)
    cam_means_3D_xy = -cam_means_3D[:, :2] * inv_z

    means_2D = torch.empty((cam_means_3D.shape[0], 2), device=cam_means_3D.device)
    means_2D[:, 0] = fx * cam_means_3D_xy[:, 0] + px
    means_2D[:, 1] = fy * cam_means_3D_xy[:, 1] + py

    large_value = 1e6
    means_2D[~clipping_mask] = large_value

    return means_2D

invert_cov_2D(cov_2D) staticmethod

Invert 2×2 covariance matrices.

Parameters:

  • cov_2D (Tensor) –
     Covariance matrices ``(N, 2, 2)``.
    

Returns:

  • cov_2D_inverse ( Tensor ) –

    Inverse covariance matrices (N, 2, 2).

Source code in odak/learn/wave/complex_gaussians.py
@staticmethod
def invert_cov_2D(cov_2D: torch.Tensor):
    """
    Invert 2×2 covariance matrices.

    Parameters
    ----------
    cov_2D : torch.Tensor
             Covariance matrices ``(N, 2, 2)``.

    Returns
    -------
    cov_2D_inverse : torch.Tensor
                     Inverse covariance matrices ``(N, 2, 2)``.
    """
    determinants = (
        cov_2D[:, 0, 0] * cov_2D[:, 1, 1] - cov_2D[:, 1, 0] * cov_2D[:, 0, 1]
    )
    determinants = determinants[:, None, None]

    cov_2D_inverse = torch.zeros_like(cov_2D)
    cov_2D_inverse[:, 0, 0] = cov_2D[:, 1, 1]
    cov_2D_inverse[:, 1, 1] = cov_2D[:, 0, 0]
    cov_2D_inverse[:, 0, 1] = -1.0 * cov_2D[:, 0, 1]
    cov_2D_inverse[:, 1, 0] = -1.0 * cov_2D[:, 1, 0]

    cov_2D_inverse = (1.0 / determinants) * cov_2D_inverse
    return cov_2D_inverse

save_gaussians(save_path)

Save Gaussian parameters to a .pth checkpoint.

Parameters:

  • save_path (str) –
        Destination file path.
    
Source code in odak/learn/wave/complex_gaussians.py
def save_gaussians(self, save_path: str):
    """
    Save Gaussian parameters to a ``.pth`` checkpoint.

    Parameters
    ----------
    save_path : str
                Destination file path.
    """
    state_dict = {
        "pre_act_quats": self.pre_act_quats.cpu(),
        "means": self.means.cpu(),
        "pre_act_scales": self.pre_act_scales.cpu(),
        "colours": self.colours.cpu(),
        "pre_act_phase": self.pre_act_phase.cpu(),
        "pre_act_opacities": self.pre_act_opacities.cpu(),
        "pre_act_plane_assignment": self.pre_act_plane_assignment.cpu(),
    }
    torch.save(state_dict, save_path)
    print(f"Gaussians saved to {save_path}")

Wave-based rendering scene for complex-valued Gaussian splatting.

Combines a set of :class:Gaussians with a camera model to produce holographic fields via tile-based splatting and band-limited angular-spectrum propagation.

Parameters:

  • gaussians (Gaussians) –
        The Gaussian primitives.
    
  • args_prop (Namespace) –
        Must contain ``wavelengths``, ``pixel_pitch``,
        ``distances``, ``pad_size``, and ``aperture_size``.
    
Source code in odak/learn/wave/complex_gaussians.py
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
class Scene:
    """
    Wave-based rendering scene for complex-valued Gaussian splatting.

    Combines a set of :class:`Gaussians` with a camera model to produce
    holographic fields via tile-based splatting and band-limited
    angular-spectrum propagation.

    Parameters
    ----------
    gaussians : Gaussians
                The Gaussian primitives.
    args_prop : argparse.Namespace
                Must contain ``wavelengths``, ``pixel_pitch``,
                ``distances``, ``pad_size``, and ``aperture_size``.
    """

    def __init__(self, gaussians: Gaussians, args_prop):
        self.gaussians = gaussians
        self.args_prop = args_prop
        self.device = self.gaussians.device
        self.wavelengths = torch.tensor(
            args_prop.wavelengths, dtype=torch.float32, device=self.device
        )
        self.mean_2D_for_planeprob = None

    def __repr__(self):
        return f"<Scene with {len(self.gaussians)} Gaussians>"

    def compute_transmittance(self, alphas: torch.Tensor):
        """
        Compute transmittance from per-Gaussian alpha values.

        Parameters
        ----------
        alphas : torch.Tensor
                 Alpha (opacity × Gaussian) values ``(N, H, W)``.

        Returns
        -------
        transmittance : torch.Tensor
                        Cumulative transmittance ``(N, H, W)``.
        """
        _, H, W = alphas.shape
        S = torch.ones((1, H, W), device=alphas.device, dtype=alphas.dtype)
        one_minus_alphas = 1.0 - alphas
        one_minus_alphas = torch.cat((S, one_minus_alphas), dim=0)
        transmittance = torch.cumprod(one_minus_alphas, dim=0)[:-1]
        transmittance = torch.where(transmittance < 1e-4, 0.0, transmittance)
        return transmittance

    def compute_depth_values(self, camera: PerspectiveCamera):
        """
        Compute per-Gaussian depth values in camera space.

        Parameters
        ----------
        camera : PerspectiveCamera

        Returns
        -------
        z_vals : torch.Tensor
                 Depth values ``(N,)``.
        """
        means_3D = self.gaussians.means
        R = camera.R[0] if camera.R.dim() == 3 else camera.R
        T = camera.T[0] if camera.T.dim() == 2 else camera.T
        means_cam = means_3D @ R + T
        z_vals = means_cam[:, -1]
        return z_vals

    def calculate_gaussian_directions(self, means_3D, camera):
        """
        Compute unit direction vectors from camera centre to each Gaussian.

        Parameters
        ----------
        means_3D : torch.Tensor
                   3-D positions ``(N, 3)``.
        camera   : PerspectiveCamera

        Returns
        -------
        gaussian_dirs : torch.Tensor
                        Unit direction vectors ``(N, 3)``.
        """
        N = means_3D.shape[0]
        camera_centers = camera.get_camera_center().repeat(N, 1)
        gaussian_dirs = means_3D - camera_centers
        gaussian_dirs = F.normalize(gaussian_dirs)
        return gaussian_dirs

    def get_idxs_to_filter_and_sort(self, z_vals: torch.Tensor):
        """
        Sort Gaussians by depth and filter those behind the camera.

        Parameters
        ----------
        z_vals : torch.Tensor
                 Depth values ``(N,)``.

        Returns
        -------
        idxs : torch.Tensor
               Sorted indices with ``z >= 0``.
        """
        sorted_vals, indices = torch.sort(z_vals)
        mask = sorted_vals >= 0
        idxs = torch.masked_select(indices, mask).to(torch.int64)
        return idxs

    def splat(
        self,
        camera: PerspectiveCamera,
        means_3D: torch.Tensor,
        z_vals: torch.Tensor,
        quats: torch.Tensor,
        scales: torch.Tensor,
        colours: torch.Tensor,
        phase: torch.Tensor,
        opacities: torch.Tensor,
        plane_probs: torch.Tensor,
        wavelengths: torch.Tensor,
        img_size: Tuple = (256, 256),
        tile_size: Tuple = (64, 64),
    ):
        """
        Multi-channel wave-based tile splatting and propagation.

        Parameters
        ----------
        camera       : PerspectiveCamera
        means_3D     : torch.Tensor ``(N, 3)``
        z_vals       : torch.Tensor ``(N,)``
        quats        : torch.Tensor ``(N, 4)``
        scales       : torch.Tensor ``(N, 3)``
        colours      : torch.Tensor ``(N, 3)``
        phase        : torch.Tensor ``(N, 3)``
        opacities    : torch.Tensor ``(N,)``
        plane_probs  : torch.Tensor ``(N, num_planes)``
        wavelengths  : torch.Tensor ``(C,)``
        img_size     : tuple of int
        tile_size    : tuple of int

        Returns
        -------
        hologram_complex : torch.Tensor
                           Complex hologram ``(C, H, W)``.
        plane_fields     : torch.Tensor
                           Per-plane fields ``(P, C, H, W)``.
        """
        W, H = img_size
        device = means_3D.device
        num_planes = plane_probs.shape[1]

        if isinstance(wavelengths, list):
            wavelengths = torch.tensor(wavelengths, device=device, dtype=torch.float32)

        R = camera.R
        fx, fy = camera.focal_length.flatten()
        px, py = camera.principal_point.flatten()

        if tile_size[0] <= 0 or tile_size[1] <= 0:
            tile_size = (64, 64)

        num_channels = len(wavelengths)

        cam_means_3D = camera.transform_world_to_camera_space(means_3D)

        means_2D = self.gaussians.compute_means_2D(cam_means_3D, fx, fy, px, py)
        self.mean_2D_for_planeprob = means_2D
        cov_2D = self.gaussians.compute_cov_2D(
            cam_means_3D, quats, scales, fx, fy, R, img_size
        )
        gaussian_bounds = self.gaussians.calculate_gaussian_bounds(
            means_2D, cov_2D, img_size
        )
        plane_fields = torch.zeros(
            (num_planes, num_channels, H, W), dtype=torch.complex64, device=device
        )

        tile_w, tile_h = tile_size
        x_tiles = math.ceil(W / tile_w)
        y_tiles = math.ceil(H / tile_h)

        for y_idx in range(y_tiles):
            for x_idx in range(x_tiles):
                x = x_idx * tile_w
                y = y_idx * tile_h
                actual_tile_w = min(tile_w, W - x)
                actual_tile_h = min(tile_h, H - y)
                x_min, y_min = x, y
                x_max = x + actual_tile_w - 1
                y_max = y + actual_tile_h - 1

                in_x_range = (gaussian_bounds[:, 0] <= x_max) & (
                    gaussian_bounds[:, 2] >= x_min
                )
                in_y_range = (gaussian_bounds[:, 1] <= y_max) & (
                    gaussian_bounds[:, 3] >= y_min
                )
                gaussian_indices = torch.where(in_x_range & in_y_range)[0]

                tile_plane_fields = self.splat_tile(
                    R,
                    fx,
                    fy,
                    px,
                    py,
                    cam_means_3D,
                    z_vals,
                    quats,
                    scales,
                    colours,
                    phase,
                    opacities,
                    plane_probs,
                    x,
                    y,
                    (actual_tile_w, actual_tile_h),
                    gaussian_indices,
                    img_size,
                    wavelengths,
                )
                plane_fields[
                    :, :, y : y + actual_tile_h, x : x + actual_tile_w
                ] += tile_plane_fields

        hologram_complex_planes = []
        for p in range(num_planes):
            plane_hologram = []
            for c, plane_field_c in enumerate(plane_fields[p]):
                wavelength_val = float(wavelengths[c].cpu().item())
                hologram_complex_c = _bandlimited_angular_spectrum_propagation(
                    plane_field_c,
                    wavelength=wavelength_val,
                    pixel_pitch=self.args_prop.pixel_pitch,
                    distance=-self.args_prop.distances[p],
                    size=self.args_prop.pad_size,
                    aperture_size=self.args_prop.aperture_size,
                )
                plane_hologram.append(hologram_complex_c)
            hologram_complex_planes.append(torch.stack(plane_hologram, dim=0))

        hologram_complex = sum(hologram_complex_planes)
        return hologram_complex, plane_fields

    def splat_tile(
        self,
        R,
        fx,
        fy,
        px,
        py,
        cam_means_3D,
        z_vals,
        quats,
        scales,
        colours,
        phase,
        opacities,
        plane_probs,
        tile_x,
        tile_y,
        tile_size,
        gaussian_indices,
        img_size,
        wavelengths,
    ):
        """
        Render a single tile for all planes (pure PyTorch).

        Parameters
        ----------
        R              : torch.Tensor
                         Rotation matrix.
        fx, fy, px, py : float or torch.Tensor
                         Camera intrinsics.
        cam_means_3D   : torch.Tensor ``(N, 3)``
        z_vals         : torch.Tensor ``(N,)``
        quats, scales, colours, phase, opacities : torch.Tensor
        plane_probs    : torch.Tensor ``(N, P)``
        tile_x, tile_y : int
        tile_size      : tuple of int
                         ``(tile_w, tile_h)`` for this tile.
        gaussian_indices : torch.Tensor
                         Indices of Gaussians overlapping this tile.
        img_size       : tuple of int
                         Full image ``(W, H)``.
        wavelengths    : torch.Tensor ``(C,)``

        Returns
        -------
        result : torch.Tensor
                 ``(P, C, tile_h, tile_w)`` complex field for this tile.
        """
        device = cam_means_3D.device
        W, H = img_size
        tile_w, tile_h = tile_size
        num_planes = plane_probs.shape[1]

        tile_plane_fields = []
        for _ in range(num_planes):
            tile_plane_fields.append(
                torch.zeros(
                    (len(wavelengths), tile_h, tile_w),
                    device=device,
                    dtype=torch.complex64,
                )
            )

        if gaussian_indices.numel() == 0:
            return torch.stack(tile_plane_fields, dim=0)

        xs, ys = torch.meshgrid(
            torch.arange(tile_x, tile_x + tile_w, device=device),
            torch.arange(tile_y, tile_y + tile_h, device=device),
            indexing="xy",
        )
        points_2D = torch.stack([xs.flatten(), ys.flatten()], dim=1)

        tile_means_3D = cam_means_3D[gaussian_indices]
        valid_mask = (tile_means_3D[:, 2] > self.gaussians.NEAR_PLANE) & (
            tile_means_3D[:, 2] < self.gaussians.FAR_PLANE
        )
        if not valid_mask.any():
            return torch.stack(tile_plane_fields, dim=0)

        tile_means_3D = tile_means_3D[valid_mask]
        valid_gaussian_indices = gaussian_indices[valid_mask]

        tile_means_2D = self.gaussians.compute_means_2D(tile_means_3D, fx, fy, px, py)

        tile_plane_probs = plane_probs[valid_gaussian_indices]

        tile_means_2D = tile_means_2D.unsqueeze(1)
        diff = points_2D.unsqueeze(0) - tile_means_2D

        tile_cov_2D = self.gaussians.compute_cov_2D(
            tile_means_3D,
            quats[valid_gaussian_indices],
            scales[valid_gaussian_indices],
            fx,
            fy,
            R,
            img_size,
        )
        cov_inv = self.gaussians.invert_cov_2D(tile_cov_2D)

        term = torch.bmm(diff, cov_inv)
        term = (term * diff).sum(dim=-1)
        term = term.view(-1, tile_h, tile_w)

        gauss_exp = torch.exp(-0.5 * term)
        tile_opacities = opacities[valid_gaussian_indices].view(-1, 1, 1)
        base_alphas = tile_opacities * gauss_exp

        for plane_idx in range(num_planes):
            plane_mask = tile_plane_probs[:, plane_idx].view(-1, 1, 1)
            plane_alphas = base_alphas * plane_mask
            plane_alphas_reshaped = plane_alphas.reshape(-1, tile_h, tile_w)
            transmittance = self.compute_transmittance(plane_alphas_reshaped)

            for c in range(len(wavelengths)):
                colours_c = colours[valid_gaussian_indices, c].view(-1, 1, 1)
                phase_c = phase[valid_gaussian_indices, c].view(-1, 1, 1)
                tile_plane_fields[plane_idx][c] = torch.sum(
                    colours_c * plane_alphas * transmittance * torch.exp(1j * phase_c),
                    dim=0,
                )

        result = torch.stack(tile_plane_fields, dim=0)
        return result

    def render(
        self,
        camera: PerspectiveCamera,
        img_size: Tuple = (-1, -1),
        bg_colour: Tuple = (0.0, 0.0, 0.0),
        tile_size: Tuple = (64, 64),
        step=-1,
        max_step=-1,
    ):
        """
        Render a complex hologram from the current Gaussians.

        Parameters
        ----------
        camera    : PerspectiveCamera
        img_size  : tuple of int
                    ``(W, H)``.
        bg_colour : tuple of float
                    Background colour (unused in wave rendering).
        tile_size : tuple of int
                    Tile dimensions for splatting.
        step      : int
                    Current training step (for scheduled activations).
        max_step  : int
                    Maximum training step.

        Returns
        -------
        hologram_complex : torch.Tensor
                           Complex hologram ``(C, H, W)``.
        plane_field      : torch.Tensor
                           Per-plane complex fields
                           ``(P, C, H, W)``.
        """
        z_vals = self.compute_depth_values(camera)

        cam_means_3D = camera.transform_world_to_camera_space(self.gaussians.means)
        visible_mask = (cam_means_3D[:, 2] > self.gaussians.NEAR_PLANE) & (
            cam_means_3D[:, 2] < self.gaussians.FAR_PLANE
        )
        valid_indices = torch.where(visible_mask)[0]

        idxs = self.get_idxs_to_filter_and_sort(z_vals[valid_indices])
        idxs = valid_indices[idxs]

        pre_act_quats = self.gaussians.pre_act_quats[idxs]
        pre_act_scales = self.gaussians.pre_act_scales[idxs]
        pre_act_phase = self.gaussians.pre_act_phase[idxs]
        pre_act_opacities = self.gaussians.pre_act_opacities[idxs]
        pre_act_plane_assignment = self.gaussians.pre_act_plane_assignment[idxs]

        z_vals = z_vals[idxs]
        means_3D = self.gaussians.means[idxs]
        colours = self.gaussians.colours[idxs]

        quats, scales, phase_val, opacities, plane_probs = (
            self.gaussians.apply_activations(
                pre_act_quats,
                pre_act_scales,
                pre_act_phase,
                pre_act_opacities,
                pre_act_plane_assignment,
                step,
                max_step,
            )
        )

        hologram_complex, plane_field = self.splat(
            camera,
            means_3D,
            z_vals,
            quats,
            scales,
            colours,
            phase_val,
            opacities,
            plane_probs,
            self.wavelengths,
            img_size,
            tile_size,
        )

        return hologram_complex, plane_field

calculate_gaussian_directions(means_3D, camera)

Compute unit direction vectors from camera centre to each Gaussian.

Parameters:

  • means_3D (Tensor) –
       3-D positions ``(N, 3)``.
    
  • camera

Returns:

  • gaussian_dirs ( Tensor ) –

    Unit direction vectors (N, 3).

Source code in odak/learn/wave/complex_gaussians.py
def calculate_gaussian_directions(self, means_3D, camera):
    """
    Compute unit direction vectors from camera centre to each Gaussian.

    Parameters
    ----------
    means_3D : torch.Tensor
               3-D positions ``(N, 3)``.
    camera   : PerspectiveCamera

    Returns
    -------
    gaussian_dirs : torch.Tensor
                    Unit direction vectors ``(N, 3)``.
    """
    N = means_3D.shape[0]
    camera_centers = camera.get_camera_center().repeat(N, 1)
    gaussian_dirs = means_3D - camera_centers
    gaussian_dirs = F.normalize(gaussian_dirs)
    return gaussian_dirs

compute_depth_values(camera)

Compute per-Gaussian depth values in camera space.

Parameters:

Returns:

  • z_vals ( Tensor ) –

    Depth values (N,).

Source code in odak/learn/wave/complex_gaussians.py
def compute_depth_values(self, camera: PerspectiveCamera):
    """
    Compute per-Gaussian depth values in camera space.

    Parameters
    ----------
    camera : PerspectiveCamera

    Returns
    -------
    z_vals : torch.Tensor
             Depth values ``(N,)``.
    """
    means_3D = self.gaussians.means
    R = camera.R[0] if camera.R.dim() == 3 else camera.R
    T = camera.T[0] if camera.T.dim() == 2 else camera.T
    means_cam = means_3D @ R + T
    z_vals = means_cam[:, -1]
    return z_vals

compute_transmittance(alphas)

Compute transmittance from per-Gaussian alpha values.

Parameters:

  • alphas (Tensor) –
     Alpha (opacity × Gaussian) values ``(N, H, W)``.
    

Returns:

  • transmittance ( Tensor ) –

    Cumulative transmittance (N, H, W).

Source code in odak/learn/wave/complex_gaussians.py
def compute_transmittance(self, alphas: torch.Tensor):
    """
    Compute transmittance from per-Gaussian alpha values.

    Parameters
    ----------
    alphas : torch.Tensor
             Alpha (opacity × Gaussian) values ``(N, H, W)``.

    Returns
    -------
    transmittance : torch.Tensor
                    Cumulative transmittance ``(N, H, W)``.
    """
    _, H, W = alphas.shape
    S = torch.ones((1, H, W), device=alphas.device, dtype=alphas.dtype)
    one_minus_alphas = 1.0 - alphas
    one_minus_alphas = torch.cat((S, one_minus_alphas), dim=0)
    transmittance = torch.cumprod(one_minus_alphas, dim=0)[:-1]
    transmittance = torch.where(transmittance < 1e-4, 0.0, transmittance)
    return transmittance

get_idxs_to_filter_and_sort(z_vals)

Sort Gaussians by depth and filter those behind the camera.

Parameters:

  • z_vals (Tensor) –
     Depth values ``(N,)``.
    

Returns:

  • idxs ( Tensor ) –

    Sorted indices with z >= 0.

Source code in odak/learn/wave/complex_gaussians.py
def get_idxs_to_filter_and_sort(self, z_vals: torch.Tensor):
    """
    Sort Gaussians by depth and filter those behind the camera.

    Parameters
    ----------
    z_vals : torch.Tensor
             Depth values ``(N,)``.

    Returns
    -------
    idxs : torch.Tensor
           Sorted indices with ``z >= 0``.
    """
    sorted_vals, indices = torch.sort(z_vals)
    mask = sorted_vals >= 0
    idxs = torch.masked_select(indices, mask).to(torch.int64)
    return idxs

render(camera, img_size=(-1, -1), bg_colour=(0.0, 0.0, 0.0), tile_size=(64, 64), step=-1, max_step=-1)

Render a complex hologram from the current Gaussians.

Parameters:

  • camera (PerspectiveCamera) –
  • img_size (Tuple, default: (-1, -1) ) –
        ``(W, H)``.
    
  • bg_colour (tuple of float, default: (0.0, 0.0, 0.0) ) –
        Background colour (unused in wave rendering).
    
  • tile_size (tuple of int, default: (64, 64) ) –
        Tile dimensions for splatting.
    
  • step
        Current training step (for scheduled activations).
    
  • max_step
        Maximum training step.
    

Returns:

  • hologram_complex ( Tensor ) –

    Complex hologram (C, H, W).

  • plane_field ( Tensor ) –

    Per-plane complex fields (P, C, H, W).

Source code in odak/learn/wave/complex_gaussians.py
def render(
    self,
    camera: PerspectiveCamera,
    img_size: Tuple = (-1, -1),
    bg_colour: Tuple = (0.0, 0.0, 0.0),
    tile_size: Tuple = (64, 64),
    step=-1,
    max_step=-1,
):
    """
    Render a complex hologram from the current Gaussians.

    Parameters
    ----------
    camera    : PerspectiveCamera
    img_size  : tuple of int
                ``(W, H)``.
    bg_colour : tuple of float
                Background colour (unused in wave rendering).
    tile_size : tuple of int
                Tile dimensions for splatting.
    step      : int
                Current training step (for scheduled activations).
    max_step  : int
                Maximum training step.

    Returns
    -------
    hologram_complex : torch.Tensor
                       Complex hologram ``(C, H, W)``.
    plane_field      : torch.Tensor
                       Per-plane complex fields
                       ``(P, C, H, W)``.
    """
    z_vals = self.compute_depth_values(camera)

    cam_means_3D = camera.transform_world_to_camera_space(self.gaussians.means)
    visible_mask = (cam_means_3D[:, 2] > self.gaussians.NEAR_PLANE) & (
        cam_means_3D[:, 2] < self.gaussians.FAR_PLANE
    )
    valid_indices = torch.where(visible_mask)[0]

    idxs = self.get_idxs_to_filter_and_sort(z_vals[valid_indices])
    idxs = valid_indices[idxs]

    pre_act_quats = self.gaussians.pre_act_quats[idxs]
    pre_act_scales = self.gaussians.pre_act_scales[idxs]
    pre_act_phase = self.gaussians.pre_act_phase[idxs]
    pre_act_opacities = self.gaussians.pre_act_opacities[idxs]
    pre_act_plane_assignment = self.gaussians.pre_act_plane_assignment[idxs]

    z_vals = z_vals[idxs]
    means_3D = self.gaussians.means[idxs]
    colours = self.gaussians.colours[idxs]

    quats, scales, phase_val, opacities, plane_probs = (
        self.gaussians.apply_activations(
            pre_act_quats,
            pre_act_scales,
            pre_act_phase,
            pre_act_opacities,
            pre_act_plane_assignment,
            step,
            max_step,
        )
    )

    hologram_complex, plane_field = self.splat(
        camera,
        means_3D,
        z_vals,
        quats,
        scales,
        colours,
        phase_val,
        opacities,
        plane_probs,
        self.wavelengths,
        img_size,
        tile_size,
    )

    return hologram_complex, plane_field

splat(camera, means_3D, z_vals, quats, scales, colours, phase, opacities, plane_probs, wavelengths, img_size=(256, 256), tile_size=(64, 64))

Multi-channel wave-based tile splatting and propagation.

Parameters:

  • camera (PerspectiveCamera) –
  • means_3D (Tensor) –
  • z_vals (Tensor) –
  • quats (Tensor) –
  • scales (Tensor) –
  • colours (Tensor) –
  • phase (Tensor) –
  • opacities (Tensor) –
  • plane_probs (Tensor) –
  • wavelengths (Tensor) –
  • img_size (Tuple, default: (256, 256) ) –
  • tile_size (Tuple, default: (64, 64) ) –

Returns:

  • hologram_complex ( Tensor ) –

    Complex hologram (C, H, W).

  • plane_fields ( Tensor ) –

    Per-plane fields (P, C, H, W).

Source code in odak/learn/wave/complex_gaussians.py
def splat(
    self,
    camera: PerspectiveCamera,
    means_3D: torch.Tensor,
    z_vals: torch.Tensor,
    quats: torch.Tensor,
    scales: torch.Tensor,
    colours: torch.Tensor,
    phase: torch.Tensor,
    opacities: torch.Tensor,
    plane_probs: torch.Tensor,
    wavelengths: torch.Tensor,
    img_size: Tuple = (256, 256),
    tile_size: Tuple = (64, 64),
):
    """
    Multi-channel wave-based tile splatting and propagation.

    Parameters
    ----------
    camera       : PerspectiveCamera
    means_3D     : torch.Tensor ``(N, 3)``
    z_vals       : torch.Tensor ``(N,)``
    quats        : torch.Tensor ``(N, 4)``
    scales       : torch.Tensor ``(N, 3)``
    colours      : torch.Tensor ``(N, 3)``
    phase        : torch.Tensor ``(N, 3)``
    opacities    : torch.Tensor ``(N,)``
    plane_probs  : torch.Tensor ``(N, num_planes)``
    wavelengths  : torch.Tensor ``(C,)``
    img_size     : tuple of int
    tile_size    : tuple of int

    Returns
    -------
    hologram_complex : torch.Tensor
                       Complex hologram ``(C, H, W)``.
    plane_fields     : torch.Tensor
                       Per-plane fields ``(P, C, H, W)``.
    """
    W, H = img_size
    device = means_3D.device
    num_planes = plane_probs.shape[1]

    if isinstance(wavelengths, list):
        wavelengths = torch.tensor(wavelengths, device=device, dtype=torch.float32)

    R = camera.R
    fx, fy = camera.focal_length.flatten()
    px, py = camera.principal_point.flatten()

    if tile_size[0] <= 0 or tile_size[1] <= 0:
        tile_size = (64, 64)

    num_channels = len(wavelengths)

    cam_means_3D = camera.transform_world_to_camera_space(means_3D)

    means_2D = self.gaussians.compute_means_2D(cam_means_3D, fx, fy, px, py)
    self.mean_2D_for_planeprob = means_2D
    cov_2D = self.gaussians.compute_cov_2D(
        cam_means_3D, quats, scales, fx, fy, R, img_size
    )
    gaussian_bounds = self.gaussians.calculate_gaussian_bounds(
        means_2D, cov_2D, img_size
    )
    plane_fields = torch.zeros(
        (num_planes, num_channels, H, W), dtype=torch.complex64, device=device
    )

    tile_w, tile_h = tile_size
    x_tiles = math.ceil(W / tile_w)
    y_tiles = math.ceil(H / tile_h)

    for y_idx in range(y_tiles):
        for x_idx in range(x_tiles):
            x = x_idx * tile_w
            y = y_idx * tile_h
            actual_tile_w = min(tile_w, W - x)
            actual_tile_h = min(tile_h, H - y)
            x_min, y_min = x, y
            x_max = x + actual_tile_w - 1
            y_max = y + actual_tile_h - 1

            in_x_range = (gaussian_bounds[:, 0] <= x_max) & (
                gaussian_bounds[:, 2] >= x_min
            )
            in_y_range = (gaussian_bounds[:, 1] <= y_max) & (
                gaussian_bounds[:, 3] >= y_min
            )
            gaussian_indices = torch.where(in_x_range & in_y_range)[0]

            tile_plane_fields = self.splat_tile(
                R,
                fx,
                fy,
                px,
                py,
                cam_means_3D,
                z_vals,
                quats,
                scales,
                colours,
                phase,
                opacities,
                plane_probs,
                x,
                y,
                (actual_tile_w, actual_tile_h),
                gaussian_indices,
                img_size,
                wavelengths,
            )
            plane_fields[
                :, :, y : y + actual_tile_h, x : x + actual_tile_w
            ] += tile_plane_fields

    hologram_complex_planes = []
    for p in range(num_planes):
        plane_hologram = []
        for c, plane_field_c in enumerate(plane_fields[p]):
            wavelength_val = float(wavelengths[c].cpu().item())
            hologram_complex_c = _bandlimited_angular_spectrum_propagation(
                plane_field_c,
                wavelength=wavelength_val,
                pixel_pitch=self.args_prop.pixel_pitch,
                distance=-self.args_prop.distances[p],
                size=self.args_prop.pad_size,
                aperture_size=self.args_prop.aperture_size,
            )
            plane_hologram.append(hologram_complex_c)
        hologram_complex_planes.append(torch.stack(plane_hologram, dim=0))

    hologram_complex = sum(hologram_complex_planes)
    return hologram_complex, plane_fields

splat_tile(R, fx, fy, px, py, cam_means_3D, z_vals, quats, scales, colours, phase, opacities, plane_probs, tile_x, tile_y, tile_size, gaussian_indices, img_size, wavelengths)

Render a single tile for all planes (pure PyTorch).

Parameters:

  • R
             Rotation matrix.
    
  • fx (float or Tensor) –
             Camera intrinsics.
    
  • fy (float or Tensor) –
             Camera intrinsics.
    
  • px (float or Tensor) –
             Camera intrinsics.
    
  • py (float or Tensor) –
             Camera intrinsics.
    
  • cam_means_3D
  • z_vals
  • quats (Tensor) –
  • scales (Tensor) –
  • colours (Tensor) –
  • phase (Tensor) –
  • opacities (Tensor) –
  • plane_probs
  • tile_x (int) –
  • tile_y (int) –
  • tile_size
             ``(tile_w, tile_h)`` for this tile.
    
  • gaussian_indices (Tensor) –
             Indices of Gaussians overlapping this tile.
    
  • img_size
             Full image ``(W, H)``.
    
  • wavelengths

Returns:

  • result ( Tensor ) –

    (P, C, tile_h, tile_w) complex field for this tile.

Source code in odak/learn/wave/complex_gaussians.py
def splat_tile(
    self,
    R,
    fx,
    fy,
    px,
    py,
    cam_means_3D,
    z_vals,
    quats,
    scales,
    colours,
    phase,
    opacities,
    plane_probs,
    tile_x,
    tile_y,
    tile_size,
    gaussian_indices,
    img_size,
    wavelengths,
):
    """
    Render a single tile for all planes (pure PyTorch).

    Parameters
    ----------
    R              : torch.Tensor
                     Rotation matrix.
    fx, fy, px, py : float or torch.Tensor
                     Camera intrinsics.
    cam_means_3D   : torch.Tensor ``(N, 3)``
    z_vals         : torch.Tensor ``(N,)``
    quats, scales, colours, phase, opacities : torch.Tensor
    plane_probs    : torch.Tensor ``(N, P)``
    tile_x, tile_y : int
    tile_size      : tuple of int
                     ``(tile_w, tile_h)`` for this tile.
    gaussian_indices : torch.Tensor
                     Indices of Gaussians overlapping this tile.
    img_size       : tuple of int
                     Full image ``(W, H)``.
    wavelengths    : torch.Tensor ``(C,)``

    Returns
    -------
    result : torch.Tensor
             ``(P, C, tile_h, tile_w)`` complex field for this tile.
    """
    device = cam_means_3D.device
    W, H = img_size
    tile_w, tile_h = tile_size
    num_planes = plane_probs.shape[1]

    tile_plane_fields = []
    for _ in range(num_planes):
        tile_plane_fields.append(
            torch.zeros(
                (len(wavelengths), tile_h, tile_w),
                device=device,
                dtype=torch.complex64,
            )
        )

    if gaussian_indices.numel() == 0:
        return torch.stack(tile_plane_fields, dim=0)

    xs, ys = torch.meshgrid(
        torch.arange(tile_x, tile_x + tile_w, device=device),
        torch.arange(tile_y, tile_y + tile_h, device=device),
        indexing="xy",
    )
    points_2D = torch.stack([xs.flatten(), ys.flatten()], dim=1)

    tile_means_3D = cam_means_3D[gaussian_indices]
    valid_mask = (tile_means_3D[:, 2] > self.gaussians.NEAR_PLANE) & (
        tile_means_3D[:, 2] < self.gaussians.FAR_PLANE
    )
    if not valid_mask.any():
        return torch.stack(tile_plane_fields, dim=0)

    tile_means_3D = tile_means_3D[valid_mask]
    valid_gaussian_indices = gaussian_indices[valid_mask]

    tile_means_2D = self.gaussians.compute_means_2D(tile_means_3D, fx, fy, px, py)

    tile_plane_probs = plane_probs[valid_gaussian_indices]

    tile_means_2D = tile_means_2D.unsqueeze(1)
    diff = points_2D.unsqueeze(0) - tile_means_2D

    tile_cov_2D = self.gaussians.compute_cov_2D(
        tile_means_3D,
        quats[valid_gaussian_indices],
        scales[valid_gaussian_indices],
        fx,
        fy,
        R,
        img_size,
    )
    cov_inv = self.gaussians.invert_cov_2D(tile_cov_2D)

    term = torch.bmm(diff, cov_inv)
    term = (term * diff).sum(dim=-1)
    term = term.view(-1, tile_h, tile_w)

    gauss_exp = torch.exp(-0.5 * term)
    tile_opacities = opacities[valid_gaussian_indices].view(-1, 1, 1)
    base_alphas = tile_opacities * gauss_exp

    for plane_idx in range(num_planes):
        plane_mask = tile_plane_probs[:, plane_idx].view(-1, 1, 1)
        plane_alphas = base_alphas * plane_mask
        plane_alphas_reshaped = plane_alphas.reshape(-1, tile_h, tile_w)
        transmittance = self.compute_transmittance(plane_alphas_reshaped)

        for c in range(len(wavelengths)):
            colours_c = colours[valid_gaussian_indices, c].view(-1, 1, 1)
            phase_c = phase[valid_gaussian_indices, c].view(-1, 1, 1)
            tile_plane_fields[plane_idx][c] = torch.sum(
                colours_c * plane_alphas * transmittance * torch.exp(1j * phase_c),
                dim=0,
            )

    result = torch.stack(tile_plane_fields, dim=0)
    return result

A lightweight perspective camera model.

Stores camera intrinsics and extrinsics and provides coordinate-transform utilities.

Parameters:

  • R
             Rotation matrix, shape ``(3, 3)`` or ``(1, 3, 3)``.
    
  • T
             Translation vector, shape ``(3,)`` or ``(1, 3)``.
    
  • focal_length
             Focal lengths ``(fx, fy)``, shape ``(2,)`` or ``(1, 2)``.
    
  • principal_point
             Principal point ``(px, py)``, shape ``(2,)`` or ``(1, 2)``.
    
  • device
             Device for all tensors (default: ``"cpu"``).
    
Source code in odak/learn/tools/camera.py
class PerspectiveCamera:
    """
    A lightweight perspective camera model.

    Stores camera intrinsics and extrinsics and provides
    coordinate-transform utilities.

    Parameters
    ----------
    R              : torch.Tensor
                     Rotation matrix, shape ``(3, 3)`` or ``(1, 3, 3)``.
    T              : torch.Tensor
                     Translation vector, shape ``(3,)`` or ``(1, 3)``.
    focal_length   : torch.Tensor
                     Focal lengths ``(fx, fy)``, shape ``(2,)`` or ``(1, 2)``.
    principal_point: torch.Tensor
                     Principal point ``(px, py)``, shape ``(2,)`` or ``(1, 2)``.
    device         : torch.device or str, optional
                     Device for all tensors (default: ``"cpu"``).
    """

    def __init__(self, R, T, focal_length, principal_point, device="cpu"):
        self.device = torch.device(device)
        self.R = (
            R.to(self.device)
            if isinstance(R, torch.Tensor)
            else torch.tensor(R, dtype=torch.float32, device=self.device)
        )
        self.T = (
            T.to(self.device)
            if isinstance(T, torch.Tensor)
            else torch.tensor(T, dtype=torch.float32, device=self.device)
        )
        self.focal_length = (
            focal_length.to(self.device)
            if isinstance(focal_length, torch.Tensor)
            else torch.tensor(focal_length, dtype=torch.float32, device=self.device)
        )
        self.principal_point = (
            principal_point.to(self.device)
            if isinstance(principal_point, torch.Tensor)
            else torch.tensor(principal_point, dtype=torch.float32, device=self.device)
        )

    def transform_world_to_camera_space(self, points):
        """
        Transform world-space points into camera space.

        Follows the convention: ``X_cam = X_world @ R + T``.

        Parameters
        ----------
        points : torch.Tensor
                 World-space points, shape ``(N, 3)``.

        Returns
        -------
        cam_points : torch.Tensor
                     Camera-space points, shape ``(N, 3)``.
        """
        R = self.R[0] if self.R.dim() == 3 else self.R
        T = self.T[0] if self.T.dim() == 2 else self.T
        return points @ R + T

    def get_camera_center(self):
        """
        Compute the camera centre in world coordinates.

        Returns
        -------
        center : torch.Tensor
                 Camera centre, shape ``(1, 3)``.
        """
        R = self.R[0] if self.R.dim() == 3 else self.R
        T = self.T[0] if self.T.dim() == 2 else self.T
        center = -T @ R.transpose(0, 1)
        return center.unsqueeze(0)

get_camera_center()

Compute the camera centre in world coordinates.

Returns:

  • center ( Tensor ) –

    Camera centre, shape (1, 3).

Source code in odak/learn/tools/camera.py
def get_camera_center(self):
    """
    Compute the camera centre in world coordinates.

    Returns
    -------
    center : torch.Tensor
             Camera centre, shape ``(1, 3)``.
    """
    R = self.R[0] if self.R.dim() == 3 else self.R
    T = self.T[0] if self.T.dim() == 2 else self.T
    center = -T @ R.transpose(0, 1)
    return center.unsqueeze(0)

transform_world_to_camera_space(points)

Transform world-space points into camera space.

Follows the convention: X_cam = X_world @ R + T.

Parameters:

  • points (Tensor) –
     World-space points, shape ``(N, 3)``.
    

Returns:

  • cam_points ( Tensor ) –

    Camera-space points, shape (N, 3).

Source code in odak/learn/tools/camera.py
def transform_world_to_camera_space(self, points):
    """
    Transform world-space points into camera space.

    Follows the convention: ``X_cam = X_world @ R + T``.

    Parameters
    ----------
    points : torch.Tensor
             World-space points, shape ``(N, 3)``.

    Returns
    -------
    cam_points : torch.Tensor
                 Camera-space points, shape ``(N, 3)``.
    """
    R = self.R[0] if self.R.dim() == 3 else self.R
    T = self.T[0] if self.T.dim() == 2 else self.T
    return points @ R + T
Can I load Gaussians from a trained checkpoint instead of random initialisation?

Yes. Use init_type="gaussians" with a path to a .pth checkpoint:

gaussians = Gaussians(
    init_type="gaussians",
    device="cuda",
    load_path="path/to/checkpoint.pth",
    args_prop=args,
)

The checkpoint should contain the keys means, pre_act_quats, pre_act_scales, colours, pre_act_phase, pre_act_opacities, and pre_act_plane_assignment.

Can I initialise from a point cloud?

Yes. Use init_type="point" with a dictionary containing positions and colors tensors:

pointcloud_data = {
    "positions": positions_tensor,  # (N, 3)
    "colors": colors_tensor,        # (N, 3)
}
gaussians = Gaussians(
    init_type="point",
    device="cpu",
    pointcloud_data=pointcloud_data,
    args_prop=args,
)