Skip to content

odak.learn.wave

angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate convolution with Angular Spectrum method for beam propagation.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def angular_spectrum(
                     field,
                     k,
                     distance,
                     dx,
                     wavelength,
                     zero_padding = False,
                     aperture = 1.
                    ):
    """
    A definition to calculate convolution with Angular Spectrum method for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Angular Spectrum',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.

Parameters:

  • field
               A complex field.
               The expected size is [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def band_limited_angular_spectrum(
                                  field,
                                  k,
                                  distance,
                                  dx,
                                  wavelength,
                                  zero_padding = False,
                                  aperture = 1.
                                 ):
    """
    A definition to calculate bandlimited angular spectrum based beam propagation. For more 
    `Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673`.

    Parameters
    ----------
    field            : torch.complex
                       A complex field.
                       The expected size is [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Bandlimited Angular Spectrum',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

custom(field, kernel, zero_padding=False, aperture=1.0)

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field [m x n].
    
  • kernel
               Custom complex kernel for beam propagation.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def custom(
           field,
           kernel,
           zero_padding = False,
           aperture = 1.
          ):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    kernel           : torch.complex
                       Custom complex kernel for beam propagation.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    if type(kernel) == type(None):
        H = torch.ones(field.shape).to(field.device)
    else:
        H = kernel * aperture
    U1 = torch.fft.fftshift(torch.fft.fft2(field)) * aperture
    if zero_padding == False:
        U2 = H * U1
    elif zero_padding == True:
        U2 = zero_pad(H * U1)
    result = torch.fft.ifft2(torch.fft.ifftshift(U2))
    return result

fraunhofer(field, k, distance, dx, wavelength)

A definition to calculate light transport usin Fraunhofer approximation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def fraunhofer(
               field,
               k,
               distance,
               dx,
               wavelength
              ):
    """
    A definition to calculate light transport usin Fraunhofer approximation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).
    """
    nv, nu = field.shape[-1], field.shape[-2]
    x = torch.linspace(-nv*dx/2, nv*dx/2, nv, dtype=torch.float32)
    y = torch.linspace(-nu*dx/2, nu*dx/2, nu, dtype=torch.float32)
    Y, X = torch.meshgrid(y, x, indexing='ij')
    Z = torch.pow(X, 2) + torch.pow(Y, 2)
    c = 1. / (1j * wavelength * distance) * torch.exp(1j * k * 0.5 / distance * Z)
    c = c.to(field.device)
    result = c * torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(field))) * dx ** 2
    return result

gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel')

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

Parameters:

  • field
               Complex field (MxN).
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • slm_range
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    
  • propagation_type (str, default: 'Transfer Function Fresnel' ) –
               Type of the propagation (see odak.learn.wave.propagate_beam).
    

Returns:

  • hologram ( cfloat ) –

    Calculated complex hologram.

  • reconstruction ( cfloat ) –

    Calculated reconstruction using calculated hologram.

Source code in odak/learn/wave/classical.py
def gerchberg_saxton(
                     field,
                     n_iterations,
                     distance,
                     dx,
                     wavelength,
                     slm_range = 6.28,
                     propagation_type = 'Transfer Function Fresnel'
                    ):
    """
    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

    Parameters
    ----------
    field            : torch.cfloat
                       Complex field (MxN).
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    slm_range        : float
                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    propagation_type : str
                       Type of the propagation (see odak.learn.wave.propagate_beam).

    Returns
    -------
    hologram         : torch.cfloat
                       Calculated complex hologram.
    reconstruction   : torch.cfloat
                       Calculated reconstruction using calculated hologram. 
    """
    k = wavenumber(wavelength)
    reconstruction = field
    for i in range(n_iterations):
        hologram = propagate_beam(
            reconstruction, k, -distance, dx, wavelength, propagation_type)
        reconstruction = propagate_beam(
            hologram, k, distance, dx, wavelength, propagation_type)
        reconstruction = set_amplitude(reconstruction, field)
    reconstruction = propagate_beam(
        hologram, k, distance, dx, wavelength, propagation_type)
    return hologram, reconstruction

get_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_angular_spectrum_kernel(
                                nu,
                                nv,
                                dx = 8e-6,
                                wavelength = 515e-9,
                                distance = 0.,
                                device = torch.device('cpu')
                               ):
    """
    Helper function for odak.learn.wave.angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    H = torch.exp(1j  * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))
    H = H.to(device)
    return H

get_band_limited_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.band_limited_angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( complex64 ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_band_limited_angular_spectrum_kernel(
                                             nu,
                                             nv,
                                             dx = 8e-6,
                                             wavelength = 515e-9,
                                             distance = 0.,
                                             device = torch.device('cpu')
                                            ):
    """
    Helper function for odak.learn.wave.band_limited_angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : torch.complex64
                         Complex kernel in Fourier domain.
    """
    x = dx * float(nu)
    y = dx * float(nv)
    fx = torch.linspace(
                        -1 / (2 * dx) + 0.5 / (2 * x),
                         1 / (2 * dx) - 0.5 / (2 * x),
                         nu,
                         dtype = torch.float32,
                         device = device
                        )
    fy = torch.linspace(
                        -1 / (2 * dx) + 0.5 / (2 * y),
                        1 / (2 * dx) - 0.5 / (2 * y),
                        nv,
                        dtype = torch.float32,
                        device = device
                       )
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    HH_exp = 2 * torch.pi * torch.sqrt(1 / wavelength ** 2 - (FX ** 2 + FY ** 2))
    distance = torch.tensor([distance], device = device)
    H_exp = torch.mul(HH_exp, distance)
    fx_max = 1 / torch.sqrt((2 * distance * (1 / x))**2 + 1) / wavelength
    fy_max = 1 / torch.sqrt((2 * distance * (1 / y))**2 + 1) / wavelength
    H_filter = ((torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).clone().detach()
    H = generate_complex_field(H_filter, H_exp)
    return H

get_impulse_response_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), scale=1, aperture_samples=[20, 20, 5, 5])

Helper function for odak.learn.wave.impulse_response_fresnel.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    
  • scale
                 Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    
  • aperture_samples
                 Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.
    

Returns:

  • H ( complex64 ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_impulse_response_fresnel_kernel(
                                        nu,
                                        nv,
                                        dx = 8e-6,
                                        wavelength = 515e-9,
                                        distance = 0.,
                                        device = torch.device('cpu'),
                                        scale = 1,
                                        aperture_samples = [20, 20, 5, 5]
                                       ):
    """
    Helper function for odak.learn.wave.impulse_response_fresnel.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().
    scale              : int
                         Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    aperture_samples   : list
                         Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.

    Returns
    -------
    H                  : torch.complex64
                         Complex kernel in Fourier domain.
    """
    k = wavenumber(wavelength)
    distance = torch.as_tensor(distance, device = device)
    length_x, length_y = (torch.tensor(dx * nu, device = device), torch.tensor(dx * nv, device = device))
    x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device)
    y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device)
    X, Y = torch.meshgrid(x, y, indexing = 'ij')
    wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device)
    wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device)
    h = torch.zeros(nu * scale, nv * scale, dtype = torch.complex64, device = device)
    pxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[2], device = device)
    pys = torch.linspace(- dx / 2., dx / 2., aperture_samples[3], device = device)
    for wx in tqdm(wxs):
        for wy in wys:
            for px in pxs:
                for py in pys:
                    r = (X + px - wx) ** 2 + (Y + py - wy) ** 2
                    h += 1. / (1j * wavelength * distance) * torch.exp(1j * k / (2 * distance) * r) 
    H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]
    return H

get_incoherent_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.angular_spectrum.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( float ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_incoherent_angular_spectrum_kernel(
                                           nu,
                                           nv,
                                           dx = 8e-6,
                                           wavelength = 515e-9,
                                           distance = 0.,
                                           device = torch.device('cpu')
                                          ):
    """
    Helper function for odak.learn.wave.angular_spectrum.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : float
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
    H = torch.exp(1j  * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))
    H_ptime = correlation_2d(H, H)
    H = H_ptime.to(device)
    return H

get_light_kernels(wavelengths, distances, pixel_pitches, resolution=[1080, 1920], resolution_factor=1, samples=[50, 50, 5, 5], propagation_type='Bandlimited Angular Spectrum', kernel_type='spatial', device=torch.device('cpu'))

Utility function to request a tensor filled with light transport kernels according to the given optical configurations.

Parameters:

  • wavelengths
                 A list of wavelengths.
    
  • distances
                 A list of propagation distances.
    
  • pixel_pitches
                 A list of pixel_pitches.
    
  • resolution
                 Resolution of the light transport kernel.
    
  • resolution_factor
                 If `Impulse Response Fresnel` propagation is used, this resolution factor could be set larger than one leading to higher resolution light transport kernels than the provided native `resolution`. For more, see odak.learn.wave.get_impulse_response_kernel().
    
  • samples
                 If `Impulse Response Fresnel` propagation is used, these sample counts will be used to calculate the light transport kernel. For more, see odak.learn.wave.get_impulse_response_kernel().
    
  • propagation_type
                 Propagation type. For more, see odak.learn.wave.propagate_beam().
    
  • kernel_type
                 If set to `spatial`, light transport kernels will be provided in space. But if set to `fourier`, these kernels will be provided in the Fourier domain.
    
  • device
                 Device used for computation (i.e., cpu, cuda).
    

Returns:

  • light_kernels_amplitude ( tensor ) –

    Amplitudes of the light kernels generated [w x d x p x m x n].

  • light_kernels_phase ( tensor ) –

    Phases of the light kernels generated [w x d x p x m x n].

  • light_kernels_complex ( tensor ) –

    Complex light kernels generated [w x d x p x m x n].

  • light_parameters ( tensor ) –

    Parameters of each pixel in light_kernels* [w x d x p x m x n x 5]. Last dimension contains, wavelengths, distances, pixel pitches, X and Y locations in order.

Source code in odak/learn/wave/classical.py
def get_light_kernels(
                      wavelengths,
                      distances,
                      pixel_pitches,
                      resolution = [1080, 1920],
                      resolution_factor = 1,
                      samples = [50, 50, 5, 5],
                      propagation_type = 'Bandlimited Angular Spectrum',
                      kernel_type = 'spatial',
                      device = torch.device('cpu')
                     ):
    """
    Utility function to request a tensor filled with light transport kernels according to the given optical configurations.

    Parameters
    ----------
    wavelengths        : list
                         A list of wavelengths.
    distances          : list
                         A list of propagation distances.
    pixel_pitches      : list
                         A list of pixel_pitches.
    resolution         : list
                         Resolution of the light transport kernel.
    resolution_factor  : int
                         If `Impulse Response Fresnel` propagation is used, this resolution factor could be set larger than one leading to higher resolution light transport kernels than the provided native `resolution`. For more, see odak.learn.wave.get_impulse_response_kernel().
    samples            : list
                         If `Impulse Response Fresnel` propagation is used, these sample counts will be used to calculate the light transport kernel. For more, see odak.learn.wave.get_impulse_response_kernel().
    propagation_type   : str
                         Propagation type. For more, see odak.learn.wave.propagate_beam().
    kernel_type        : str
                         If set to `spatial`, light transport kernels will be provided in space. But if set to `fourier`, these kernels will be provided in the Fourier domain.
    device             : torch.device
                         Device used for computation (i.e., cpu, cuda).

    Returns
    -------
    light_kernels_amplitude : torch.tensor
                              Amplitudes of the light kernels generated [w x d x p x m x n].
    light_kernels_phase     : torch.tensor
                              Phases of the light kernels generated [w x d x p x m x n].
    light_kernels_complex   : torch.tensor
                              Complex light kernels generated [w x d x p x m x n].
    light_parameters        : torch.tensor
                              Parameters of each pixel in light_kernels* [w x d x p x m x n x 5].  Last dimension contains, wavelengths, distances, pixel pitches, X and Y locations in order.
    """
    if propagation_type != 'Impulse Response Fresnel' and propagation_type != 'Seperable Impulse Response Fresnel':
        resolution_factor = 1
    light_kernels_complex = torch.zeros(            
                                        len(wavelengths),
                                        len(distances),
                                        len(pixel_pitches),
                                        resolution[0] * resolution_factor,
                                        resolution[1] * resolution_factor,
                                        dtype = torch.complex64,
                                        device = device
                                       )
    light_parameters = torch.zeros(
                                   len(wavelengths),
                                   len(distances),
                                   len(pixel_pitches),
                                   resolution[0] * resolution_factor,
                                   resolution[1] * resolution_factor,
                                   5,
                                   dtype = torch.float32,
                                   device = device
                                  )
    for wavelength_id, distance_id, pixel_pitch_id in itertools.product(
                                                                        range(len(wavelengths)),
                                                                        range(len(distances)),
                                                                        range(len(pixel_pitches)),
                                                                       ):
        pixel_pitch = pixel_pitches[pixel_pitch_id]
        wavelength = wavelengths[wavelength_id]
        distance = distances[distance_id]
        kernel_fourier = get_propagation_kernel(
                                                nu = resolution[0],
                                                nv = resolution[1],
                                                dx = pixel_pitch,
                                                wavelength = wavelength,
                                                distance = distance,
                                                device = device,
                                                propagation_type = propagation_type,
                                                scale = resolution_factor,
                                                samples = samples
                                               )
        if kernel_type == 'spatial':
            kernel = torch.fft.ifftshift(torch.fft.ifft2(kernel_fourier))
        elif kernel_type == 'fourier':
            kernel = kernel_fourier
        else:
            logging.warning('Unknown kernel type requested.')
            raise ValueError('Unknown kernel type requested.')
        kernel_amplitude = calculate_amplitude(kernel)
        kernel_phase = calculate_phase(kernel) % (2 * torch.pi)
        light_kernels_complex[wavelength_id, distance_id, pixel_pitch_id] = kernel
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 0] = wavelength
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 1] = distance
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 2] = pixel_pitch
        lims = [
                resolution[0] // 2 * pixel_pitch,
                resolution[1] // 2 * pixel_pitch 
               ]
        x = torch.linspace(-lims[0], lims[0], resolution[0] * resolution_factor, device = device)
        y = torch.linspace(-lims[1], lims[1], resolution[1] * resolution_factor, device = device)        
        X, Y = torch.meshgrid(x, y, indexing = 'ij')
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 3] = X
        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 4] = Y
    light_kernels_amplitude = calculate_amplitude(light_kernels_complex)
    light_kernels_phase = calculate_phase(light_kernels_complex) % (2. * torch.pi)
    return light_kernels_amplitude, light_kernels_phase, light_kernels_complex, light_parameters

get_point_wise_impulse_response_fresnel_kernel(aperture_points, aperture_field, target_points, resolution, resolution_factor=1, wavelength=5.15e-07, distance=0.0, randomization=False, device=torch.device('cpu'))

This function is a freeform point spread function calculation routine for an aperture defined with a complex field, aperture_field, and locations in space, aperture_points. The point spread function is calculated over provided points, target_points. The final result is reshaped to follow the provided resolution.

Parameters:

  • aperture_points
                       Points representing an aperture in Euler space (XYZ) [m x 3].
    
  • aperture_field
                       Complex field for each point provided by `aperture_points` [1 x m].
    
  • target_points
                       Target points where the propagated field will be calculated [n x 1].
    
  • resolution
                       Final resolution that the propagated field will be reshaped [X x Y].
    
  • resolution_factor
                       Scale with respect to `resolution` (e.g., scale = 2 leads to `2 x resolution` for the final complex field.
    
  • wavelength
                       Wavelength in meters.
    
  • randomization
                       If set `True`, this will help generate a noisy response roughly approximating a real life case, where imperfections occur.
    
  • distance
                       Distance in meters.
    

Returns:

  • h ( float ) –

    Complex field in spatial domain.

Source code in odak/learn/wave/classical.py
def get_point_wise_impulse_response_fresnel_kernel(
                                                   aperture_points,
                                                   aperture_field,
                                                   target_points,
                                                   resolution,
                                                   resolution_factor = 1,
                                                   wavelength = 515e-9,
                                                   distance = 0.,
                                                   randomization = False,
                                                   device = torch.device('cpu')
                                                  ):
    """
    This function is a freeform point spread function calculation routine for an aperture defined with a complex field, `aperture_field`, and locations in space, `aperture_points`.
    The point spread function is calculated over provided points, `target_points`.
    The final result is reshaped to follow the provided `resolution`.

    Parameters
    ----------
    aperture_points          : torch.tensor
                               Points representing an aperture in Euler space (XYZ) [m x 3].
    aperture_field           : torch.tensor
                               Complex field for each point provided by `aperture_points` [1 x m].
    target_points            : torch.tensor
                               Target points where the propagated field will be calculated [n x 1].
    resolution               : list
                               Final resolution that the propagated field will be reshaped [X x Y].
    resolution_factor        : int
                               Scale with respect to `resolution` (e.g., scale = 2 leads to `2 x resolution` for the final complex field.
    wavelength               : float
                               Wavelength in meters.
    randomization            : bool
                               If set `True`, this will help generate a noisy response roughly approximating a real life case, where imperfections occur.
    distance                 : float
                               Distance in meters.

    Returns
    -------
    h                        : float
                               Complex field in spatial domain.
    """
    device = aperture_field.device
    k = wavenumber(wavelength)
    if randomization:
        pp = [
              aperture_points[:, 0].max() - aperture_points[:, 0].min(),
              aperture_points[:, 1].max() - aperture_points[:, 1].min()
             ]
        target_points[:, 0] = target_points[:, 0] - torch.randn(target_points[:, 0].shape) * pp[0]
        target_points[:, 1] = target_points[:, 1] - torch.randn(target_points[:, 1].shape) * pp[1]
    deltaX = aperture_points[:, 0].unsqueeze(0) - target_points[:, 0].unsqueeze(-1)
    deltaY = aperture_points[:, 1].unsqueeze(0) - target_points[:, 1].unsqueeze(-1)
    r = deltaX ** 2 + deltaY ** 2
    h = torch.exp(1j * k / (2 * distance) * r) * aperture_field
    h = torch.sum(h, dim = 1).reshape(resolution[0] * resolution_factor, resolution[1] * resolution_factor)
    h = 1. / (1j * wavelength * distance) * h
    return h

get_propagation_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), propagation_type='Bandlimited Angular Spectrum', scale=1, samples=[20, 20, 5, 5])

Get propagation kernel for the propagation type.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    
  • propagation_type
                 Propagation type.
                 The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    
  • scale
                 Scale factor for scaled beam propagation.
    
  • samples
                 When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
    

Returns:

  • kernel ( tensor ) –

    Complex kernel for the given propagation type.

Source code in odak/learn/wave/classical.py
def get_propagation_kernel(
                           nu, 
                           nv, 
                           dx = 8e-6, 
                           wavelength = 515e-9, 
                           distance = 0., 
                           device = torch.device('cpu'), 
                           propagation_type = 'Bandlimited Angular Spectrum', 
                           scale = 1,
                           samples = [20, 20, 5, 5]
                          ):
    """
    Get propagation kernel for the propagation type.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().
    propagation_type   : str
                         Propagation type.
                         The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    scale              : int
                         Scale factor for scaled beam propagation.
    samples            : list
                         When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.


    Returns
    -------
    kernel             : torch.tensor
                         Complex kernel for the given propagation type.
    """                                                      
    logging.warning('Requested propagation kernel size for {} method with {} m distance, {} m pixel pitch, {} m wavelength, {} x {} resolutions, x{} scale and {} samples.'.format(propagation_type, distance, dx, wavelength, nu, nv, scale, samples))
    if propagation_type == 'Bandlimited Angular Spectrum':
        kernel = get_band_limited_angular_spectrum_kernel(
                                                          nu = nu,
                                                          nv = nv,
                                                          dx = dx,
                                                          wavelength = wavelength,
                                                          distance = distance,
                                                          device = device
                                                         )
    elif propagation_type == 'Angular Spectrum':
        kernel = get_angular_spectrum_kernel(
                                             nu = nu,
                                             nv = nv,
                                             dx = dx,
                                             wavelength = wavelength,
                                             distance = distance,
                                             device = device
                                            )
    elif propagation_type == 'Transfer Function Fresnel':
        kernel = get_transfer_function_fresnel_kernel(
                                                      nu = nu,
                                                      nv = nv,
                                                      dx = dx,
                                                      wavelength = wavelength,
                                                      distance = distance,
                                                      device = device
                                                     )
    elif propagation_type == 'Impulse Response Fresnel':
        kernel = get_impulse_response_fresnel_kernel(
                                                     nu = nu, 
                                                     nv = nv, 
                                                     dx = dx, 
                                                     wavelength = wavelength,
                                                     distance = distance,
                                                     device =  device,
                                                     scale = scale,
                                                     aperture_samples = samples
                                                    )
    elif propagation_type == 'Incoherent Angular Spectrum':
        kernel = get_incoherent_angular_spectrum_kernel(
                                                        nu = nu,
                                                        nv = nv, 
                                                        dx = dx, 
                                                        wavelength = wavelength, 
                                                        distance = distance,
                                                        device = device
                                                       )
    elif propagation_type == 'Seperable Impulse Response Fresnel':
        kernel, _, _, _ = get_seperable_impulse_response_fresnel_kernel(
                                                                        nu = nu,
                                                                        nv = nv,
                                                                        dx = dx,
                                                                        wavelength = wavelength,
                                                                        distance = distance,
                                                                        device = device,
                                                                        scale = scale,
                                                                        aperture_samples = samples
                                                                       )
    else:
        logging.warning('Propagation type not recognized')
        assert True == False
    return kernel

get_seperable_impulse_response_fresnel_kernel(nu, nv, dx=3.74e-06, wavelength=5.15e-07, distance=0.0, scale=1, aperture_samples=[50, 50, 5, 5], device=torch.device('cpu'))

Returns impulse response fresnel kernel in separable form.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    
  • scale
                 Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    
  • aperture_samples
                 Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.
    

Returns:

  • H ( complex64 ) –

    Complex kernel in Fourier domain.

  • h ( complex64 ) –

    Complex kernel in spatial domain.

  • h_x ( complex64 ) –

    1D complex kernel in spatial domain along X axis.

  • h_y ( complex64 ) –

    1D complex kernel in spatial domain along Y axis.

Source code in odak/learn/wave/classical.py
def get_seperable_impulse_response_fresnel_kernel(
                                                  nu,
                                                  nv,
                                                  dx = 3.74e-6,
                                                  wavelength = 515e-9,
                                                  distance = 0.,
                                                  scale = 1,
                                                  aperture_samples = [50, 50, 5, 5],
                                                  device = torch.device('cpu')
                                                 ):
    """
    Returns impulse response fresnel kernel in separable form.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().
    scale              : int
                         Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    aperture_samples   : list
                         Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.

    Returns
    -------
    H                  : torch.complex64
                         Complex kernel in Fourier domain.
    h                  : torch.complex64
                         Complex kernel in spatial domain.
    h_x                : torch.complex64
                         1D complex kernel in spatial domain along X axis.
    h_y                : torch.complex64
                         1D complex kernel in spatial domain along Y axis.
    """
    k = wavenumber(wavelength)
    distance = torch.as_tensor(distance, device = device)
    length_x, length_y = (
                          torch.tensor(dx * nu, device = device),
                          torch.tensor(dx * nv, device = device)
                         )
    x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device)
    y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device)
    wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device).unsqueeze(0).unsqueeze(0)
    wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device).unsqueeze(0).unsqueeze(-1)
    pxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[2], device = device).unsqueeze(0).unsqueeze(-1)
    pys = torch.linspace(- dx / 2., dx / 2., aperture_samples[3], device = device).unsqueeze(0).unsqueeze(0)
    wxs = (wxs - pxs).reshape(1, -1).unsqueeze(-1)
    wys = (wys - pys).reshape(1, -1).unsqueeze(1)

    X = x.unsqueeze(-1).unsqueeze(-1)
    Y = y[y.shape[0] // 2].unsqueeze(-1).unsqueeze(-1)
    r_x = (X + wxs) ** 2
    r_y = (Y + wys) ** 2
    r = r_x + r_y
    h_x = torch.exp(1j * k / (2 * distance) * r)
    h_x = torch.sum(h_x, axis = (1, 2))

    if nu != nv:
        X = x[x.shape[0] // 2].unsqueeze(-1).unsqueeze(-1)
        Y = y.unsqueeze(-1).unsqueeze(-1)
        r_x = (X + wxs) ** 2
        r_y = (Y + wys) ** 2
        r = r_x + r_y
        h_y = torch.exp(1j * k * r / (2 * distance))
        h_y = torch.sum(h_y, axis = (1, 2))
    else:
        h_y = h_x.detach().clone()
    h = torch.exp(1j * k * distance) / (1j * wavelength * distance) * h_x.unsqueeze(1) * h_y.unsqueeze(0)
    H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]
    return H, h, h_x, h_y

get_transfer_function_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))

Helper function for odak.learn.wave.transfer_function_fresnel.

Parameters:

  • nu
                 Resolution at X axis in pixels.
    
  • nv
                 Resolution at Y axis in pixels.
    
  • dx
                 Pixel pitch in meters.
    
  • wavelength
                 Wavelength in meters.
    
  • distance
                 Distance in meters.
    
  • device
                 Device, for more see torch.device().
    

Returns:

  • H ( complex64 ) –

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_transfer_function_fresnel_kernel(
                                         nu,
                                         nv,
                                         dx = 8e-6,
                                         wavelength = 515e-9,
                                         distance = 0.,
                                         device = torch.device('cpu')
                                        ):
    """
    Helper function for odak.learn.wave.transfer_function_fresnel.

    Parameters
    ----------
    nu                 : int
                         Resolution at X axis in pixels.
    nv                 : int
                         Resolution at Y axis in pixels.
    dx                 : float
                         Pixel pitch in meters.
    wavelength         : float
                         Wavelength in meters.
    distance           : float
                         Distance in meters.
    device             : torch.device
                         Device, for more see torch.device().


    Returns
    -------
    H                  : torch.complex64
                         Complex kernel in Fourier domain.
    """
    distance = torch.tensor([distance]).to(device)
    fx = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nu, dtype = torch.float32, device = device)
    fy = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nv, dtype = torch.float32, device = device)
    FY, FX = torch.meshgrid(fx, fy, indexing = 'ij')
    k = wavenumber(wavelength)
    H = torch.exp(-1j * distance * (k - torch.pi * wavelength * (FX ** 2 + FY ** 2)))
    return H

impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0, scale=1, samples=[20, 20, 5, 5])

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    
  • scale
               Resolution factor to scale generated kernel.
    
  • samples
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def impulse_response_fresnel(
                             field,
                             k,
                             distance,
                             dx,
                             wavelength,
                             zero_padding = False,
                             aperture = 1.,
                             scale = 1,
                             samples = [20, 20, 5, 5]
                            ):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].
    scale            : int
                       Resolution factor to scale generated kernel.
    samples          : list
                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Impulse Response Fresnel',
                               device = field.device,
                               scale = scale,
                               samples = samples
                              )
    if scale > 1:
        field_amplitude = calculate_amplitude(field)
        field_phase = calculate_phase(field)
        field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device)
        field_scale_phase = torch.zeros_like(field_scale_amplitude)
        field_scale_amplitude[::scale, ::scale] = field_amplitude
        field_scale_phase[::scale, ::scale] = field_phase
        field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)
    else:
        field_scale = field
    result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture)
    return result

incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate incoherent beam propagation with Angular Spectrum method.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def incoherent_angular_spectrum(
                                field,
                                k,
                                distance,
                                dx,
                                wavelength,
                                zero_padding = False,
                                aperture = 1.
                               ):
    """
    A definition to calculate incoherent beam propagation with Angular Spectrum method.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Incoherent Angular Spectrum',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

point_wise(target, wavelength, distance, dx, device, lens_size=401)

Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

Parameters:

  • target
               float input target to be converted into a hologram (Target should be in range of 0 and 1).
    
  • wavelength
               Wavelength of the electric field.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • device
               Device type (cuda or cpu)`.
    
  • lens_size
               Size of lens for masking sub holograms(in pixels).
    

Returns:

  • hologram ( cfloat ) –

    Calculated complex hologram.

Source code in odak/learn/wave/classical.py
def point_wise(
               target,
               wavelength,
               distance,
               dx,
               device,
               lens_size=401
              ):
    """
    Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

    Parameters
    ----------
    target           : torch.float
                       float input target to be converted into a hologram (Target should be in range of 0 and 1).
    wavelength       : float
                       Wavelength of the electric field.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    device           : torch.device
                       Device type (cuda or cpu)`.
    lens_size        : int
                       Size of lens for masking sub holograms(in pixels).

    Returns
    -------
    hologram         : torch.cfloat
                       Calculated complex hologram.
    """
    target = zero_pad(target)
    nx, ny = target.shape
    k = wavenumber(wavelength)
    ones = torch.ones(target.shape, requires_grad=False).to(device)
    x = torch.linspace(-nx/2, nx/2, nx).to(device)
    y = torch.linspace(-ny/2, ny/2, ny).to(device)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    Z = (X**2+Y**2)**0.5
    mask = (torch.abs(Z) <= lens_size)
    mask[mask > 1] = 1
    fz = quadratic_phase_function(nx, ny, k, focal=-distance, dx=dx).to(device)
    A = torch.nan_to_num(target**0.5, nan=0.0)
    fz = mask*fz
    FA = torch.fft.fft2(torch.fft.fftshift(A))
    FFZ = torch.fft.fft2(torch.fft.fftshift(fz))
    H = torch.mul(FA, FFZ)
    hologram = torch.fft.ifftshift(torch.fft.ifft2(H))
    hologram = crop_center(hologram)
    return hologram

propagate_beam(field, k, distance, dx, wavelength, propagation_type='Bandlimited Angular Spectrum', kernel=None, zero_padding=[True, False, True], aperture=1.0, scale=1, samples=[20, 20, 5, 5])

Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

Parameters:

  • field
               Complex field [m x n].
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • propagation_type (str, default: 'Bandlimited Angular Spectrum' ) –
               Type of the propagation.
               The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    
  • kernel
               Custom complex kernel.
    
  • zero_padding
               Zero padding the input field if the first item in the list set True.
               Zero padding in the Fourier domain if the second item in the list set to True.
               Cropping the result with half resolution if the third item in the list is set to true.
               Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
    
  • aperture
               Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.
               If provided as a floating point 1, there will be no aperture in Fourier domain.
    
  • scale
               Resolution factor to scale generated kernel.
    
  • samples
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
    

Returns:

  • result ( complex ) –

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def propagate_beam(
                   field,
                   k,
                   distance,
                   dx,
                   wavelength,
                   propagation_type='Bandlimited Angular Spectrum',
                   kernel = None,
                   zero_padding = [True, False, True],
                   aperture = 1.,
                   scale = 1,
                   samples = [20, 20, 5, 5]
                  ):
    """
    Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

    Parameters
    ----------
    field            : torch.complex
                       Complex field [m x n].
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    propagation_type : str
                       Type of the propagation.
                       The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    kernel           : torch.complex
                       Custom complex kernel.
    zero_padding     : list
                       Zero padding the input field if the first item in the list set True.
                       Zero padding in the Fourier domain if the second item in the list set to True.
                       Cropping the result with half resolution if the third item in the list is set to true.
                       Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
    aperture         : torch.tensor
                       Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.
                       If provided as a floating point 1, there will be no aperture in Fourier domain.
    scale            : int
                       Resolution factor to scale generated kernel.
    samples          : list
                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.

    Returns
    -------
    result           : torch.complex
                       Final complex field [m x n].
    """
    if zero_padding[0]:
        field = zero_pad(field)
    if propagation_type == 'Angular Spectrum':
        result = angular_spectrum(
                                  field = field,
                                  k = k,
                                  distance = distance,
                                  dx = dx,
                                  wavelength = wavelength,
                                  zero_padding = zero_padding[1],
                                  aperture = aperture
                                 )
    elif propagation_type == 'Bandlimited Angular Spectrum':
        result = band_limited_angular_spectrum(
                                               field = field,
                                               k = k,
                                               distance = distance,
                                               dx = dx,
                                               wavelength = wavelength,
                                               zero_padding = zero_padding[1],
                                               aperture = aperture
                                              )
    elif propagation_type == 'Impulse Response Fresnel':
        result = impulse_response_fresnel(
                                          field = field,
                                          k = k,
                                          distance = distance,
                                          dx = dx,
                                          wavelength = wavelength,
                                          zero_padding = zero_padding[1],
                                          aperture = aperture,
                                          scale = scale,
                                          samples = samples
                                         )
    elif propagation_type == 'Seperable Impulse Response Fresnel':
        result = seperable_impulse_response_fresnel(
                                                    field = field,
                                                    k = k,
                                                    distance = distance,
                                                    dx = dx,
                                                    wavelength = wavelength,
                                                    zero_padding = zero_padding[1],
                                                    aperture = aperture,
                                                    scale = scale,
                                                    samples = samples
                                                   )
    elif propagation_type == 'Transfer Function Fresnel':
        result = transfer_function_fresnel(
                                           field = field,
                                           k = k,
                                           distance = distance,
                                           dx = dx,
                                           wavelength = wavelength,
                                           zero_padding = zero_padding[1],
                                           aperture = aperture
                                          )
    elif propagation_type == 'custom':
        result = custom(
                        field = field,
                        kernel = kernel,
                        zero_padding = zero_padding[1],
                        aperture = aperture
                       )
    elif propagation_type == 'Fraunhofer':
        result = fraunhofer(
                            field = field,
                            k = k,
                            distance = distance,
                            dx = dx,
                            wavelength = wavelength
                           )
    elif propagation_type == 'Incoherent Angular Spectrum':
        result = incoherent_angular_spectrum(
                                             field = field,
                                             k = k,
                                             distance = distance,
                                             dx = dx,
                                             wavelength = wavelength,
                                             zero_padding = zero_padding[1],
                                             aperture = aperture
                                            )
    else:
        logging.warning('Propagation type not recognized')
        assert True == False
    if zero_padding[2]:
        result = crop_center(result)
    return result

seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0, scale=1, samples=[20, 20, 5, 5])

A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    
  • scale
               Resolution factor to scale generated kernel.
    
  • samples
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def seperable_impulse_response_fresnel(
                                       field,
                                       k,
                                       distance,
                                       dx,
                                       wavelength,
                                       zero_padding = False,
                                       aperture = 1.,
                                       scale = 1,
                                       samples = [20, 20, 5, 5]
                                      ):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].
    scale            : int
                       Resolution factor to scale generated kernel.
    samples          : list
                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.

    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Seperable Impulse Response Fresnel',
                               device = field.device,
                               scale = scale,
                               samples = samples
                              )
    if scale > 1:
        field_amplitude = calculate_amplitude(field)
        field_phase = calculate_phase(field)
        field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device)
        field_scale_phase = torch.zeros_like(field_scale_amplitude)
        field_scale_amplitude[::scale, ::scale] = field_amplitude
        field_scale_phase[::scale, ::scale] = field_phase
        field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)
    else:
        field_scale = field
    result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture)
    return result

shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None)

Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in here and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

Parameters:

  • phase
               Phase value of a phase-only hologram.
    
  • depth_shift
               Distance in meters.
    
  • pixel_pitch
               Pixel pitch size in meters.
    
  • wavelength
               Wavelength of light.
    
  • propagation_type (str, default: 'Transfer Function Fresnel' ) –
               Beam propagation type. For more see odak.learn.wave.propagate_beam().
    
  • kernel_length
               Kernel length for the Gaussian blur kernel.
    
  • sigma
               Standard deviation for the Gaussian blur kernel.
    
  • amplitude
               Amplitude value of a complex hologram.
    
Source code in odak/learn/wave/classical.py
def shift_w_double_phase(
                         phase,
                         depth_shift,
                         pixel_pitch,
                         wavelength,
                         propagation_type = 'Transfer Function Fresnel',
                         kernel_length = 4,
                         sigma = 0.5,
                         amplitude = None
                        ):
    """
    Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in [here](https://github.com/liangs111/tensor_holography/blob/6fdb26561a4e554136c579fa57788bb5fc3cac62/optics.py#L131-L207) and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

    Parameters
    ----------
    phase            : torch.tensor
                       Phase value of a phase-only hologram.
    depth_shift      : float
                       Distance in meters.
    pixel_pitch      : float
                       Pixel pitch size in meters.
    wavelength       : float
                       Wavelength of light.
    propagation_type : str
                       Beam propagation type. For more see odak.learn.wave.propagate_beam().
    kernel_length    : int
                       Kernel length for the Gaussian blur kernel.
    sigma            : float
                       Standard deviation for the Gaussian blur kernel.
    amplitude        : torch.tensor
                       Amplitude value of a complex hologram.
    """
    if type(amplitude) == type(None):
        amplitude = torch.ones_like(phase)
    hologram = generate_complex_field(amplitude, phase)
    k = wavenumber(wavelength)
    hologram_padded = zero_pad(hologram)
    shifted_field_padded = propagate_beam(
                                          hologram_padded,
                                          k,
                                          depth_shift,
                                          pixel_pitch,
                                          wavelength,
                                          propagation_type
                                         )
    shifted_field = crop_center(shifted_field_padded)
    phase_shift = torch.exp(torch.tensor([-2 * torch.pi * depth_shift / wavelength]).to(phase.device))
    shift = torch.cos(phase_shift) + 1j * torch.sin(phase_shift)
    shifted_complex_hologram = shifted_field * shift

    if kernel_length > 0 and sigma >0:
        blur_kernel = generate_2d_gaussian(
                                           [kernel_length, kernel_length],
                                           [sigma, sigma]
                                          ).to(phase.device)
        blur_kernel = blur_kernel.unsqueeze(0)
        blur_kernel = blur_kernel.unsqueeze(0)
        field_imag = torch.imag(shifted_complex_hologram)
        field_real = torch.real(shifted_complex_hologram)
        field_imag = field_imag.unsqueeze(0)
        field_imag = field_imag.unsqueeze(0)
        field_real = field_real.unsqueeze(0)
        field_real = field_real.unsqueeze(0)
        field_imag = torch.nn.functional.conv2d(field_imag, blur_kernel, padding='same')
        field_real = torch.nn.functional.conv2d(field_real, blur_kernel, padding='same')
        shifted_complex_hologram = torch.complex(field_real, field_imag)
        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)
        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)

    shifted_amplitude = calculate_amplitude(shifted_complex_hologram)
    shifted_amplitude = shifted_amplitude / torch.amax(shifted_amplitude, [0,1])

    shifted_phase = calculate_phase(shifted_complex_hologram)
    phase_zero_mean = shifted_phase - torch.mean(shifted_phase)

    phase_offset = torch.arccos(shifted_amplitude)
    phase_low = phase_zero_mean - phase_offset
    phase_high = phase_zero_mean + phase_offset

    phase_only = torch.zeros_like(phase)
    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
    return phase_only

stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type='Bandlimited Angular Spectrum', n_iteration=100, loss_function=None, learning_rate=0.1)

Definition to generate phase and reconstruction from target image via stochastic gradient descent.

Parameters:

  • target
                        Target field amplitude [m x n].
                        Keep the target values between zero and one.
    
  • wavelength
                        Set if the converted array requires gradient.
    
  • distance
                        Hologram plane distance wrt SLM plane.
    
  • pixel_pitch
                        SLM pixel pitch in meters.
    
  • propagation_type
                        Type of the propagation (see odak.learn.wave.propagate_beam()).
    
  • n_iteration
                        Number of iteration.
    
  • loss_function
                        If none it is set to be l2 loss.
    
  • learning_rate
                        Learning rate.
    

Returns:

  • hologram ( Tensor ) –

    Phase only hologram as torch array

  • reconstruction_intensity ( Tensor ) –

    Reconstruction as torch array

Source code in odak/learn/wave/classical.py
def stochastic_gradient_descent(
                                target,
                                wavelength,
                                distance,
                                pixel_pitch,
                                propagation_type = 'Bandlimited Angular Spectrum',
                                n_iteration = 100,
                                loss_function = None,
                                learning_rate = 0.1
                               ):
    """
    Definition to generate phase and reconstruction from target image via stochastic gradient descent.

    Parameters
    ----------
    target                    : torch.Tensor
                                Target field amplitude [m x n].
                                Keep the target values between zero and one.
    wavelength                : double
                                Set if the converted array requires gradient.
    distance                  : double
                                Hologram plane distance wrt SLM plane.
    pixel_pitch               : float
                                SLM pixel pitch in meters.
    propagation_type          : str
                                Type of the propagation (see odak.learn.wave.propagate_beam()).
    n_iteration:              : int
                                Number of iteration.
    loss_function:            : function
                                If none it is set to be l2 loss.
    learning_rate             : float
                                Learning rate.

    Returns
    -------
    hologram                  : torch.Tensor
                                Phase only hologram as torch array

    reconstruction_intensity  : torch.Tensor
                                Reconstruction as torch array

    """
    phase = torch.randn_like(target, requires_grad = True)
    k = wavenumber(wavelength)
    optimizer = torch.optim.Adam([phase], lr = learning_rate)
    if type(loss_function) == type(None):
        loss_function = torch.nn.MSELoss()
    t = tqdm(range(n_iteration), leave = False, dynamic_ncols = True)
    for i in t:
        optimizer.zero_grad()
        hologram = generate_complex_field(1., phase)
        reconstruction = propagate_beam(
                                        hologram, 
                                        k, 
                                        distance, 
                                        pixel_pitch, 
                                        wavelength, 
                                        propagation_type, 
                                        zero_padding = [True, False, True]
                                       )
        reconstruction_intensity = calculate_amplitude(reconstruction) ** 2
        loss = loss_function(reconstruction_intensity, target)
        description = "Loss:{:.4f}".format(loss.item())
        loss.backward(retain_graph = True)
        optimizer.step()
        t.set_description(description)
    logging.warning(description)
    torch.no_grad()
    hologram = generate_complex_field(1., phase)
    reconstruction = propagate_beam(
                                    hologram, 
                                    k, 
                                    distance, 
                                    pixel_pitch, 
                                    wavelength, 
                                    propagation_type, 
                                    zero_padding = [True, False, True]
                                   )
    return hologram, reconstruction

transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field
               Complex field (MxN).
    
  • k
               Wave number of a wave, see odak.wave.wavenumber for more.
    
  • distance
               Propagation distance.
    
  • dx
               Size of one single pixel in the field grid (in meters).
    
  • wavelength
               Wavelength of the electric field.
    
  • zero_padding
               Zero pad in Fourier domain.
    
  • aperture
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
               The default is one, but an aperture could be as large as input field [m x n].
    

Returns:

  • result ( complex ) –

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def transfer_function_fresnel(
                              field,
                              k,
                              distance,
                              dx,
                              wavelength,
                              zero_padding = False,
                              aperture = 1.
                             ):
    """
    A definition to calculate convolution based Fresnel approximation for beam propagation.

    Parameters
    ----------
    field            : torch.complex
                       Complex field (MxN).
    k                : odak.wave.wavenumber
                       Wave number of a wave, see odak.wave.wavenumber for more.
    distance         : float
                       Propagation distance.
    dx               : float
                       Size of one single pixel in the field grid (in meters).
    wavelength       : float
                       Wavelength of the electric field.
    zero_padding     : bool
                       Zero pad in Fourier domain.
    aperture         : torch.tensor
                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
                       The default is one, but an aperture could be as large as input field [m x n].


    Returns
    -------
    result           : torch.complex
                       Final complex field (MxN).

    """
    H = get_propagation_kernel(
                               nu = field.shape[-2], 
                               nv = field.shape[-1], 
                               dx = dx, 
                               wavelength = wavelength, 
                               distance = distance, 
                               propagation_type = 'Transfer Function Fresnel',
                               device = field.device
                              )
    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
    return result

blazed_grating(nx, ny, levels=2, axis='x')

A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.

Parameters:

  • nx
           Size of the output along X.
    
  • ny
           Size of the output along Y.
    
  • levels
           Number of pixels.
    
  • axis
           Axis of glazed grating. It could be `x` or `y`.
    
Source code in odak/learn/wave/lens.py
def blazed_grating(nx, ny, levels = 2, axis = 'x'):
    """
    A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.


    Parameters
    ----------
    nx           : int
                   Size of the output along X.
    ny           : int
                   Size of the output along Y.
    levels       : int
                   Number of pixels.
    axis         : str
                   Axis of glazed grating. It could be `x` or `y`.

    """
    if levels < 2:
        levels = 2
    x = (torch.abs(torch.arange(-nx, 0)) % levels) / levels * (2 * np.pi)
    y = (torch.abs(torch.arange(-ny, 0)) % levels) / levels * (2 * np.pi)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    if axis == 'x':
        blazed_grating = torch.exp(1j * X)
    elif axis == 'y':
        blazed_grating = torch.exp(1j * Y)
    return blazed_grating

linear_grating(nx, ny, every=2, add=None, axis='x')

A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

Parameters:

  • nx
         Size of the output along X.
    
  • ny
         Size of the output along Y.
    
  • every
         Add the add value at every given number.
    
  • add
         Angle to be added.
    
  • axis
         Axis eiter X,Y or both.
    

Returns:

  • field ( tensor ) –

    Linear grating term.

Source code in odak/learn/wave/lens.py
def linear_grating(nx, ny, every = 2, add = None, axis = 'x'):
    """
    A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

    Parameters
    ----------
    nx         : int
                 Size of the output along X.
    ny         : int
                 Size of the output along Y.
    every      : int
                 Add the add value at every given number.
    add        : float
                 Angle to be added.
    axis       : string
                 Axis eiter X,Y or both.

    Returns
    ----------
    field      : torch.tensor
                 Linear grating term.
    """
    if isinstance(add, type(None)):
        add = np.pi
    grating = torch.zeros((nx, ny), dtype=torch.complex64)
    if axis == 'x':
        grating[::every, :] = torch.exp(torch.tensor(1j*add))
    if axis == 'y':
        grating[:, ::every] = torch.exp(torch.tensor(1j*add))
    if axis == 'xy':
        checker = np.indices((nx, ny)).sum(axis=0) % every
        checker = torch.from_numpy(checker)
        checker += 1
        checker = checker % 2
        grating = torch.exp(1j*checker*add)
    return grating

prism_grating(nx, ny, k, angle, dx=0.001, axis='x', phase_offset=0.0)

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

Parameters:

  • nx
           Size of the output along X.
    
  • ny
           Size of the output along Y.
    
  • k
           See odak.wave.wavenumber for more.
    
  • angle
           Tilt angle of the prism in degrees.
    
  • dx
           Pixel pitch.
    
  • axis
           Axis of the prism.
    
  • phase_offset (float, default: 0.0 ) –
           Phase offset in angles. Default is zero.
    

Returns:

  • prism ( tensor ) –

    Generated phase function for a prism.

Source code in odak/learn/wave/lens.py
def prism_grating(nx, ny, k, angle, dx = 0.001, axis = 'x', phase_offset = 0.):
    """
    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

    Parameters
    ----------
    nx           : int
                   Size of the output along X.
    ny           : int
                   Size of the output along Y.
    k            : odak.wave.wavenumber
                   See odak.wave.wavenumber for more.
    angle        : float
                   Tilt angle of the prism in degrees.
    dx           : float
                   Pixel pitch.
    axis         : str
                   Axis of the prism.
    phase_offset : float
                   Phase offset in angles. Default is zero.

    Returns
    ----------
    prism        : torch.tensor
                   Generated phase function for a prism.
    """
    angle = torch.deg2rad(torch.tensor([angle]))
    phase_offset = torch.deg2rad(torch.tensor([phase_offset]))
    x = torch.arange(0, nx) * dx
    y = torch.arange(0, ny) * dx
    X, Y = torch.meshgrid(x, y, indexing='ij')
    if axis == 'y':
        phase = k * torch.sin(angle) * Y + phase_offset
        prism = torch.exp(-1j * phase)
    elif axis == 'x':
        phase = k * torch.sin(angle) * X + phase_offset
        prism = torch.exp(-1j * phase)
    return prism

quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0])

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

Parameters:

  • nx
         Size of the output along X.
    
  • ny
         Size of the output along Y.
    
  • k
         See odak.wave.wavenumber for more.
    
  • focal
         Focal length of the quadratic phase function.
    
  • dx
         Pixel pitch.
    
  • offset
         Deviation from the center along X and Y axes.
    

Returns:

  • qpf ( tensor ) –

    Generated quadratic phase function.

Source code in odak/learn/wave/lens.py
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):
    """ 
    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

    Parameters
    ----------
    nx         : int
                 Size of the output along X.
    ny         : int
                 Size of the output along Y.
    k          : odak.wave.wavenumber
                 See odak.wave.wavenumber for more.
    focal      : float
                 Focal length of the quadratic phase function.
    dx         : float
                 Pixel pitch.
    offset     : list
                 Deviation from the center along X and Y axes.

    Returns
    -------
    qpf        : torch.tensor
                 Generated quadratic phase function.
    """
    size = [nx, ny]
    x = torch.linspace(-size[0] * dx / 2, size[0] * dx / 2, size[0]) - offset[1] * dx
    y = torch.linspace(-size[1] * dx / 2, size[1] * dx / 2, size[1]) - offset[0] * dx
    X, Y = torch.meshgrid(x, y, indexing = 'ij')
    Z = X ** 2 + Y ** 2
    qpf = torch.exp(-0.5j * k / focal * Z)
    return qpf

multiplane_loss

Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

Source code in odak/learn/wave/loss.py
class multiplane_loss():
    """
    Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.
    """

    def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
                 target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6], 
                 multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
        """
        Parameters
        ----------
        target_image      : torch.tensor
                            Color target image [3 x m x n].
        target_depth      : torch.tensor
                            Monochrome target depth, same resolution as target_image.
        target_blur_size  : int
                            Maximum target blur size.
        blur_ratio        : float
                            Blur ratio, a value between zero and one.
        number_of_planes  : int
                            Number of planes.
        weights           : list
                            Weights of the loss function.
        multiplier        : float
                            Multiplier to multipy with targets.
        scheme            : str
                            The type of the loss, `naive` without defocus or `defocus` with defocus.
        reduction         : str
                            Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
        device            : torch.device
                            Device to be used (e.g., cuda, cpu, opencl).
        """
        self.device = device
        self.target_image     = target_image.float().to(self.device)
        self.target_depth     = target_depth.float().to(self.device)
        self.target_blur_size = target_blur_size
        if self.target_blur_size % 2 == 0:
            self.target_blur_size += 1
        self.number_of_planes = number_of_planes
        self.multiplier       = multiplier
        self.weights          = weights
        self.reduction        = reduction
        self.blur_ratio       = blur_ratio
        self.set_targets()
        if scheme == 'defocus':
            self.add_defocus_blur()
        self.loss_function = torch.nn.MSELoss(reduction = self.reduction)

    def get_targets(self):
        """
        Returns
        -------
        targets           : torch.tensor
                            Returns a copy of the targets.
        target_depth      : torch.tensor
                            Returns a copy of the normalized quantized depth map.

        """
        divider = self.number_of_planes - 1
        if divider == 0:
            divider = 1
        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider


    def set_targets(self):
        """
        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
        """
        self.target_depth = self.target_depth * (self.number_of_planes - 1)
        self.target_depth = torch.round(self.target_depth, decimals = 0)
        self.targets      = torch.zeros(
                                        self.number_of_planes,
                                        self.target_image.shape[0],
                                        self.target_image.shape[1],
                                        self.target_image.shape[2],
                                        requires_grad = False,
                                        device = self.device
                                       )
        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
        self.masks        = torch.zeros_like(self.targets)
        for i in range(self.number_of_planes):
            for ch in range(self.target_image.shape[0]):
                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
                new_target = self.target_image[ch] * mask
                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
                self.masks[i, ch] = mask.detach().clone() 


    def add_defocus_blur(self):
        """
        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
        """
        kernel_length = [self.target_blur_size, self.target_blur_size ]
        for ch in range(self.target_image.shape[0]):
            targets_cache = self.targets[:, ch].detach().clone()
            target = torch.sum(targets_cache, axis = 0)
            for i in range(self.number_of_planes):
                defocus = torch.zeros_like(targets_cache[i])
                for j in range(self.number_of_planes):
                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                    if torch.sum(targets_cache[j]) > 0:
                        if i == j:
                            nsigma = [0., 0.]
                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                        kernel = kernel / torch.sum(kernel)
                        kernel = kernel.unsqueeze(0).unsqueeze(0)
                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
                self.targets[i, ch] = defocus
        self.targets = self.targets.detach().clone() * self.multiplier


    def __call__(self, image, target, plane_id = None, inject_noise = False, noise_ratio = 1e-3):
        """
        Calculates the multiplane loss against a given target.

        Parameters
        ----------
        image         : torch.tensor
                        Image to compare with a target [3 x m x n].
        target        : torch.tensor
                        Target image for comparison [3 x m x n].
        plane_id      : int
                        Number of the plane under test.
        inject_noise  : bool
                        When True, noise is added on the targets at the given `noise_ratio`.
        noise_ratio   : float
                        Noise ratio.

        Returns
        -------
        loss          : torch.tensor
                        Computed loss.
        """
        l2 = self.weights[0] * self.loss_function(image, target)
        if isinstance(plane_id, type(None)):
            mask = self.masks
        else:
            mask= self.masks[plane_id, :]
        if inject_noise:
            target = target + torch.randn_like(target) * noise_ratio * (target.max() - target.min())
        l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
        l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
        loss = l2 + l2_mask + l2_cor
        return loss

__call__(image, target, plane_id=None, inject_noise=False, noise_ratio=0.001)

Calculates the multiplane loss against a given target.

Parameters:

  • image
            Image to compare with a target [3 x m x n].
    
  • target
            Target image for comparison [3 x m x n].
    
  • plane_id
            Number of the plane under test.
    
  • inject_noise
            When True, noise is added on the targets at the given `noise_ratio`.
    
  • noise_ratio
            Noise ratio.
    

Returns:

  • loss ( tensor ) –

    Computed loss.

Source code in odak/learn/wave/loss.py
def __call__(self, image, target, plane_id = None, inject_noise = False, noise_ratio = 1e-3):
    """
    Calculates the multiplane loss against a given target.

    Parameters
    ----------
    image         : torch.tensor
                    Image to compare with a target [3 x m x n].
    target        : torch.tensor
                    Target image for comparison [3 x m x n].
    plane_id      : int
                    Number of the plane under test.
    inject_noise  : bool
                    When True, noise is added on the targets at the given `noise_ratio`.
    noise_ratio   : float
                    Noise ratio.

    Returns
    -------
    loss          : torch.tensor
                    Computed loss.
    """
    l2 = self.weights[0] * self.loss_function(image, target)
    if isinstance(plane_id, type(None)):
        mask = self.masks
    else:
        mask= self.masks[plane_id, :]
    if inject_noise:
        target = target + torch.randn_like(target) * noise_ratio * (target.max() - target.min())
    l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
    l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
    loss = l2 + l2_mask + l2_cor
    return loss

__init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, weights=[1.0, 2.1, 0.6], multiplier=1.0, scheme='defocus', reduction='mean', device=torch.device('cpu'))

Parameters:

  • target_image
                Color target image [3 x m x n].
    
  • target_depth
                Monochrome target depth, same resolution as target_image.
    
  • target_blur_size
                Maximum target blur size.
    
  • blur_ratio
                Blur ratio, a value between zero and one.
    
  • number_of_planes
                Number of planes.
    
  • weights
                Weights of the loss function.
    
  • multiplier
                Multiplier to multipy with targets.
    
  • scheme
                The type of the loss, `naive` without defocus or `defocus` with defocus.
    
  • reduction
                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    
  • device
                Device to be used (e.g., cuda, cpu, opencl).
    
Source code in odak/learn/wave/loss.py
def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
             target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6], 
             multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
    """
    Parameters
    ----------
    target_image      : torch.tensor
                        Color target image [3 x m x n].
    target_depth      : torch.tensor
                        Monochrome target depth, same resolution as target_image.
    target_blur_size  : int
                        Maximum target blur size.
    blur_ratio        : float
                        Blur ratio, a value between zero and one.
    number_of_planes  : int
                        Number of planes.
    weights           : list
                        Weights of the loss function.
    multiplier        : float
                        Multiplier to multipy with targets.
    scheme            : str
                        The type of the loss, `naive` without defocus or `defocus` with defocus.
    reduction         : str
                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    device            : torch.device
                        Device to be used (e.g., cuda, cpu, opencl).
    """
    self.device = device
    self.target_image     = target_image.float().to(self.device)
    self.target_depth     = target_depth.float().to(self.device)
    self.target_blur_size = target_blur_size
    if self.target_blur_size % 2 == 0:
        self.target_blur_size += 1
    self.number_of_planes = number_of_planes
    self.multiplier       = multiplier
    self.weights          = weights
    self.reduction        = reduction
    self.blur_ratio       = blur_ratio
    self.set_targets()
    if scheme == 'defocus':
        self.add_defocus_blur()
    self.loss_function = torch.nn.MSELoss(reduction = self.reduction)

add_defocus_blur()

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def add_defocus_blur(self):
    """
    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
    """
    kernel_length = [self.target_blur_size, self.target_blur_size ]
    for ch in range(self.target_image.shape[0]):
        targets_cache = self.targets[:, ch].detach().clone()
        target = torch.sum(targets_cache, axis = 0)
        for i in range(self.number_of_planes):
            defocus = torch.zeros_like(targets_cache[i])
            for j in range(self.number_of_planes):
                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                if torch.sum(targets_cache[j]) > 0:
                    if i == j:
                        nsigma = [0., 0.]
                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                    kernel = kernel / torch.sum(kernel)
                    kernel = kernel.unsqueeze(0).unsqueeze(0)
                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
            self.targets[i, ch] = defocus
    self.targets = self.targets.detach().clone() * self.multiplier

get_targets()

Returns:

  • targets ( tensor ) –

    Returns a copy of the targets.

  • target_depth ( tensor ) –

    Returns a copy of the normalized quantized depth map.

Source code in odak/learn/wave/loss.py
def get_targets(self):
    """
    Returns
    -------
    targets           : torch.tensor
                        Returns a copy of the targets.
    target_depth      : torch.tensor
                        Returns a copy of the normalized quantized depth map.

    """
    divider = self.number_of_planes - 1
    if divider == 0:
        divider = 1
    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider

set_targets()

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def set_targets(self):
    """
    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
    """
    self.target_depth = self.target_depth * (self.number_of_planes - 1)
    self.target_depth = torch.round(self.target_depth, decimals = 0)
    self.targets      = torch.zeros(
                                    self.number_of_planes,
                                    self.target_image.shape[0],
                                    self.target_image.shape[1],
                                    self.target_image.shape[2],
                                    requires_grad = False,
                                    device = self.device
                                   )
    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
    self.masks        = torch.zeros_like(self.targets)
    for i in range(self.number_of_planes):
        for ch in range(self.target_image.shape[0]):
            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
            new_target = self.target_image[ch] * mask
            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
            self.masks[i, ch] = mask.detach().clone() 

perceptual_multiplane_loss

Perceptual loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

Source code in odak/learn/wave/loss.py
class perceptual_multiplane_loss():
    """
    Perceptual loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.
    """

    def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
                 target_blur_size = 10, number_of_planes = 4, multiplier = 1., scheme = 'defocus', 
                 base_loss_weights = {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.},
                 additional_loss_weights = {'cvvdp': 1.}, reduction = 'mean', return_components = False, device = torch.device('cpu')):
        """
        Parameters
        ----------
        target_image            : torch.tensor
                                    Color target image [3 x m x n].
        target_depth            : torch.tensor
                                    Monochrome target depth, same resolution as target_image.
        target_blur_size        : int
                                    Maximum target blur size.
        blur_ratio              : float
                                    Blur ratio, a value between zero and one.
        number_of_planes        : int
                                    Number of planes.
        multiplier              : float
                                    Multiplier to multipy with targets.
        scheme                  : str
                                    The type of the loss, `naive` without defocus or `defocus` with defocus.
        base_loss_weights       : list
                                    Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
        additional_loss_weights : dict
                                    Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
        reduction               : str
                                    Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
        return_components       : bool
                                    If True (False by default), returns the components of the loss as a dict.
        device                  : torch.device
                                    Device to be used (e.g., cuda, cpu, opencl).
        """
        self.device = device
        self.target_image     = target_image.float().to(self.device)
        self.target_depth     = target_depth.float().to(self.device)
        self.target_blur_size = target_blur_size
        if self.target_blur_size % 2 == 0:
            self.target_blur_size += 1
        self.number_of_planes = number_of_planes
        self.multiplier       = multiplier
        self.reduction        = reduction
        if self.reduction == 'none' and len(list(additional_loss_weights.keys())) > 0:
            logging.warning("Reduction cannot be 'none' for additional loss functions. Changing reduction to 'mean'.")
            self.reduction = 'mean'
        self.blur_ratio       = blur_ratio
        self.set_targets()
        if scheme == 'defocus':
            self.add_defocus_blur()
        self.base_loss_weights = base_loss_weights
        self.additional_loss_weights = additional_loss_weights
        self.return_components = return_components
        self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)
        self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)
        for key in self.additional_loss_weights.keys():
            if self.additional_loss_weights[key]:
                if key == 'cvvdp':
                    self.cvvdp = CVVDP(device = device)
                if key == 'fvvdp':
                    self.fvvdp = FVVDP()
                if key == 'lpips':
                    self.lpips = LPIPS()
                if key == 'psnr':
                    self.psnr = PSNR()
                if key == 'ssim':
                    self.ssim = SSIM()
                if key == 'msssim':
                    self.msssim = MSSSIM()

    def get_targets(self):
        """
        Returns
        -------
        targets           : torch.tensor
                            Returns a copy of the targets.
        target_depth      : torch.tensor
                            Returns a copy of the normalized quantized depth map.

        """
        divider = self.number_of_planes - 1
        if divider == 0:
            divider = 1
        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider


    def set_targets(self):
        """
        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
        """
        self.target_depth = self.target_depth * (self.number_of_planes - 1)
        self.target_depth = torch.round(self.target_depth, decimals = 0)
        self.targets      = torch.zeros(
                                        self.number_of_planes,
                                        self.target_image.shape[0],
                                        self.target_image.shape[1],
                                        self.target_image.shape[2],
                                        requires_grad = False,
                                        device = self.device
                                       )
        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
        self.masks        = torch.zeros_like(self.targets)
        for i in range(self.number_of_planes):
            for ch in range(self.target_image.shape[0]):
                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
                new_target = self.target_image[ch] * mask
                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
                self.masks[i, ch] = mask.detach().clone() 


    def add_defocus_blur(self):
        """
        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
        """
        kernel_length = [self.target_blur_size, self.target_blur_size ]
        for ch in range(self.target_image.shape[0]):
            targets_cache = self.targets[:, ch].detach().clone()
            target = torch.sum(targets_cache, axis = 0)
            for i in range(self.number_of_planes):
                defocus = torch.zeros_like(targets_cache[i])
                for j in range(self.number_of_planes):
                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                    if torch.sum(targets_cache[j]) > 0:
                        if i == j:
                            nsigma = [0., 0.]
                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                        kernel = kernel / torch.sum(kernel)
                        kernel = kernel.unsqueeze(0).unsqueeze(0)
                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
                self.targets[i, ch] = defocus
        self.targets = self.targets.detach().clone() * self.multiplier


    def __call__(self, image, target, plane_id = None, inject_noise = False, noise_ratio = 1e-3):
        """
        Calculates the multiplane loss against a given target.

        Parameters
        ----------
        image         : torch.tensor
                        Image to compare with a target [3 x m x n].
        target        : torch.tensor
                        Target image for comparison [3 x m x n].
        plane_id      : int
                        Number of the plane under test.
        inject_noise  : bool
                        When True, noise is added on the targets at the given `noise_ratio`.
        noise_ratio   : float
                        Noise ratio.


        Returns
        -------
        loss          : torch.tensor
                        Computed loss.
        """
        loss_components = {}
        if isinstance(plane_id, type(None)):
            mask = self.masks
        else:
            mask= self.masks[plane_id, :]
        if inject_noise:
            target = target + torch.randn_like(target) * noise_ratio * (target.max() - target.min())
        l2 = self.base_loss_weights['base_l2_loss'] * self.l2_loss_fn(image, target)
        l2_mask = self.base_loss_weights['loss_l2_mask'] * self.l2_loss_fn(image * mask, target * mask)
        l2_cor = self.base_loss_weights['loss_l2_cor'] * self.l2_loss_fn(image * target, target * target)
        loss_components['l2'] = l2
        loss_components['l2_mask'] = l2_mask
        loss_components['l2_cor'] = l2_cor
        loss = l2 + l2_mask + l2_cor

        l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)
        l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)
        l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)
        loss_components['l1'] = l1
        loss_components['l1_mask'] = l1_mask
        loss_components['l1_cor'] = l1_cor
        loss += l1 + l1_mask + l1_cor

        for key in self.additional_loss_weights.keys():
            if self.additional_loss_weights[key]:
                if key == 'cvvdp':
                    loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)
                    loss_components['cvvdp'] = loss_cvvdp
                    loss += loss_cvvdp
                if key == 'fvvdp':
                    loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)
                    loss_components['fvvdp'] = loss_fvvdp
                    loss += loss_fvvdp
                if key == 'lpips':
                    loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)
                    loss_components['lpips'] = loss_lpips
                    loss += loss_lpips
                if key == 'psnr':
                    loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)
                    loss_components['psnr'] = loss_psnr
                    loss += loss_psnr
                if key == 'ssim':
                    loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)
                    loss_components['ssim'] = loss_ssim
                    loss += loss_ssim
                if key == 'msssim':
                    loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)
                    loss_components['msssim'] = loss_msssim
                    loss += loss_msssim
        if self.return_components:
            return loss, loss_components
        return loss

__call__(image, target, plane_id=None, inject_noise=False, noise_ratio=0.001)

Calculates the multiplane loss against a given target.

Parameters:

  • image
            Image to compare with a target [3 x m x n].
    
  • target
            Target image for comparison [3 x m x n].
    
  • plane_id
            Number of the plane under test.
    
  • inject_noise
            When True, noise is added on the targets at the given `noise_ratio`.
    
  • noise_ratio
            Noise ratio.
    

Returns:

  • loss ( tensor ) –

    Computed loss.

Source code in odak/learn/wave/loss.py
def __call__(self, image, target, plane_id = None, inject_noise = False, noise_ratio = 1e-3):
    """
    Calculates the multiplane loss against a given target.

    Parameters
    ----------
    image         : torch.tensor
                    Image to compare with a target [3 x m x n].
    target        : torch.tensor
                    Target image for comparison [3 x m x n].
    plane_id      : int
                    Number of the plane under test.
    inject_noise  : bool
                    When True, noise is added on the targets at the given `noise_ratio`.
    noise_ratio   : float
                    Noise ratio.


    Returns
    -------
    loss          : torch.tensor
                    Computed loss.
    """
    loss_components = {}
    if isinstance(plane_id, type(None)):
        mask = self.masks
    else:
        mask= self.masks[plane_id, :]
    if inject_noise:
        target = target + torch.randn_like(target) * noise_ratio * (target.max() - target.min())
    l2 = self.base_loss_weights['base_l2_loss'] * self.l2_loss_fn(image, target)
    l2_mask = self.base_loss_weights['loss_l2_mask'] * self.l2_loss_fn(image * mask, target * mask)
    l2_cor = self.base_loss_weights['loss_l2_cor'] * self.l2_loss_fn(image * target, target * target)
    loss_components['l2'] = l2
    loss_components['l2_mask'] = l2_mask
    loss_components['l2_cor'] = l2_cor
    loss = l2 + l2_mask + l2_cor

    l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)
    l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)
    l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)
    loss_components['l1'] = l1
    loss_components['l1_mask'] = l1_mask
    loss_components['l1_cor'] = l1_cor
    loss += l1 + l1_mask + l1_cor

    for key in self.additional_loss_weights.keys():
        if self.additional_loss_weights[key]:
            if key == 'cvvdp':
                loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)
                loss_components['cvvdp'] = loss_cvvdp
                loss += loss_cvvdp
            if key == 'fvvdp':
                loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)
                loss_components['fvvdp'] = loss_fvvdp
                loss += loss_fvvdp
            if key == 'lpips':
                loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)
                loss_components['lpips'] = loss_lpips
                loss += loss_lpips
            if key == 'psnr':
                loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)
                loss_components['psnr'] = loss_psnr
                loss += loss_psnr
            if key == 'ssim':
                loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)
                loss_components['ssim'] = loss_ssim
                loss += loss_ssim
            if key == 'msssim':
                loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)
                loss_components['msssim'] = loss_msssim
                loss += loss_msssim
    if self.return_components:
        return loss, loss_components
    return loss

__init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, multiplier=1.0, scheme='defocus', base_loss_weights={'base_l2_loss': 1.0, 'loss_l2_mask': 1.0, 'loss_l2_cor': 1.0, 'base_l1_loss': 1.0, 'loss_l1_mask': 1.0, 'loss_l1_cor': 1.0}, additional_loss_weights={'cvvdp': 1.0}, reduction='mean', return_components=False, device=torch.device('cpu'))

Parameters:

  • target_image
                        Color target image [3 x m x n].
    
  • target_depth
                        Monochrome target depth, same resolution as target_image.
    
  • target_blur_size
                        Maximum target blur size.
    
  • blur_ratio
                        Blur ratio, a value between zero and one.
    
  • number_of_planes
                        Number of planes.
    
  • multiplier
                        Multiplier to multipy with targets.
    
  • scheme
                        The type of the loss, `naive` without defocus or `defocus` with defocus.
    
  • base_loss_weights
                        Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
    
  • additional_loss_weights (dict, default: {'cvvdp': 1.0} ) –
                        Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
    
  • reduction
                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    
  • return_components
                        If True (False by default), returns the components of the loss as a dict.
    
  • device
                        Device to be used (e.g., cuda, cpu, opencl).
    
Source code in odak/learn/wave/loss.py
def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
             target_blur_size = 10, number_of_planes = 4, multiplier = 1., scheme = 'defocus', 
             base_loss_weights = {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.},
             additional_loss_weights = {'cvvdp': 1.}, reduction = 'mean', return_components = False, device = torch.device('cpu')):
    """
    Parameters
    ----------
    target_image            : torch.tensor
                                Color target image [3 x m x n].
    target_depth            : torch.tensor
                                Monochrome target depth, same resolution as target_image.
    target_blur_size        : int
                                Maximum target blur size.
    blur_ratio              : float
                                Blur ratio, a value between zero and one.
    number_of_planes        : int
                                Number of planes.
    multiplier              : float
                                Multiplier to multipy with targets.
    scheme                  : str
                                The type of the loss, `naive` without defocus or `defocus` with defocus.
    base_loss_weights       : list
                                Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
    additional_loss_weights : dict
                                Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
    reduction               : str
                                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    return_components       : bool
                                If True (False by default), returns the components of the loss as a dict.
    device                  : torch.device
                                Device to be used (e.g., cuda, cpu, opencl).
    """
    self.device = device
    self.target_image     = target_image.float().to(self.device)
    self.target_depth     = target_depth.float().to(self.device)
    self.target_blur_size = target_blur_size
    if self.target_blur_size % 2 == 0:
        self.target_blur_size += 1
    self.number_of_planes = number_of_planes
    self.multiplier       = multiplier
    self.reduction        = reduction
    if self.reduction == 'none' and len(list(additional_loss_weights.keys())) > 0:
        logging.warning("Reduction cannot be 'none' for additional loss functions. Changing reduction to 'mean'.")
        self.reduction = 'mean'
    self.blur_ratio       = blur_ratio
    self.set_targets()
    if scheme == 'defocus':
        self.add_defocus_blur()
    self.base_loss_weights = base_loss_weights
    self.additional_loss_weights = additional_loss_weights
    self.return_components = return_components
    self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)
    self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)
    for key in self.additional_loss_weights.keys():
        if self.additional_loss_weights[key]:
            if key == 'cvvdp':
                self.cvvdp = CVVDP(device = device)
            if key == 'fvvdp':
                self.fvvdp = FVVDP()
            if key == 'lpips':
                self.lpips = LPIPS()
            if key == 'psnr':
                self.psnr = PSNR()
            if key == 'ssim':
                self.ssim = SSIM()
            if key == 'msssim':
                self.msssim = MSSSIM()

add_defocus_blur()

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def add_defocus_blur(self):
    """
    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
    """
    kernel_length = [self.target_blur_size, self.target_blur_size ]
    for ch in range(self.target_image.shape[0]):
        targets_cache = self.targets[:, ch].detach().clone()
        target = torch.sum(targets_cache, axis = 0)
        for i in range(self.number_of_planes):
            defocus = torch.zeros_like(targets_cache[i])
            for j in range(self.number_of_planes):
                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
                if torch.sum(targets_cache[j]) > 0:
                    if i == j:
                        nsigma = [0., 0.]
                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
                    kernel = kernel / torch.sum(kernel)
                    kernel = kernel.unsqueeze(0).unsqueeze(0)
                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
            self.targets[i, ch] = defocus
    self.targets = self.targets.detach().clone() * self.multiplier

get_targets()

Returns:

  • targets ( tensor ) –

    Returns a copy of the targets.

  • target_depth ( tensor ) –

    Returns a copy of the normalized quantized depth map.

Source code in odak/learn/wave/loss.py
def get_targets(self):
    """
    Returns
    -------
    targets           : torch.tensor
                        Returns a copy of the targets.
    target_depth      : torch.tensor
                        Returns a copy of the normalized quantized depth map.

    """
    divider = self.number_of_planes - 1
    if divider == 0:
        divider = 1
    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider

set_targets()

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def set_targets(self):
    """
    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
    """
    self.target_depth = self.target_depth * (self.number_of_planes - 1)
    self.target_depth = torch.round(self.target_depth, decimals = 0)
    self.targets      = torch.zeros(
                                    self.number_of_planes,
                                    self.target_image.shape[0],
                                    self.target_image.shape[1],
                                    self.target_image.shape[2],
                                    requires_grad = False,
                                    device = self.device
                                   )
    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
    self.masks        = torch.zeros_like(self.targets)
    for i in range(self.number_of_planes):
        for ch in range(self.target_image.shape[0]):
            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
            new_target = self.target_image[ch] * mask
            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
            self.masks[i, ch] = mask.detach().clone() 

phase_gradient

Bases: Module

The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude.

This implements a convolution of the phase with a kernel.

The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.

Source code in odak/learn/wave/loss.py
class phase_gradient(nn.Module):

    """
    The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude. 

    This implements a convolution of the phase with a kernel.

    The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.
    """


    def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
        """
        Parameters
        ----------
        kernel                  : torch.tensor
                                    Convolution filter kernel, 3 by 3 Laplacian kernel by default.
        loss                    : torch.nn.Module
                                    loss function, L2 Loss by default.
        """
        super(phase_gradient, self).__init__()
        self.device = device
        self.loss = loss
        if kernel == None:
            self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
        else:
            if len(kernel.shape) == 4:
                self.kernel = kernel
            else:
                self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
        self.kernel = Variable(self.kernel.to(self.device))


    def forward(self, phase):
        """
        Calculates the phase gradient Loss.

        Parameters
        ----------
        phase                  : torch.tensor
                                    Phase of the complex amplitude.

        Returns
        -------

        loss_value              : torch.tensor
                                    The computed loss.
        """

        if len(phase.shape) == 2:
            phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
        edge_detect = self.functional_conv2d(phase)
        loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
        return loss_value


    def functional_conv2d(self, phase):
        """
        Calculates the gradient of the phase.

        Parameters
        ----------
        phase                  : torch.tensor
                                    Phase of the complex amplitude.

        Returns
        -------

        edge_detect              : torch.tensor
                                    The computed phase gradient.
        """
        edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
        return edge_detect

__init__(kernel=None, loss=nn.MSELoss(), device=torch.device('cpu'))

Parameters:

  • kernel
                        Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    
  • loss
                        loss function, L2 Loss by default.
    
Source code in odak/learn/wave/loss.py
def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
    """
    Parameters
    ----------
    kernel                  : torch.tensor
                                Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    loss                    : torch.nn.Module
                                loss function, L2 Loss by default.
    """
    super(phase_gradient, self).__init__()
    self.device = device
    self.loss = loss
    if kernel == None:
        self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
    else:
        if len(kernel.shape) == 4:
            self.kernel = kernel
        else:
            self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
    self.kernel = Variable(self.kernel.to(self.device))

forward(phase)

Calculates the phase gradient Loss.

Parameters:

  • phase
                        Phase of the complex amplitude.
    

Returns:

  • loss_value ( tensor ) –

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, phase):
    """
    Calculates the phase gradient Loss.

    Parameters
    ----------
    phase                  : torch.tensor
                                Phase of the complex amplitude.

    Returns
    -------

    loss_value              : torch.tensor
                                The computed loss.
    """

    if len(phase.shape) == 2:
        phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
    edge_detect = self.functional_conv2d(phase)
    loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
    return loss_value

functional_conv2d(phase)

Calculates the gradient of the phase.

Parameters:

  • phase
                        Phase of the complex amplitude.
    

Returns:

  • edge_detect ( tensor ) –

    The computed phase gradient.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, phase):
    """
    Calculates the gradient of the phase.

    Parameters
    ----------
    phase                  : torch.tensor
                                Phase of the complex amplitude.

    Returns
    -------

    edge_detect              : torch.tensor
                                The computed phase gradient.
    """
    edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
    return edge_detect

speckle_contrast

Bases: Module

The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

We refer to the following paper:

Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0.

Source code in odak/learn/wave/loss.py
class speckle_contrast(nn.Module):

    """
    The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

    We refer to the following paper:

    Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0. 
    """


    def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
        """
        Parameters
        ----------
        kernel_size             : torch.tensor
                                    Convolution filter kernel size, 11 by 11 average kernel by default.
        step_size               : tuple
                                    Convolution stride in height and width direction.
        loss                    : torch.nn.Module
                                    loss function, L2 Loss by default.
        """
        super(speckle_contrast, self).__init__()
        self.device = device
        self.loss = loss
        self.step_size = step_size
        self.kernel_size = kernel_size
        self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
        self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))


    def forward(self, intensity):
        """
        Calculates the speckle contrast Loss.

        Parameters
        ----------
        intensity               : torch.tensor
                                    intensity of the complex amplitude.

        Returns
        -------

        loss_value              : torch.tensor
                                    The computed loss.
        """

        if len(intensity.shape) == 2:
            intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
        Speckle_C = self.functional_conv2d(intensity)
        loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
        return loss_value


    def functional_conv2d(self, intensity):
        """
        Calculates the speckle contrast of the intensity.

        Parameters
        ----------
        intensity                : torch.tensor
                                    Intensity of the complex field.

        Returns
        -------

        Speckle_C               : torch.tensor
                                    The computed speckle contrast.
        """
        mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
        var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
        Speckle_C = var / mean
        return Speckle_C

__init__(kernel_size=11, step_size=(1, 1), loss=nn.MSELoss(), device=torch.device('cpu'))

Parameters:

  • kernel_size
                        Convolution filter kernel size, 11 by 11 average kernel by default.
    
  • step_size
                        Convolution stride in height and width direction.
    
  • loss
                        loss function, L2 Loss by default.
    
Source code in odak/learn/wave/loss.py
def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
    """
    Parameters
    ----------
    kernel_size             : torch.tensor
                                Convolution filter kernel size, 11 by 11 average kernel by default.
    step_size               : tuple
                                Convolution stride in height and width direction.
    loss                    : torch.nn.Module
                                loss function, L2 Loss by default.
    """
    super(speckle_contrast, self).__init__()
    self.device = device
    self.loss = loss
    self.step_size = step_size
    self.kernel_size = kernel_size
    self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
    self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))

forward(intensity)

Calculates the speckle contrast Loss.

Parameters:

  • intensity
                        intensity of the complex amplitude.
    

Returns:

  • loss_value ( tensor ) –

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, intensity):
    """
    Calculates the speckle contrast Loss.

    Parameters
    ----------
    intensity               : torch.tensor
                                intensity of the complex amplitude.

    Returns
    -------

    loss_value              : torch.tensor
                                The computed loss.
    """

    if len(intensity.shape) == 2:
        intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
    Speckle_C = self.functional_conv2d(intensity)
    loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
    return loss_value

functional_conv2d(intensity)

Calculates the speckle contrast of the intensity.

Parameters:

  • intensity
                        Intensity of the complex field.
    

Returns:

  • Speckle_C ( tensor ) –

    The computed speckle contrast.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, intensity):
    """
    Calculates the speckle contrast of the intensity.

    Parameters
    ----------
    intensity                : torch.tensor
                                Intensity of the complex field.

    Returns
    -------

    Speckle_C               : torch.tensor
                                The computed speckle contrast.
    """
    mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
    var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
    Speckle_C = var / mean
    return Speckle_C

channel_gate

Bases: Module

Channel attention module with various pooling strategies. This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

Source code in odak/learn/models/components.py
class channel_gate(torch.nn.Module):
    """
    Channel attention module with various pooling strategies.
    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).
    """
    def __init__(
                 self, 
                 gate_channels, 
                 reduction_ratio = 16, 
                 pool_types = ['avg', 'max']
                ):
        """
        Initializes the channel gate module.

        Parameters
        ----------
        gate_channels   : int
                          Number of channels of the input feature map.
        reduction_ratio : int
                          Reduction ratio for the intermediate layer.
        pool_types      : list
                          List of pooling operations to apply.
        """
        super().__init__()
        self.gate_channels = gate_channels
        hidden_channels = gate_channels // reduction_ratio
        if hidden_channels == 0:
            hidden_channels = 1
        self.mlp = torch.nn.Sequential(
                                       convolutional_block_attention.Flatten(),
                                       torch.nn.Linear(gate_channels, hidden_channels),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(hidden_channels, gate_channels)
                                      )
        self.pool_types = pool_types


    def forward(self, x):
        """
        Forward pass of the ChannelGate module.

        Applies channel-wise attention to the input tensor.

        Parameters
        ----------
        x            : torch.tensor
                       Input tensor to the ChannelGate module.

        Returns
        -------
        output       : torch.tensor
                       Output tensor after applying channel attention.
        """
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type == 'avg':
                pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
            elif pool_type == 'max':
                pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
            channel_att_raw = self.mlp(pool)
            channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw
        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
        output = x * scale
        return output

__init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'])

Initializes the channel gate module.

Parameters:

  • gate_channels
              Number of channels of the input feature map.
    
  • reduction_ratio (int, default: 16 ) –
              Reduction ratio for the intermediate layer.
    
  • pool_types
              List of pooling operations to apply.
    
Source code in odak/learn/models/components.py
def __init__(
             self, 
             gate_channels, 
             reduction_ratio = 16, 
             pool_types = ['avg', 'max']
            ):
    """
    Initializes the channel gate module.

    Parameters
    ----------
    gate_channels   : int
                      Number of channels of the input feature map.
    reduction_ratio : int
                      Reduction ratio for the intermediate layer.
    pool_types      : list
                      List of pooling operations to apply.
    """
    super().__init__()
    self.gate_channels = gate_channels
    hidden_channels = gate_channels // reduction_ratio
    if hidden_channels == 0:
        hidden_channels = 1
    self.mlp = torch.nn.Sequential(
                                   convolutional_block_attention.Flatten(),
                                   torch.nn.Linear(gate_channels, hidden_channels),
                                   torch.nn.ReLU(),
                                   torch.nn.Linear(hidden_channels, gate_channels)
                                  )
    self.pool_types = pool_types

forward(x)

Forward pass of the ChannelGate module.

Applies channel-wise attention to the input tensor.

Parameters:

  • x
           Input tensor to the ChannelGate module.
    

Returns:

  • output ( tensor ) –

    Output tensor after applying channel attention.

Source code in odak/learn/models/components.py
def forward(self, x):
    """
    Forward pass of the ChannelGate module.

    Applies channel-wise attention to the input tensor.

    Parameters
    ----------
    x            : torch.tensor
                   Input tensor to the ChannelGate module.

    Returns
    -------
    output       : torch.tensor
                   Output tensor after applying channel attention.
    """
    channel_att_sum = None
    for pool_type in self.pool_types:
        if pool_type == 'avg':
            pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
        elif pool_type == 'max':
            pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
        channel_att_raw = self.mlp(pool)
        channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw
    scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
    output = x * scale
    return output

convolution_layer

Bases: Module

A convolution layer.

Source code in odak/learn/models/components.py
class convolution_layer(torch.nn.Module):
    """
    A convolution layer.
    """
    def __init__(
                 self,
                 input_channels = 2,
                 output_channels = 2,
                 kernel_size = 3,
                 bias = False,
                 stride = 1,
                 normalization = False,
                 activation = torch.nn.ReLU()
                ):
        """
        A convolutional layer class.


        Parameters
        ----------
        input_channels  : int
                          Number of input channels.
        output_channels : int
                          Number of output channels.
        kernel_size     : int
                          Kernel size.
        bias            : bool
                          Set to True to let convolutional layers have bias term.
        normalization   : bool
                          If True, adds a Batch Normalization layer after the convolutional layer.
        activation      : torch.nn
                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
        """
        super().__init__()
        layers = [
            torch.nn.Conv2d(
                            input_channels,
                            output_channels,
                            kernel_size = kernel_size,
                            stride = stride,
                            padding = kernel_size // 2,
                            bias = bias
                           )
        ]
        if normalization:
            layers.append(torch.nn.BatchNorm2d(output_channels))
        if activation:
            layers.append(activation)
        self.model = torch.nn.Sequential(*layers)


    def forward(self, x):
        """
        Forward model.

        Parameters
        ----------
        x             : torch.tensor
                        Input data.


        Returns
        ----------
        result        : torch.tensor
                        Estimated output.
        """
        result = self.model(x)
        return result

__init__(input_channels=2, output_channels=2, kernel_size=3, bias=False, stride=1, normalization=False, activation=torch.nn.ReLU())

A convolutional layer class.

Parameters:

  • input_channels
              Number of input channels.
    
  • output_channels (int, default: 2 ) –
              Number of output channels.
    
  • kernel_size
              Kernel size.
    
  • bias
              Set to True to let convolutional layers have bias term.
    
  • normalization
              If True, adds a Batch Normalization layer after the convolutional layer.
    
  • activation
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels = 2,
             output_channels = 2,
             kernel_size = 3,
             bias = False,
             stride = 1,
             normalization = False,
             activation = torch.nn.ReLU()
            ):
    """
    A convolutional layer class.


    Parameters
    ----------
    input_channels  : int
                      Number of input channels.
    output_channels : int
                      Number of output channels.
    kernel_size     : int
                      Kernel size.
    bias            : bool
                      Set to True to let convolutional layers have bias term.
    normalization   : bool
                      If True, adds a Batch Normalization layer after the convolutional layer.
    activation      : torch.nn
                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    """
    super().__init__()
    layers = [
        torch.nn.Conv2d(
                        input_channels,
                        output_channels,
                        kernel_size = kernel_size,
                        stride = stride,
                        padding = kernel_size // 2,
                        bias = bias
                       )
    ]
    if normalization:
        layers.append(torch.nn.BatchNorm2d(output_channels))
    if activation:
        layers.append(activation)
    self.model = torch.nn.Sequential(*layers)

forward(x)

Forward model.

Parameters:

  • x
            Input data.
    

Returns:

  • result ( tensor ) –

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):
    """
    Forward model.

    Parameters
    ----------
    x             : torch.tensor
                    Input data.


    Returns
    ----------
    result        : torch.tensor
                    Estimated output.
    """
    result = self.model(x)
    return result

convolutional_block_attention

Bases: Module

Convolutional Block Attention Module (CBAM) class. This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

Source code in odak/learn/models/components.py
class convolutional_block_attention(torch.nn.Module):
    """
    Convolutional Block Attention Module (CBAM) class. 
    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).
    """
    def __init__(
                 self, 
                 gate_channels, 
                 reduction_ratio = 16, 
                 pool_types = ['avg', 'max'], 
                 no_spatial = False
                ):
        """
        Initializes the convolutional block attention module.

        Parameters
        ----------
        gate_channels   : int
                          Number of channels of the input feature map.
        reduction_ratio : int
                          Reduction ratio for the channel attention.
        pool_types      : list
                          List of pooling operations to apply for channel attention.
        no_spatial      : bool
                          If True, spatial attention is not applied.
        """
        super(convolutional_block_attention, self).__init__()
        self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial = no_spatial
        if not no_spatial:
            self.spatial_gate = spatial_gate()


    class Flatten(torch.nn.Module):
        """
        Flattens the input tensor to a 2D matrix.
        """
        def forward(self, x):
            return x.view(x.size(0), -1)


    def forward(self, x):
        """
        Forward pass of the convolutional block attention module.

        Parameters
        ----------
        x            : torch.tensor
                       Input tensor to the CBAM module.

        Returns
        -------
        x_out        : torch.tensor
                       Output tensor after applying channel and spatial attention.
        """
        x_out = self.channel_gate(x)
        if not self.no_spatial:
            x_out = self.spatial_gate(x_out)
        return x_out

Flatten

Bases: Module

Flattens the input tensor to a 2D matrix.

Source code in odak/learn/models/components.py
class Flatten(torch.nn.Module):
    """
    Flattens the input tensor to a 2D matrix.
    """
    def forward(self, x):
        return x.view(x.size(0), -1)

__init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False)

Initializes the convolutional block attention module.

Parameters:

  • gate_channels
              Number of channels of the input feature map.
    
  • reduction_ratio (int, default: 16 ) –
              Reduction ratio for the channel attention.
    
  • pool_types
              List of pooling operations to apply for channel attention.
    
  • no_spatial
              If True, spatial attention is not applied.
    
Source code in odak/learn/models/components.py
def __init__(
             self, 
             gate_channels, 
             reduction_ratio = 16, 
             pool_types = ['avg', 'max'], 
             no_spatial = False
            ):
    """
    Initializes the convolutional block attention module.

    Parameters
    ----------
    gate_channels   : int
                      Number of channels of the input feature map.
    reduction_ratio : int
                      Reduction ratio for the channel attention.
    pool_types      : list
                      List of pooling operations to apply for channel attention.
    no_spatial      : bool
                      If True, spatial attention is not applied.
    """
    super(convolutional_block_attention, self).__init__()
    self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)
    self.no_spatial = no_spatial
    if not no_spatial:
        self.spatial_gate = spatial_gate()

forward(x)

Forward pass of the convolutional block attention module.

Parameters:

  • x
           Input tensor to the CBAM module.
    

Returns:

  • x_out ( tensor ) –

    Output tensor after applying channel and spatial attention.

Source code in odak/learn/models/components.py
def forward(self, x):
    """
    Forward pass of the convolutional block attention module.

    Parameters
    ----------
    x            : torch.tensor
                   Input tensor to the CBAM module.

    Returns
    -------
    x_out        : torch.tensor
                   Output tensor after applying channel and spatial attention.
    """
    x_out = self.channel_gate(x)
    if not self.no_spatial:
        x_out = self.spatial_gate(x_out)
    return x_out

double_convolution

Bases: Module

A double convolution layer.

Source code in odak/learn/models/components.py
class double_convolution(torch.nn.Module):
    """
    A double convolution layer.
    """
    def __init__(
                 self,
                 input_channels = 2,
                 mid_channels = None,
                 output_channels = 2,
                 kernel_size = 3, 
                 bias = False,
                 normalization = False,
                 activation = torch.nn.ReLU()
                ):
        """
        Double convolution model.


        Parameters
        ----------
        input_channels  : int
                          Number of input channels.
        mid_channels    : int
                          Number of channels in the hidden layer between two convolutions.
        output_channels : int
                          Number of output channels.
        kernel_size     : int
                          Kernel size.
        bias            : bool 
                          Set to True to let convolutional layers have bias term.
        normalization   : bool
                          If True, adds a Batch Normalization layer after the convolutional layer.
        activation      : torch.nn
                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
        """
        super().__init__()
        if isinstance(mid_channels, type(None)):
            mid_channels = output_channels
        self.activation = activation
        self.model = torch.nn.Sequential(
                                         convolution_layer(
                                                           input_channels = input_channels,
                                                           output_channels = mid_channels,
                                                           kernel_size = kernel_size,
                                                           bias = bias,
                                                           normalization = normalization,
                                                           activation = self.activation
                                                          ),
                                         convolution_layer(
                                                           input_channels = mid_channels,
                                                           output_channels = output_channels,
                                                           kernel_size = kernel_size,
                                                           bias = bias,
                                                           normalization = normalization,
                                                           activation = self.activation
                                                          )
                                        )


    def forward(self, x):
        """
        Forward model.

        Parameters
        ----------
        x             : torch.tensor
                        Input data.


        Returns
        ----------
        result        : torch.tensor
                        Estimated output.      
        """
        result = self.model(x)
        return result

__init__(input_channels=2, mid_channels=None, output_channels=2, kernel_size=3, bias=False, normalization=False, activation=torch.nn.ReLU())

Double convolution model.

Parameters:

  • input_channels
              Number of input channels.
    
  • mid_channels
              Number of channels in the hidden layer between two convolutions.
    
  • output_channels (int, default: 2 ) –
              Number of output channels.
    
  • kernel_size
              Kernel size.
    
  • bias
              Set to True to let convolutional layers have bias term.
    
  • normalization
              If True, adds a Batch Normalization layer after the convolutional layer.
    
  • activation
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels = 2,
             mid_channels = None,
             output_channels = 2,
             kernel_size = 3, 
             bias = False,
             normalization = False,
             activation = torch.nn.ReLU()
            ):
    """
    Double convolution model.


    Parameters
    ----------
    input_channels  : int
                      Number of input channels.
    mid_channels    : int
                      Number of channels in the hidden layer between two convolutions.
    output_channels : int
                      Number of output channels.
    kernel_size     : int
                      Kernel size.
    bias            : bool 
                      Set to True to let convolutional layers have bias term.
    normalization   : bool
                      If True, adds a Batch Normalization layer after the convolutional layer.
    activation      : torch.nn
                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    """
    super().__init__()
    if isinstance(mid_channels, type(None)):
        mid_channels = output_channels
    self.activation = activation
    self.model = torch.nn.Sequential(
                                     convolution_layer(
                                                       input_channels = input_channels,
                                                       output_channels = mid_channels,
                                                       kernel_size = kernel_size,
                                                       bias = bias,
                                                       normalization = normalization,
                                                       activation = self.activation
                                                      ),
                                     convolution_layer(
                                                       input_channels = mid_channels,
                                                       output_channels = output_channels,
                                                       kernel_size = kernel_size,
                                                       bias = bias,
                                                       normalization = normalization,
                                                       activation = self.activation
                                                      )
                                    )

forward(x)

Forward model.

Parameters:

  • x
            Input data.
    

Returns:

  • result ( tensor ) –

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):
    """
    Forward model.

    Parameters
    ----------
    x             : torch.tensor
                    Input data.


    Returns
    ----------
    result        : torch.tensor
                    Estimated output.      
    """
    result = self.model(x)
    return result

downsample_layer

Bases: Module

A downscaling component followed by a double convolution.

Source code in odak/learn/models/components.py
class downsample_layer(torch.nn.Module):
    """
    A downscaling component followed by a double convolution.
    """
    def __init__(
                 self,
                 input_channels,
                 output_channels,
                 kernel_size = 3,
                 bias = False,
                 normalization = False,
                 activation = torch.nn.ReLU()
                ):
        """
        A downscaling component with a double convolution.

        Parameters
        ----------
        input_channels  : int
                          Number of input channels.
        output_channels : int
                          Number of output channels.
        kernel_size     : int
                          Kernel size.
        bias            : bool 
                          Set to True to let convolutional layers have bias term.
        normalization   : bool                
                          If True, adds a Batch Normalization layer after the convolutional layer.
        activation      : torch.nn
                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
        """
        super().__init__()
        self.maxpool_conv = torch.nn.Sequential(
                                                torch.nn.MaxPool2d(2),
                                                double_convolution(
                                                                   input_channels = input_channels,
                                                                   mid_channels = output_channels,
                                                                   output_channels = output_channels,
                                                                   kernel_size = kernel_size,
                                                                   bias = bias,
                                                                   normalization = normalization,
                                                                   activation = activation
                                                                  )
                                               )


    def forward(self, x):
        """
        Forward model.

        Parameters
        ----------
        x              : torch.tensor
                         First input data.



        Returns
        ----------
        result        : torch.tensor
                        Estimated output.      
        """
        result = self.maxpool_conv(x)
        return result

__init__(input_channels, output_channels, kernel_size=3, bias=False, normalization=False, activation=torch.nn.ReLU())

A downscaling component with a double convolution.

Parameters:

  • input_channels
              Number of input channels.
    
  • output_channels (int) –
              Number of output channels.
    
  • kernel_size
              Kernel size.
    
  • bias
              Set to True to let convolutional layers have bias term.
    
  • normalization
              If True, adds a Batch Normalization layer after the convolutional layer.
    
  • activation
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels,
             output_channels,
             kernel_size = 3,
             bias = False,
             normalization = False,
             activation = torch.nn.ReLU()
            ):
    """
    A downscaling component with a double convolution.

    Parameters
    ----------
    input_channels  : int
                      Number of input channels.
    output_channels : int
                      Number of output channels.
    kernel_size     : int
                      Kernel size.
    bias            : bool 
                      Set to True to let convolutional layers have bias term.
    normalization   : bool                
                      If True, adds a Batch Normalization layer after the convolutional layer.
    activation      : torch.nn
                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    """
    super().__init__()
    self.maxpool_conv = torch.nn.Sequential(
                                            torch.nn.MaxPool2d(2),
                                            double_convolution(
                                                               input_channels = input_channels,
                                                               mid_channels = output_channels,
                                                               output_channels = output_channels,
                                                               kernel_size = kernel_size,
                                                               bias = bias,
                                                               normalization = normalization,
                                                               activation = activation
                                                              )
                                           )

forward(x)

Forward model.

Parameters:

  • x
             First input data.
    

Returns:

  • result ( tensor ) –

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):
    """
    Forward model.

    Parameters
    ----------
    x              : torch.tensor
                     First input data.



    Returns
    ----------
    result        : torch.tensor
                    Estimated output.      
    """
    result = self.maxpool_conv(x)
    return result

focal_surface_light_propagation

Bases: Module

focal_surface_light_propagation model.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit}. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024.

Source code in odak/learn/wave/models.py
class focal_surface_light_propagation(torch.nn.Module):
    """
    focal_surface_light_propagation model.

    References
    ----------

    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit}. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024.
    """
    def __init__(
                 self,
                 depth = 3,
                 dimensions = 8,
                 input_channels = 6,
                 out_channels = 6,
                 kernel_size = 3,
                 bias = True,
                 device = torch.device('cpu'),
                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
                ):
        """
        Initializes the focal surface light propagation model.

        Parameters
        ----------
        depth             : int
                            Number of downsampling and upsampling layers.
        dimensions        : int
                            Number of dimensions/features in the model.
        input_channels    : int
                            Number of input channels.
        out_channels      : int
                            Number of output channels.
        kernel_size       : int
                            Size of the convolution kernel.
        bias              : bool
                            If True, allows convolutional layers to learn a bias term.
        device            : torch.device
                            Default device is CPU.
        activation        : torch.nn.Module
                            Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
        """
        super().__init__()
        self.depth = depth
        self.device = device
        self.sv_kernel_generation = spatially_varying_kernel_generation_model(
            depth = depth,
            dimensions = dimensions,
            input_channels = input_channels + 1,  # +1 to account for an extra channel
            kernel_size = kernel_size,
            bias = bias,
            activation = activation
        )
        self.light_propagation = spatially_adaptive_unet(
            depth = depth,
            dimensions = dimensions,
            input_channels = input_channels,
            out_channels = out_channels,
            kernel_size = kernel_size,
            bias = bias,
            activation = activation
        )


    def forward(self, focal_surface, phase_only_hologram):
        """
        Forward pass through the model.

        Parameters
        ----------
        focal_surface         : torch.Tensor
                                Input focal surface.
        phase_only_hologram   : torch.Tensor
                                Input phase-only hologram.

        Returns
        ----------
        result                : torch.Tensor
                                Output tensor after light propagation.
        """
        input_field = self.generate_input_field(phase_only_hologram)
        sv_kernel = self.sv_kernel_generation(focal_surface, input_field)
        output_field = self.light_propagation(sv_kernel, input_field)
        final = (output_field[:, 0:3, :, :] + 1j * output_field[:, 3:6, :, :])
        result = calculate_amplitude(final) ** 2
        return result


    def generate_input_field(self, phase_only_hologram):
        """
        Generates an input field by combining the real and imaginary parts.

        Parameters
        ----------
        phase_only_hologram   : torch.Tensor
                                Input phase-only hologram.

        Returns
        ----------
        input_field           : torch.Tensor
                                Concatenated real and imaginary parts of the complex field.
        """
        [b, c, h, w] = phase_only_hologram.size()
        input_phase = phase_only_hologram * 2 * np.pi
        hologram_amplitude = torch.ones(b, c, h, w, requires_grad = False).to(self.device)
        field = generate_complex_field(hologram_amplitude, input_phase)
        input_field = torch.cat((field.real, field.imag), dim = 1)
        return input_field


    def load_weights(self, weight_filename, key_mapping_filename):
        """
        Function to load weights for this multi-layer perceptron from a file.

        Parameters
        ----------
        weight_filename      : str
                               Path to the old model's weight file.
        key_mapping_filename : str
                               Path to the JSON file containing the key mappings.
        """
        # Load old model weights
        old_model_weights = torch.load(weight_filename, map_location = self.device,weights_only=True)

        # Load key mappings from JSON file
        with open(key_mapping_filename, 'r') as json_file:
            key_mappings = json.load(json_file)

        # Extract the key mappings for sv_kernel_generation and light_prop
        sv_kernel_generation_key_mapping = key_mappings['sv_kernel_generation_key_mapping']
        light_prop_key_mapping = key_mappings['light_prop_key_mapping']

        # Initialize new state dicts
        sv_kernel_generation_new_state_dict = {}
        light_prop_new_state_dict = {}

        # Map and load sv_kernel_generation_model weights
        for old_key, value in old_model_weights.items():
            if old_key in sv_kernel_generation_key_mapping:
                # Map the old key to the new key
                new_key = sv_kernel_generation_key_mapping[old_key]
                sv_kernel_generation_new_state_dict[new_key] = value

        self.sv_kernel_generation.to(self.device)
        self.sv_kernel_generation.load_state_dict(sv_kernel_generation_new_state_dict)

        # Map and load light_prop model weights
        for old_key, value in old_model_weights.items():
            if old_key in light_prop_key_mapping:
                # Map the old key to the new key
                new_key = light_prop_key_mapping[old_key]
                light_prop_new_state_dict[new_key] = value
        self.light_propagation.to(self.device)
        self.light_propagation.load_state_dict(light_prop_new_state_dict)

__init__(depth=3, dimensions=8, input_channels=6, out_channels=6, kernel_size=3, bias=True, device=torch.device('cpu'), activation=torch.nn.LeakyReLU(0.2, inplace=True))

Initializes the focal surface light propagation model.

Parameters:

  • depth
                Number of downsampling and upsampling layers.
    
  • dimensions
                Number of dimensions/features in the model.
    
  • input_channels
                Number of input channels.
    
  • out_channels
                Number of output channels.
    
  • kernel_size
                Size of the convolution kernel.
    
  • bias
                If True, allows convolutional layers to learn a bias term.
    
  • device
                Default device is CPU.
    
  • activation
                Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    
Source code in odak/learn/wave/models.py
def __init__(
             self,
             depth = 3,
             dimensions = 8,
             input_channels = 6,
             out_channels = 6,
             kernel_size = 3,
             bias = True,
             device = torch.device('cpu'),
             activation = torch.nn.LeakyReLU(0.2, inplace = True)
            ):
    """
    Initializes the focal surface light propagation model.

    Parameters
    ----------
    depth             : int
                        Number of downsampling and upsampling layers.
    dimensions        : int
                        Number of dimensions/features in the model.
    input_channels    : int
                        Number of input channels.
    out_channels      : int
                        Number of output channels.
    kernel_size       : int
                        Size of the convolution kernel.
    bias              : bool
                        If True, allows convolutional layers to learn a bias term.
    device            : torch.device
                        Default device is CPU.
    activation        : torch.nn.Module
                        Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    """
    super().__init__()
    self.depth = depth
    self.device = device
    self.sv_kernel_generation = spatially_varying_kernel_generation_model(
        depth = depth,
        dimensions = dimensions,
        input_channels = input_channels + 1,  # +1 to account for an extra channel
        kernel_size = kernel_size,
        bias = bias,
        activation = activation
    )
    self.light_propagation = spatially_adaptive_unet(
        depth = depth,
        dimensions = dimensions,
        input_channels = input_channels,
        out_channels = out_channels,
        kernel_size = kernel_size,
        bias = bias,
        activation = activation
    )

forward(focal_surface, phase_only_hologram)

Forward pass through the model.

Parameters:

  • focal_surface
                    Input focal surface.
    
  • phase_only_hologram
                    Input phase-only hologram.
    

Returns:

  • result ( Tensor ) –

    Output tensor after light propagation.

Source code in odak/learn/wave/models.py
def forward(self, focal_surface, phase_only_hologram):
    """
    Forward pass through the model.

    Parameters
    ----------
    focal_surface         : torch.Tensor
                            Input focal surface.
    phase_only_hologram   : torch.Tensor
                            Input phase-only hologram.

    Returns
    ----------
    result                : torch.Tensor
                            Output tensor after light propagation.
    """
    input_field = self.generate_input_field(phase_only_hologram)
    sv_kernel = self.sv_kernel_generation(focal_surface, input_field)
    output_field = self.light_propagation(sv_kernel, input_field)
    final = (output_field[:, 0:3, :, :] + 1j * output_field[:, 3:6, :, :])
    result = calculate_amplitude(final) ** 2
    return result

generate_input_field(phase_only_hologram)

Generates an input field by combining the real and imaginary parts.

Parameters:

  • phase_only_hologram
                    Input phase-only hologram.
    

Returns:

  • input_field ( Tensor ) –

    Concatenated real and imaginary parts of the complex field.

Source code in odak/learn/wave/models.py
def generate_input_field(self, phase_only_hologram):
    """
    Generates an input field by combining the real and imaginary parts.

    Parameters
    ----------
    phase_only_hologram   : torch.Tensor
                            Input phase-only hologram.

    Returns
    ----------
    input_field           : torch.Tensor
                            Concatenated real and imaginary parts of the complex field.
    """
    [b, c, h, w] = phase_only_hologram.size()
    input_phase = phase_only_hologram * 2 * np.pi
    hologram_amplitude = torch.ones(b, c, h, w, requires_grad = False).to(self.device)
    field = generate_complex_field(hologram_amplitude, input_phase)
    input_field = torch.cat((field.real, field.imag), dim = 1)
    return input_field

load_weights(weight_filename, key_mapping_filename)

Function to load weights for this multi-layer perceptron from a file.

Parameters:

  • weight_filename
                   Path to the old model's weight file.
    
  • key_mapping_filename (str) –
                   Path to the JSON file containing the key mappings.
    
Source code in odak/learn/wave/models.py
def load_weights(self, weight_filename, key_mapping_filename):
    """
    Function to load weights for this multi-layer perceptron from a file.

    Parameters
    ----------
    weight_filename      : str
                           Path to the old model's weight file.
    key_mapping_filename : str
                           Path to the JSON file containing the key mappings.
    """
    # Load old model weights
    old_model_weights = torch.load(weight_filename, map_location = self.device,weights_only=True)

    # Load key mappings from JSON file
    with open(key_mapping_filename, 'r') as json_file:
        key_mappings = json.load(json_file)

    # Extract the key mappings for sv_kernel_generation and light_prop
    sv_kernel_generation_key_mapping = key_mappings['sv_kernel_generation_key_mapping']
    light_prop_key_mapping = key_mappings['light_prop_key_mapping']

    # Initialize new state dicts
    sv_kernel_generation_new_state_dict = {}
    light_prop_new_state_dict = {}

    # Map and load sv_kernel_generation_model weights
    for old_key, value in old_model_weights.items():
        if old_key in sv_kernel_generation_key_mapping:
            # Map the old key to the new key
            new_key = sv_kernel_generation_key_mapping[old_key]
            sv_kernel_generation_new_state_dict[new_key] = value

    self.sv_kernel_generation.to(self.device)
    self.sv_kernel_generation.load_state_dict(sv_kernel_generation_new_state_dict)

    # Map and load light_prop model weights
    for old_key, value in old_model_weights.items():
        if old_key in light_prop_key_mapping:
            # Map the old key to the new key
            new_key = light_prop_key_mapping[old_key]
            light_prop_new_state_dict[new_key] = value
    self.light_propagation.to(self.device)
    self.light_propagation.load_state_dict(light_prop_new_state_dict)

global_feature_module

Bases: Module

A global feature layer that processes global features from input channels and applies them to another input tensor via learned transformations.

Source code in odak/learn/models/components.py
class global_feature_module(torch.nn.Module):
    """
    A global feature layer that processes global features from input channels and
    applies them to another input tensor via learned transformations.
    """
    def __init__(
                 self,
                 input_channels,
                 mid_channels,
                 output_channels,
                 kernel_size,
                 bias = False,
                 normalization = False,
                 activation = torch.nn.ReLU()
                ):
        """
        A global feature layer.

        Parameters
        ----------
        input_channels  : int
                          Number of input channels.
        mid_channels  : int
                          Number of mid channels.
        output_channels : int
                          Number of output channels.
        kernel_size     : int
                          Kernel size.
        bias            : bool
                          Set to True to let convolutional layers have bias term.
        normalization   : bool
                          If True, adds a Batch Normalization layer after the convolutional layer.
        activation      : torch.nn
                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
        """
        super().__init__()
        self.transformations_1 = global_transformations(input_channels, output_channels)
        self.global_features_1 = double_convolution(
                                                    input_channels = input_channels,
                                                    mid_channels = mid_channels,
                                                    output_channels = output_channels,
                                                    kernel_size = kernel_size,
                                                    bias = bias,
                                                    normalization = normalization,
                                                    activation = activation
                                                   )
        self.global_features_2 = double_convolution(
                                                    input_channels = input_channels,
                                                    mid_channels = mid_channels,
                                                    output_channels = output_channels,
                                                    kernel_size = kernel_size,
                                                    bias = bias,
                                                    normalization = normalization,
                                                    activation = activation
                                                   )
        self.transformations_2 = global_transformations(input_channels, output_channels)


    def forward(self, x1, x2):
        """
        Forward model.

        Parameters
        ----------
        x1             : torch.tensor
                         First input data.
        x2             : torch.tensor
                         Second input data.

        Returns
        ----------
        result        : torch.tensor
                        Estimated output.
        """
        global_tensor_1 = self.transformations_1(x1, x2)
        y1 = self.global_features_1(global_tensor_1)
        y2 = self.global_features_2(y1)
        global_tensor_2 = self.transformations_2(y1, y2)
        return global_tensor_2

__init__(input_channels, mid_channels, output_channels, kernel_size, bias=False, normalization=False, activation=torch.nn.ReLU())

A global feature layer.

Parameters:

  • input_channels
              Number of input channels.
    
  • mid_channels
              Number of mid channels.
    
  • output_channels (int) –
              Number of output channels.
    
  • kernel_size
              Kernel size.
    
  • bias
              Set to True to let convolutional layers have bias term.
    
  • normalization
              If True, adds a Batch Normalization layer after the convolutional layer.
    
  • activation
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels,
             mid_channels,
             output_channels,
             kernel_size,
             bias = False,
             normalization = False,
             activation = torch.nn.ReLU()
            ):
    """
    A global feature layer.

    Parameters
    ----------
    input_channels  : int
                      Number of input channels.
    mid_channels  : int
                      Number of mid channels.
    output_channels : int
                      Number of output channels.
    kernel_size     : int
                      Kernel size.
    bias            : bool
                      Set to True to let convolutional layers have bias term.
    normalization   : bool
                      If True, adds a Batch Normalization layer after the convolutional layer.
    activation      : torch.nn
                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    """
    super().__init__()
    self.transformations_1 = global_transformations(input_channels, output_channels)
    self.global_features_1 = double_convolution(
                                                input_channels = input_channels,
                                                mid_channels = mid_channels,
                                                output_channels = output_channels,
                                                kernel_size = kernel_size,
                                                bias = bias,
                                                normalization = normalization,
                                                activation = activation
                                               )
    self.global_features_2 = double_convolution(
                                                input_channels = input_channels,
                                                mid_channels = mid_channels,
                                                output_channels = output_channels,
                                                kernel_size = kernel_size,
                                                bias = bias,
                                                normalization = normalization,
                                                activation = activation
                                               )
    self.transformations_2 = global_transformations(input_channels, output_channels)

forward(x1, x2)

Forward model.

Parameters:

  • x1
             First input data.
    
  • x2
             Second input data.
    

Returns:

  • result ( tensor ) –

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x1, x2):
    """
    Forward model.

    Parameters
    ----------
    x1             : torch.tensor
                     First input data.
    x2             : torch.tensor
                     Second input data.

    Returns
    ----------
    result        : torch.tensor
                    Estimated output.
    """
    global_tensor_1 = self.transformations_1(x1, x2)
    y1 = self.global_features_1(global_tensor_1)
    y2 = self.global_features_2(y1)
    global_tensor_2 = self.transformations_2(y1, y2)
    return global_tensor_2

global_transformations

Bases: Module

A global feature layer that processes global features from input channels and applies learned transformations to another input tensor.

This implementation is adapted from RSGUnet: https://github.com/MTLab/rsgunet_image_enhance.

Reference: J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices."

Source code in odak/learn/models/components.py
class global_transformations(torch.nn.Module):
    """
    A global feature layer that processes global features from input channels and
    applies learned transformations to another input tensor.

    This implementation is adapted from RSGUnet:
    https://github.com/MTLab/rsgunet_image_enhance.

    Reference:
    J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices."
    """
    def __init__(
                 self,
                 input_channels,
                 output_channels
                ):
        """
        A global feature layer.

        Parameters
        ----------
        input_channels  : int
                          Number of input channels.
        output_channels : int
                          Number of output channels.
        """
        super().__init__()
        self.global_feature_1 = torch.nn.Sequential(
            torch.nn.Linear(input_channels, output_channels),
            torch.nn.LeakyReLU(0.2, inplace = True),
        )
        self.global_feature_2 = torch.nn.Sequential(
            torch.nn.Linear(output_channels, output_channels),
            torch.nn.LeakyReLU(0.2, inplace = True)
        )


    def forward(self, x1, x2):
        """
        Forward model.

        Parameters
        ----------
        x1             : torch.tensor
                         First input data.
        x2             : torch.tensor
                         Second input data.

        Returns
        ----------
        result        : torch.tensor
                        Estimated output.
        """
        y = torch.mean(x2, dim = (2, 3))
        y1 = self.global_feature_1(y)
        y2 = self.global_feature_2(y1)
        y1 = y1.unsqueeze(2).unsqueeze(3)
        y2 = y2.unsqueeze(2).unsqueeze(3)
        result = x1 * y1 + y2
        return result

__init__(input_channels, output_channels)

A global feature layer.

Parameters:

  • input_channels
              Number of input channels.
    
  • output_channels (int) –
              Number of output channels.
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels,
             output_channels
            ):
    """
    A global feature layer.

    Parameters
    ----------
    input_channels  : int
                      Number of input channels.
    output_channels : int
                      Number of output channels.
    """
    super().__init__()
    self.global_feature_1 = torch.nn.Sequential(
        torch.nn.Linear(input_channels, output_channels),
        torch.nn.LeakyReLU(0.2, inplace = True),
    )
    self.global_feature_2 = torch.nn.Sequential(
        torch.nn.Linear(output_channels, output_channels),
        torch.nn.LeakyReLU(0.2, inplace = True)
    )

forward(x1, x2)

Forward model.

Parameters:

  • x1
             First input data.
    
  • x2
             Second input data.
    

Returns:

  • result ( tensor ) –

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x1, x2):
    """
    Forward model.

    Parameters
    ----------
    x1             : torch.tensor
                     First input data.
    x2             : torch.tensor
                     Second input data.

    Returns
    ----------
    result        : torch.tensor
                    Estimated output.
    """
    y = torch.mean(x2, dim = (2, 3))
    y1 = self.global_feature_1(y)
    y2 = self.global_feature_2(y1)
    y1 = y1.unsqueeze(2).unsqueeze(3)
    y2 = y2.unsqueeze(2).unsqueeze(3)
    result = x1 * y1 + y2
    return result

holobeam_multiholo

Bases: Module

The learned holography model used in the paper, Akşit, Kaan, and Yuta Itoh. "HoloBeam: Paper-Thin Near-Eye Displays." In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), pp. 581-591. IEEE, 2023.

Parameters:

  • n_input
                Number of channels in the input.
    
  • n_hidden
                Number of channels in the hidden layers.
    
  • n_output
                Number of channels in the output layer.
    
  • device
                Default device is CPU.
    
  • reduction
                Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.
    
Source code in odak/learn/wave/models.py
class holobeam_multiholo(torch.nn.Module):
    """
    The learned holography model used in the paper, Akşit, Kaan, and Yuta Itoh. "HoloBeam: Paper-Thin Near-Eye Displays." In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), pp. 581-591. IEEE, 2023.


    Parameters
    ----------
    n_input           : int
                        Number of channels in the input.
    n_hidden          : int
                        Number of channels in the hidden layers.
    n_output          : int
                        Number of channels in the output layer.
    device            : torch.device
                        Default device is CPU.
    reduction         : str
                        Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.
    """
    def __init__(
                 self,
                 n_input = 1,
                 n_hidden = 16,
                 n_output = 2,
                 device = torch.device('cpu'),
                 reduction = 'sum'
                ):
        super(holobeam_multiholo, self).__init__()
        torch.random.seed()
        self.device = device
        self.reduction = reduction
        self.l2 = torch.nn.MSELoss(reduction = self.reduction)
        self.l1 = torch.nn.L1Loss(reduction = self.reduction)
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.n_output = n_output
        self.network = unet(
                            dimensions = self.n_hidden,
                            input_channels = self.n_input,
                            output_channels = self.n_output
                           ).to(self.device)


    def forward(self, x, test = False):
        """
        Internal function representing the forward model.
        """
        if test:
            torch.no_grad()
        y = self.network.forward(x) 
        phase_low = y[:, 0].unsqueeze(1)
        phase_high = y[:, 1].unsqueeze(1)
        phase_only = torch.zeros_like(phase_low)
        phase_only[:, :, 0::2, 0::2] = phase_low[:, :,  0::2, 0::2]
        phase_only[:, :, 1::2, 1::2] = phase_low[:, :, 1::2, 1::2]
        phase_only[:, :, 0::2, 1::2] = phase_high[:, :, 0::2, 1::2]
        phase_only[:, :, 1::2, 0::2] = phase_high[:, :, 1::2, 0::2]
        return phase_only


    def evaluate(self, input_data, ground_truth, weights = [1., 0.1]):
        """
        Internal function for evaluating.
        """
        loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)
        return loss


    def fit(self, dataloader, number_of_epochs = 100, learning_rate = 1e-5, directory = './output', save_at_every = 100):
        """
        Function to train the weights of the multi layer perceptron.

        Parameters
        ----------
        dataloader       : torch.utils.data.DataLoader
                           Data loader.
        number_of_epochs : int
                           Number of epochs.
        learning_rate    : float
                           Learning rate of the optimizer.
        directory        : str
                           Output directory.
        save_at_every    : int
                           Save the model at every given epoch count.
        """
        t_epoch = tqdm(range(number_of_epochs), leave=False, dynamic_ncols = True)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        for i in t_epoch:
            epoch_loss = 0.
            t_data = tqdm(dataloader, leave=False, dynamic_ncols = True)
            for j, data in enumerate(t_data):
                self.optimizer.zero_grad()
                images, holograms = data
                estimates = self.forward(images)
                loss = self.evaluate(estimates, holograms)
                loss.backward(retain_graph=True)
                self.optimizer.step()
                description = 'Loss:{:.4f}'.format(loss.item())
                t_data.set_description(description)
                epoch_loss += float(loss.item()) / dataloader.__len__()
            description = 'Epoch Loss:{:.4f}'.format(epoch_loss)
            t_epoch.set_description(description)
            if i % save_at_every == 0:
                self.save_weights(filename='{}/weights_{:04d}.pt'.format(directory, i))
        self.save_weights(filename='{}/weights.pt'.format(directory))
        print(description)


    def save_weights(self, filename = './weights.pt'):
        """
        Function to save the current weights of the multi layer perceptron to a file.
        Parameters
        ----------
        filename        : str
                          Filename.
        """
        torch.save(self.network.state_dict(), os.path.expanduser(filename))


    def load_weights(self, filename = './weights.pt'):
        """
        Function to load weights for this multi layer perceptron from a file.
        Parameters
        ----------
        filename        : str
                          Filename.
        """
        self.network.load_state_dict(torch.load(os.path.expanduser(filename)))
        self.network.eval()

evaluate(input_data, ground_truth, weights=[1.0, 0.1])

Internal function for evaluating.

Source code in odak/learn/wave/models.py
def evaluate(self, input_data, ground_truth, weights = [1., 0.1]):
    """
    Internal function for evaluating.
    """
    loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)
    return loss

fit(dataloader, number_of_epochs=100, learning_rate=1e-05, directory='./output', save_at_every=100)

Function to train the weights of the multi layer perceptron.

Parameters:

  • dataloader
               Data loader.
    
  • number_of_epochs (int, default: 100 ) –
               Number of epochs.
    
  • learning_rate
               Learning rate of the optimizer.
    
  • directory
               Output directory.
    
  • save_at_every
               Save the model at every given epoch count.
    
Source code in odak/learn/wave/models.py
def fit(self, dataloader, number_of_epochs = 100, learning_rate = 1e-5, directory = './output', save_at_every = 100):
    """
    Function to train the weights of the multi layer perceptron.

    Parameters
    ----------
    dataloader       : torch.utils.data.DataLoader
                       Data loader.
    number_of_epochs : int
                       Number of epochs.
    learning_rate    : float
                       Learning rate of the optimizer.
    directory        : str
                       Output directory.
    save_at_every    : int
                       Save the model at every given epoch count.
    """
    t_epoch = tqdm(range(number_of_epochs), leave=False, dynamic_ncols = True)
    self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
    for i in t_epoch:
        epoch_loss = 0.
        t_data = tqdm(dataloader, leave=False, dynamic_ncols = True)
        for j, data in enumerate(t_data):
            self.optimizer.zero_grad()
            images, holograms = data
            estimates = self.forward(images)
            loss = self.evaluate(estimates, holograms)
            loss.backward(retain_graph=True)
            self.optimizer.step()
            description = 'Loss:{:.4f}'.format(loss.item())
            t_data.set_description(description)
            epoch_loss += float(loss.item()) / dataloader.__len__()
        description = 'Epoch Loss:{:.4f}'.format(epoch_loss)
        t_epoch.set_description(description)
        if i % save_at_every == 0:
            self.save_weights(filename='{}/weights_{:04d}.pt'.format(directory, i))
    self.save_weights(filename='{}/weights.pt'.format(directory))
    print(description)

forward(x, test=False)

Internal function representing the forward model.

Source code in odak/learn/wave/models.py
def forward(self, x, test = False):
    """
    Internal function representing the forward model.
    """
    if test:
        torch.no_grad()
    y = self.network.forward(x) 
    phase_low = y[:, 0].unsqueeze(1)
    phase_high = y[:, 1].unsqueeze(1)
    phase_only = torch.zeros_like(phase_low)
    phase_only[:, :, 0::2, 0::2] = phase_low[:, :,  0::2, 0::2]
    phase_only[:, :, 1::2, 1::2] = phase_low[:, :, 1::2, 1::2]
    phase_only[:, :, 0::2, 1::2] = phase_high[:, :, 0::2, 1::2]
    phase_only[:, :, 1::2, 0::2] = phase_high[:, :, 1::2, 0::2]
    return phase_only

load_weights(filename='./weights.pt')

Function to load weights for this multi layer perceptron from a file.

Parameters:

  • filename
              Filename.
    
Source code in odak/learn/wave/models.py
def load_weights(self, filename = './weights.pt'):
    """
    Function to load weights for this multi layer perceptron from a file.
    Parameters
    ----------
    filename        : str
                      Filename.
    """
    self.network.load_state_dict(torch.load(os.path.expanduser(filename)))
    self.network.eval()

save_weights(filename='./weights.pt')

Function to save the current weights of the multi layer perceptron to a file.

Parameters:

  • filename
              Filename.
    
Source code in odak/learn/wave/models.py
def save_weights(self, filename = './weights.pt'):
    """
    Function to save the current weights of the multi layer perceptron to a file.
    Parameters
    ----------
    filename        : str
                      Filename.
    """
    torch.save(self.network.state_dict(), os.path.expanduser(filename))

multi_layer_perceptron

Bases: Module

A multi-layer perceptron model.

Source code in odak/learn/models/models.py
class multi_layer_perceptron(torch.nn.Module):
    """
    A multi-layer perceptron model.
    """

    def __init__(self,
                 dimensions,
                 activation = torch.nn.ReLU(),
                 bias = False,
                 model_type = 'conventional',
                 siren_multiplier = 1.,
                 input_multiplier = None
                ):
        """
        Parameters
        ----------
        dimensions        : list
                            List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
        activation        : torch.nn
                            Nonlinear activation function.
                            Default is `torch.nn.ReLU()`.
        bias              : bool
                            If set to True, linear layers will include biases.
        siren_multiplier  : float
                            When using `SIREN` model type, this parameter functions as a hyperparameter.
                            The original SIREN work uses 30.
                            You can bypass this parameter by providing input that are not normalized and larger then one.
        input_multiplier  : float
                            Initial value of the input multiplier before the very first layer.
        model_type        : str
                            Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
                            `conventional` refers to a standard multi layer perceptron.
                            For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
                            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
                            For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
                            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
        """
        super(multi_layer_perceptron, self).__init__()
        self.activation = activation
        self.bias = bias
        self.model_type = model_type
        self.layers = torch.nn.ModuleList()
        self.siren_multiplier = siren_multiplier
        self.dimensions = dimensions
        for i in range(len(self.dimensions) - 1):
            self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))
        if not isinstance(input_multiplier, type(None)):
            self.input_multiplier = torch.nn.ParameterList()
            self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))
        if self.model_type == 'FILM SIREN':
            self.alpha = torch.nn.ParameterList()
            for j in self.dimensions[1::]:
                self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))
        if self.model_type == 'Gaussian':
            self.alpha = torch.nn.ParameterList()
            for j in self.dimensions[1::]:
                self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))


    def forward(self, x):
        """
        Forward model.

        Parameters
        ----------
        x             : torch.tensor
                        Input data.


        Returns
        ----------
        result        : torch.tensor
                        Estimated output.      
        """
        if hasattr(self, 'input_multiplier'):
            result = x * self.input_multiplier[0]
        else:
            result = x
        for layer_id, layer in enumerate(self.layers):
            result = layer(result)
            if self.model_type == 'conventional' and layer_id != len(self.layers) -1:
                result = self.activation(result)
            elif self.model_type == 'swish' and layer_id != len(self.layers) - 1:
                result = swish(result)
            elif self.model_type == 'SIREN' and layer_id != len(self.layers) - 1:
                result = torch.sin(result * self.siren_multiplier)
            elif self.model_type == 'FILM SIREN' and layer_id != len(self.layers) - 1:
                result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])
            elif self.model_type == 'Gaussian' and layer_id != len(self.layers) - 1: 
                result = gaussian(result, self.alpha[layer_id][0])
        return result

__init__(dimensions, activation=torch.nn.ReLU(), bias=False, model_type='conventional', siren_multiplier=1.0, input_multiplier=None)

Parameters:

  • dimensions
                List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
    
  • activation
                Nonlinear activation function.
                Default is `torch.nn.ReLU()`.
    
  • bias
                If set to True, linear layers will include biases.
    
  • siren_multiplier
                When using `SIREN` model type, this parameter functions as a hyperparameter.
                The original SIREN work uses 30.
                You can bypass this parameter by providing input that are not normalized and larger then one.
    
  • input_multiplier
                Initial value of the input multiplier before the very first layer.
    
  • model_type
                Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
                `conventional` refers to a standard multi layer perceptron.
                For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
                For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
                For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
                For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
    
Source code in odak/learn/models/models.py
def __init__(self,
             dimensions,
             activation = torch.nn.ReLU(),
             bias = False,
             model_type = 'conventional',
             siren_multiplier = 1.,
             input_multiplier = None
            ):
    """
    Parameters
    ----------
    dimensions        : list
                        List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
    activation        : torch.nn
                        Nonlinear activation function.
                        Default is `torch.nn.ReLU()`.
    bias              : bool
                        If set to True, linear layers will include biases.
    siren_multiplier  : float
                        When using `SIREN` model type, this parameter functions as a hyperparameter.
                        The original SIREN work uses 30.
                        You can bypass this parameter by providing input that are not normalized and larger then one.
    input_multiplier  : float
                        Initial value of the input multiplier before the very first layer.
    model_type        : str
                        Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
                        `conventional` refers to a standard multi layer perceptron.
                        For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
                        For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
                        For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
                        For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
    """
    super(multi_layer_perceptron, self).__init__()
    self.activation = activation
    self.bias = bias
    self.model_type = model_type
    self.layers = torch.nn.ModuleList()
    self.siren_multiplier = siren_multiplier
    self.dimensions = dimensions
    for i in range(len(self.dimensions) - 1):
        self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))
    if not isinstance(input_multiplier, type(None)):
        self.input_multiplier = torch.nn.ParameterList()
        self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))
    if self.model_type == 'FILM SIREN':
        self.alpha = torch.nn.ParameterList()
        for j in self.dimensions[1::]:
            self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))
    if self.model_type == 'Gaussian':
        self.alpha = torch.nn.ParameterList()
        for j in self.dimensions[1::]:
            self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))

forward(x)

Forward model.

Parameters:

  • x
            Input data.
    

Returns:

  • result ( tensor ) –

    Estimated output.

Source code in odak/learn/models/models.py
def forward(self, x):
    """
    Forward model.

    Parameters
    ----------
    x             : torch.tensor
                    Input data.


    Returns
    ----------
    result        : torch.tensor
                    Estimated output.      
    """
    if hasattr(self, 'input_multiplier'):
        result = x * self.input_multiplier[0]
    else:
        result = x
    for layer_id, layer in enumerate(self.layers):
        result = layer(result)
        if self.model_type == 'conventional' and layer_id != len(self.layers) -1:
            result = self.activation(result)
        elif self.model_type == 'swish' and layer_id != len(self.layers) - 1:
            result = swish(result)
        elif self.model_type == 'SIREN' and layer_id != len(self.layers) - 1:
            result = torch.sin(result * self.siren_multiplier)
        elif self.model_type == 'FILM SIREN' and layer_id != len(self.layers) - 1:
            result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])
        elif self.model_type == 'Gaussian' and layer_id != len(self.layers) - 1: 
            result = gaussian(result, self.alpha[layer_id][0])
    return result

non_local_layer

Bases: Module

Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)

Source code in odak/learn/models/components.py
class non_local_layer(torch.nn.Module):
    """
    Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)
    """
    def __init__(
                 self,
                 input_channels = 1024,
                 bottleneck_channels = 512,
                 kernel_size = 1,
                 bias = False,
                ):
        """

        Parameters
        ----------
        input_channels      : int
                              Number of input channels.
        bottleneck_channels : int
                              Number of middle channels.
        kernel_size         : int
                              Kernel size.
        bias                : bool 
                              Set to True to let convolutional layers have bias term.
        """
        super(non_local_layer, self).__init__()
        self.input_channels = input_channels
        self.bottleneck_channels = bottleneck_channels
        self.g = torch.nn.Conv2d(
                                 self.input_channels, 
                                 self.bottleneck_channels,
                                 kernel_size = kernel_size,
                                 padding = kernel_size // 2,
                                 bias = bias
                                )
        self.W_z = torch.nn.Sequential(
                                       torch.nn.Conv2d(
                                                       self.bottleneck_channels,
                                                       self.input_channels, 
                                                       kernel_size = kernel_size,
                                                       bias = bias,
                                                       padding = kernel_size // 2
                                                      ),
                                       torch.nn.BatchNorm2d(self.input_channels)
                                      )
        torch.nn.init.constant_(self.W_z[1].weight, 0)   
        torch.nn.init.constant_(self.W_z[1].bias, 0)


    def forward(self, x):
        """
        Forward model [zi = Wzyi + xi]

        Parameters
        ----------
        x               : torch.tensor
                          First input data.                       


        Returns
        ----------
        z               : torch.tensor
                          Estimated output.
        """
        batch_size, channels, height, width = x.size()
        theta = x.view(batch_size, channels, -1).permute(0, 2, 1)
        phi = x.view(batch_size, channels, -1).permute(0, 2, 1)
        g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)
        attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)
        attn = torch.nn.functional.softmax(attn, dim=-1)
        y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)
        W_y = self.W_z(y)
        z = W_y + x
        return z

__init__(input_channels=1024, bottleneck_channels=512, kernel_size=1, bias=False)

Parameters:

  • input_channels
                  Number of input channels.
    
  • bottleneck_channels (int, default: 512 ) –
                  Number of middle channels.
    
  • kernel_size
                  Kernel size.
    
  • bias
                  Set to True to let convolutional layers have bias term.
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels = 1024,
             bottleneck_channels = 512,
             kernel_size = 1,
             bias = False,
            ):
    """

    Parameters
    ----------
    input_channels      : int
                          Number of input channels.
    bottleneck_channels : int
                          Number of middle channels.
    kernel_size         : int
                          Kernel size.
    bias                : bool 
                          Set to True to let convolutional layers have bias term.
    """
    super(non_local_layer, self).__init__()
    self.input_channels = input_channels
    self.bottleneck_channels = bottleneck_channels
    self.g = torch.nn.Conv2d(
                             self.input_channels, 
                             self.bottleneck_channels,
                             kernel_size = kernel_size,
                             padding = kernel_size // 2,
                             bias = bias
                            )
    self.W_z = torch.nn.Sequential(
                                   torch.nn.Conv2d(
                                                   self.bottleneck_channels,
                                                   self.input_channels, 
                                                   kernel_size = kernel_size,
                                                   bias = bias,
                                                   padding = kernel_size // 2
                                                  ),
                                   torch.nn.BatchNorm2d(self.input_channels)
                                  )
    torch.nn.init.constant_(self.W_z[1].weight, 0)   
    torch.nn.init.constant_(self.W_z[1].bias, 0)

forward(x)

Forward model [zi = Wzyi + xi]

Parameters:

  • x
              First input data.
    

Returns:

  • z ( tensor ) –

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):
    """
    Forward model [zi = Wzyi + xi]

    Parameters
    ----------
    x               : torch.tensor
                      First input data.                       


    Returns
    ----------
    z               : torch.tensor
                      Estimated output.
    """
    batch_size, channels, height, width = x.size()
    theta = x.view(batch_size, channels, -1).permute(0, 2, 1)
    phi = x.view(batch_size, channels, -1).permute(0, 2, 1)
    g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)
    attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)
    attn = torch.nn.functional.softmax(attn, dim=-1)
    y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)
    W_y = self.W_z(y)
    z = W_y + x
    return z

normalization

Bases: Module

A normalization layer.

Source code in odak/learn/models/components.py
class normalization(torch.nn.Module):
    """
    A normalization layer.
    """
    def __init__(
                 self,
                 dim = 1,
                ):
        """
        Normalization layer.


        Parameters
        ----------
        dim             : int
                          Dimension (axis) to normalize.
        """
        super().__init__()
        self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))


    def forward(self, x):
        """
        Forward model.

        Parameters
        ----------
        x             : torch.tensor
                        Input data.


        Returns
        ----------
        result        : torch.tensor
                        Estimated output.      
        """
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        result =  (x - mean) * (var + eps).rsqrt() * self.k
        return result 

__init__(dim=1)

Normalization layer.

Parameters:

  • dim
              Dimension (axis) to normalize.
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             dim = 1,
            ):
    """
    Normalization layer.


    Parameters
    ----------
    dim             : int
                      Dimension (axis) to normalize.
    """
    super().__init__()
    self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))

forward(x)

Forward model.

Parameters:

  • x
            Input data.
    

Returns:

  • result ( tensor ) –

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):
    """
    Forward model.

    Parameters
    ----------
    x             : torch.tensor
                    Input data.


    Returns
    ----------
    result        : torch.tensor
                    Estimated output.      
    """
    eps = 1e-5 if x.dtype == torch.float32 else 1e-3
    var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
    mean = torch.mean(x, dim = 1, keepdim = True)
    result =  (x - mean) * (var + eps).rsqrt() * self.k
    return result 

positional_encoder

Bases: Module

A positional encoder module. This implementation follows this specific work: Martin-Brualla, Ricardo, Noha Radwan, Mehdi SM Sajjadi, Jonathan T. Barron, Alexey Dosovitskiy, and Daniel Duckworth. "Nerf in the wild: Neural radiance fields for unconstrained photo collections." In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 7210-7219. 2021..

Source code in odak/learn/models/components.py
class positional_encoder(torch.nn.Module):
    """
    A positional encoder module.
    This implementation follows this specific work: `Martin-Brualla, Ricardo, Noha Radwan, Mehdi SM Sajjadi, Jonathan T. Barron, Alexey Dosovitskiy, and Daniel Duckworth. "Nerf in the wild: Neural radiance fields for unconstrained photo collections." In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 7210-7219. 2021.`.
    """

    def __init__(self, L):
        """
        A positional encoder module.

        Parameters
        ----------
        L                   : int
                              Positional encoding level.
        """
        super(positional_encoder, self).__init__()
        self.L = L



    def forward(self, x):
        """
        Forward model.

        Parameters
        ----------
        x               : torch.tensor
                          Input data [b x n], where `b` is batch size, `n` is the feature size.

        Returns
        ----------
        result          : torch.tensor
                          Result of the forward operation.
        """
        freqs = 2 ** torch.arange(self.L, device = x.device)
        freqs = freqs.view(1, 1, -1)
        results_cos = torch.cos(x.unsqueeze(-1) * freqs).reshape(x.shape[0], -1)
        results_sin = torch.sin(x.unsqueeze(-1) * freqs).reshape(x.shape[0], -1)
        results = torch.cat((x, results_cos, results_sin), dim = 1)
        return results

__init__(L)

A positional encoder module.

Parameters:

  • L
                  Positional encoding level.
    
Source code in odak/learn/models/components.py
def __init__(self, L):
    """
    A positional encoder module.

    Parameters
    ----------
    L                   : int
                          Positional encoding level.
    """
    super(positional_encoder, self).__init__()
    self.L = L

forward(x)

Forward model.

Parameters:

  • x
              Input data [b x n], where `b` is batch size, `n` is the feature size.
    

Returns:

  • result ( tensor ) –

    Result of the forward operation.

Source code in odak/learn/models/components.py
def forward(self, x):
    """
    Forward model.

    Parameters
    ----------
    x               : torch.tensor
                      Input data [b x n], where `b` is batch size, `n` is the feature size.

    Returns
    ----------
    result          : torch.tensor
                      Result of the forward operation.
    """
    freqs = 2 ** torch.arange(self.L, device = x.device)
    freqs = freqs.view(1, 1, -1)
    results_cos = torch.cos(x.unsqueeze(-1) * freqs).reshape(x.shape[0], -1)
    results_sin = torch.sin(x.unsqueeze(-1) * freqs).reshape(x.shape[0], -1)
    results = torch.cat((x, results_cos, results_sin), dim = 1)
    return results

residual_attention_layer

Bases: Module

A residual block with an attention layer.

Source code in odak/learn/models/components.py
class residual_attention_layer(torch.nn.Module):
    """
    A residual block with an attention layer.
    """
    def __init__(
                 self,
                 input_channels = 2,
                 output_channels = 2,
                 kernel_size = 1,
                 bias = False,
                 activation = torch.nn.ReLU()
                ):
        """
        An attention layer class.


        Parameters
        ----------
        input_channels  : int or optioal
                          Number of input channels.
        output_channels : int or optional
                          Number of middle channels.
        kernel_size     : int or optional
                          Kernel size.
        bias            : bool or optional
                          Set to True to let convolutional layers have bias term.
        activation      : torch.nn or optional
                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
        """
        super().__init__()
        self.activation = activation
        self.convolution0 = torch.nn.Sequential(
                                                torch.nn.Conv2d(
                                                                input_channels,
                                                                output_channels,
                                                                kernel_size = kernel_size,
                                                                padding = kernel_size // 2,
                                                                bias = bias
                                                               ),
                                                torch.nn.BatchNorm2d(output_channels)
                                               )
        self.convolution1 = torch.nn.Sequential(
                                                torch.nn.Conv2d(
                                                                input_channels,
                                                                output_channels,
                                                                kernel_size = kernel_size,
                                                                padding = kernel_size // 2,
                                                                bias = bias
                                                               ),
                                                torch.nn.BatchNorm2d(output_channels)
                                               )
        self.final_layer = torch.nn.Sequential(
                                               self.activation,
                                               torch.nn.Conv2d(
                                                               output_channels,
                                                               output_channels,
                                                               kernel_size = kernel_size,
                                                               padding = kernel_size // 2,
                                                               bias = bias
                                                              )
                                              )


    def forward(self, x0, x1):
        """
        Forward model.

        Parameters
        ----------
        x0             : torch.tensor
                         First input data.

        x1             : torch.tensor
                         Seconnd input data.


        Returns
        ----------
        result        : torch.tensor
                        Estimated output.      
        """
        y0 = self.convolution0(x0)
        y1 = self.convolution1(x1)
        y2 = torch.add(y0, y1)
        result = self.final_layer(y2) * x0
        return result

__init__(input_channels=2, output_channels=2, kernel_size=1, bias=False, activation=torch.nn.ReLU())

An attention layer class.

Parameters:

  • input_channels
              Number of input channels.
    
  • output_channels (int or optional, default: 2 ) –
              Number of middle channels.
    
  • kernel_size
              Kernel size.
    
  • bias
              Set to True to let convolutional layers have bias term.
    
  • activation
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels = 2,
             output_channels = 2,
             kernel_size = 1,
             bias = False,
             activation = torch.nn.ReLU()
            ):
    """
    An attention layer class.


    Parameters
    ----------
    input_channels  : int or optioal
                      Number of input channels.
    output_channels : int or optional
                      Number of middle channels.
    kernel_size     : int or optional
                      Kernel size.
    bias            : bool or optional
                      Set to True to let convolutional layers have bias term.
    activation      : torch.nn or optional
                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    """
    super().__init__()
    self.activation = activation
    self.convolution0 = torch.nn.Sequential(
                                            torch.nn.Conv2d(
                                                            input_channels,
                                                            output_channels,
                                                            kernel_size = kernel_size,
                                                            padding = kernel_size // 2,
                                                            bias = bias
                                                           ),
                                            torch.nn.BatchNorm2d(output_channels)
                                           )
    self.convolution1 = torch.nn.Sequential(
                                            torch.nn.Conv2d(
                                                            input_channels,
                                                            output_channels,
                                                            kernel_size = kernel_size,
                                                            padding = kernel_size // 2,
                                                            bias = bias
                                                           ),
                                            torch.nn.BatchNorm2d(output_channels)
                                           )
    self.final_layer = torch.nn.Sequential(
                                           self.activation,
                                           torch.nn.Conv2d(
                                                           output_channels,
                                                           output_channels,
                                                           kernel_size = kernel_size,
                                                           padding = kernel_size // 2,
                                                           bias = bias
                                                          )
                                          )

forward(x0, x1)

Forward model.

Parameters:

  • x0
             First input data.
    
  • x1
             Seconnd input data.
    

Returns:

  • result ( tensor ) –

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x0, x1):
    """
    Forward model.

    Parameters
    ----------
    x0             : torch.tensor
                     First input data.

    x1             : torch.tensor
                     Seconnd input data.


    Returns
    ----------
    result        : torch.tensor
                    Estimated output.      
    """
    y0 = self.convolution0(x0)
    y1 = self.convolution1(x1)
    y2 = torch.add(y0, y1)
    result = self.final_layer(y2) * x0
    return result

residual_layer

Bases: Module

A residual layer.

Source code in odak/learn/models/components.py
class residual_layer(torch.nn.Module):
    """
    A residual layer.
    """
    def __init__(
                 self,
                 input_channels = 2,
                 mid_channels = 16,
                 kernel_size = 3,
                 bias = False,
                 normalization = True,
                 activation = torch.nn.ReLU()
                ):
        """
        A convolutional layer class.


        Parameters
        ----------
        input_channels  : int
                          Number of input channels.
        mid_channels    : int
                          Number of middle channels.
        kernel_size     : int
                          Kernel size.
        bias            : bool 
                          Set to True to let convolutional layers have bias term.
        normalization   : bool                
                          If True, adds a Batch Normalization layer after the convolutional layer.
        activation      : torch.nn
                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
        """
        super().__init__()
        self.activation = activation
        self.convolution = double_convolution(
                                              input_channels,
                                              mid_channels = mid_channels,
                                              output_channels = input_channels,
                                              kernel_size = kernel_size,
                                              normalization = normalization,
                                              bias = bias,
                                              activation = activation
                                             )


    def forward(self, x):
        """
        Forward model.

        Parameters
        ----------
        x             : torch.tensor
                        Input data.


        Returns
        ----------
        result        : torch.tensor
                        Estimated output.      
        """
        x0 = self.convolution(x)
        return x + x0

__init__(input_channels=2, mid_channels=16, kernel_size=3, bias=False, normalization=True, activation=torch.nn.ReLU())

A convolutional layer class.

Parameters:

  • input_channels
              Number of input channels.
    
  • mid_channels
              Number of middle channels.
    
  • kernel_size
              Kernel size.
    
  • bias
              Set to True to let convolutional layers have bias term.
    
  • normalization
              If True, adds a Batch Normalization layer after the convolutional layer.
    
  • activation
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels = 2,
             mid_channels = 16,
             kernel_size = 3,
             bias = False,
             normalization = True,
             activation = torch.nn.ReLU()
            ):
    """
    A convolutional layer class.


    Parameters
    ----------
    input_channels  : int
                      Number of input channels.
    mid_channels    : int
                      Number of middle channels.
    kernel_size     : int
                      Kernel size.
    bias            : bool 
                      Set to True to let convolutional layers have bias term.
    normalization   : bool                
                      If True, adds a Batch Normalization layer after the convolutional layer.
    activation      : torch.nn
                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    """
    super().__init__()
    self.activation = activation
    self.convolution = double_convolution(
                                          input_channels,
                                          mid_channels = mid_channels,
                                          output_channels = input_channels,
                                          kernel_size = kernel_size,
                                          normalization = normalization,
                                          bias = bias,
                                          activation = activation
                                         )

forward(x)

Forward model.

Parameters:

  • x
            Input data.
    

Returns:

  • result ( tensor ) –

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):
    """
    Forward model.

    Parameters
    ----------
    x             : torch.tensor
                    Input data.


    Returns
    ----------
    result        : torch.tensor
                    Estimated output.      
    """
    x0 = self.convolution(x)
    return x + x0

spatial_gate

Bases: Module

Spatial attention module that applies a convolution layer after channel pooling. This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.

Source code in odak/learn/models/components.py
class spatial_gate(torch.nn.Module):
    """
    Spatial attention module that applies a convolution layer after channel pooling.
    This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.
    """
    def __init__(self):
        """
        Initializes the spatial gate module.
        """
        super().__init__()
        kernel_size = 7
        self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())


    def channel_pool(self, x):
        """
        Applies max and average pooling on the channels.

        Parameters
        ----------
        x             : torch.tensor
                        Input tensor.

        Returns
        -------
        output        : torch.tensor
                        Output tensor.
        """
        max_pool = torch.max(x, 1)[0].unsqueeze(1)
        avg_pool = torch.mean(x, 1).unsqueeze(1)
        output = torch.cat((max_pool, avg_pool), dim=1)
        return output


    def forward(self, x):
        """
        Forward pass of the SpatialGate module.

        Applies spatial attention to the input tensor.

        Parameters
        ----------
        x            : torch.tensor
                       Input tensor to the SpatialGate module.

        Returns
        -------
        scaled_x     : torch.tensor
                       Output tensor after applying spatial attention.
        """
        x_compress = self.channel_pool(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out)
        scaled_x = x * scale
        return scaled_x

__init__()

Initializes the spatial gate module.

Source code in odak/learn/models/components.py
def __init__(self):
    """
    Initializes the spatial gate module.
    """
    super().__init__()
    kernel_size = 7
    self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())

channel_pool(x)

Applies max and average pooling on the channels.

Parameters:

  • x
            Input tensor.
    

Returns:

  • output ( tensor ) –

    Output tensor.

Source code in odak/learn/models/components.py
def channel_pool(self, x):
    """
    Applies max and average pooling on the channels.

    Parameters
    ----------
    x             : torch.tensor
                    Input tensor.

    Returns
    -------
    output        : torch.tensor
                    Output tensor.
    """
    max_pool = torch.max(x, 1)[0].unsqueeze(1)
    avg_pool = torch.mean(x, 1).unsqueeze(1)
    output = torch.cat((max_pool, avg_pool), dim=1)
    return output

forward(x)

Forward pass of the SpatialGate module.

Applies spatial attention to the input tensor.

Parameters:

  • x
           Input tensor to the SpatialGate module.
    

Returns:

  • scaled_x ( tensor ) –

    Output tensor after applying spatial attention.

Source code in odak/learn/models/components.py
def forward(self, x):
    """
    Forward pass of the SpatialGate module.

    Applies spatial attention to the input tensor.

    Parameters
    ----------
    x            : torch.tensor
                   Input tensor to the SpatialGate module.

    Returns
    -------
    scaled_x     : torch.tensor
                   Output tensor after applying spatial attention.
    """
    x_compress = self.channel_pool(x)
    x_out = self.spatial(x_compress)
    scale = torch.sigmoid(x_out)
    scaled_x = x * scale
    return scaled_x

spatially_adaptive_convolution

Bases: Module

A spatially adaptive convolution layer.

References

C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." C. Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation." C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement."

Source code in odak/learn/models/components.py
class spatially_adaptive_convolution(torch.nn.Module):
    """
    A spatially adaptive convolution layer.

    References
    ----------

    C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions."
    C. Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation."
    C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement."
    """
    def __init__(
                 self,
                 input_channels = 2,
                 output_channels = 2,
                 kernel_size = 3,
                 stride = 1,
                 padding = 1,
                 bias = False,
                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
                ):
        """
        Initializes a spatially adaptive convolution layer.

        Parameters
        ----------
        input_channels  : int
                          Number of input channels.
        output_channels : int
                          Number of output channels.
        kernel_size     : int
                          Size of the convolution kernel.
        stride          : int
                          Stride of the convolution.
        padding         : int
                          Padding added to both sides of the input.
        bias            : bool
                          If True, includes a bias term in the convolution.
        activation      : torch.nn.Module
                          Activation function to apply. If None, no activation is applied.
        """
        super(spatially_adaptive_convolution, self).__init__()
        self.kernel_size = kernel_size
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.stride = stride
        self.padding = padding
        self.standard_convolution = torch.nn.Conv2d(
                                                    in_channels = input_channels,
                                                    out_channels = self.output_channels,
                                                    kernel_size = kernel_size,
                                                    stride = stride,
                                                    padding = padding,
                                                    bias = bias
                                                   )
        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
        self.activation = activation


    def forward(self, x, sv_kernel_feature):
        """
        Forward pass for the spatially adaptive convolution layer.

        Parameters
        ----------
        x                  : torch.tensor
                            Input data tensor.
                            Dimension: (1, C, H, W)
        sv_kernel_feature   : torch.tensor
                            Spatially varying kernel features.
                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)

        Returns
        -------
        sa_output          : torch.tensor
                            Estimated output tensor.
                            Dimension: (1, output_channels, H_out, W_out)
        """
        # Pad input and sv_kernel_feature if necessary
        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
                -2) * self.stride != x.size(-2):
            diffY = sv_kernel_feature.size(-2) % self.stride
            diffX = sv_kernel_feature.size(-1) % self.stride
            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
                                                                            diffY // 2, diffY - diffY // 2))
            diffY = x.size(-2) % self.stride
            diffX = x.size(-1) % self.stride
            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
                                            diffY // 2, diffY - diffY // 2))

        # Unfold the input tensor for matrix multiplication
        input_feature = torch.nn.functional.unfold(
                                                   x,
                                                   kernel_size = (self.kernel_size, self.kernel_size),
                                                   stride = self.stride,
                                                   padding = self.padding
                                                  )

        # Resize sv_kernel_feature to match the input feature
        sv_kernel = sv_kernel_feature.reshape(
                                              1,
                                              self.input_channels * self.kernel_size * self.kernel_size,
                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
                                             )

        # Resize weight to match the input channels and kernel size
        si_kernel = self.weight.reshape(
                                        self.weight_output_channels,
                                        self.input_channels * self.kernel_size * self.kernel_size
                                       )

        # Apply spatially varying kernels
        sv_feature = input_feature * sv_kernel

        # Perform matrix multiplication
        sa_output = torch.matmul(si_kernel, sv_feature).reshape(
                                                                1, self.weight_output_channels,
                                                                (x.size(-2) // self.stride),
                                                                (x.size(-1) // self.stride)
                                                               )
        return sa_output

__init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))

Initializes a spatially adaptive convolution layer.

Parameters:

  • input_channels
              Number of input channels.
    
  • output_channels (int, default: 2 ) –
              Number of output channels.
    
  • kernel_size
              Size of the convolution kernel.
    
  • stride
              Stride of the convolution.
    
  • padding
              Padding added to both sides of the input.
    
  • bias
              If True, includes a bias term in the convolution.
    
  • activation
              Activation function to apply. If None, no activation is applied.
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels = 2,
             output_channels = 2,
             kernel_size = 3,
             stride = 1,
             padding = 1,
             bias = False,
             activation = torch.nn.LeakyReLU(0.2, inplace = True)
            ):
    """
    Initializes a spatially adaptive convolution layer.

    Parameters
    ----------
    input_channels  : int
                      Number of input channels.
    output_channels : int
                      Number of output channels.
    kernel_size     : int
                      Size of the convolution kernel.
    stride          : int
                      Stride of the convolution.
    padding         : int
                      Padding added to both sides of the input.
    bias            : bool
                      If True, includes a bias term in the convolution.
    activation      : torch.nn.Module
                      Activation function to apply. If None, no activation is applied.
    """
    super(spatially_adaptive_convolution, self).__init__()
    self.kernel_size = kernel_size
    self.input_channels = input_channels
    self.output_channels = output_channels
    self.stride = stride
    self.padding = padding
    self.standard_convolution = torch.nn.Conv2d(
                                                in_channels = input_channels,
                                                out_channels = self.output_channels,
                                                kernel_size = kernel_size,
                                                stride = stride,
                                                padding = padding,
                                                bias = bias
                                               )
    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
    self.activation = activation

forward(x, sv_kernel_feature)

Forward pass for the spatially adaptive convolution layer.

Parameters:

  • x
                Input data tensor.
                Dimension: (1, C, H, W)
    
  • sv_kernel_feature
                Spatially varying kernel features.
                Dimension: (1, C_i * kernel_size * kernel_size, H, W)
    

Returns:

  • sa_output ( tensor ) –

    Estimated output tensor. Dimension: (1, output_channels, H_out, W_out)

Source code in odak/learn/models/components.py
def forward(self, x, sv_kernel_feature):
    """
    Forward pass for the spatially adaptive convolution layer.

    Parameters
    ----------
    x                  : torch.tensor
                        Input data tensor.
                        Dimension: (1, C, H, W)
    sv_kernel_feature   : torch.tensor
                        Spatially varying kernel features.
                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)

    Returns
    -------
    sa_output          : torch.tensor
                        Estimated output tensor.
                        Dimension: (1, output_channels, H_out, W_out)
    """
    # Pad input and sv_kernel_feature if necessary
    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
            -2) * self.stride != x.size(-2):
        diffY = sv_kernel_feature.size(-2) % self.stride
        diffX = sv_kernel_feature.size(-1) % self.stride
        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
                                                                        diffY // 2, diffY - diffY // 2))
        diffY = x.size(-2) % self.stride
        diffX = x.size(-1) % self.stride
        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
                                        diffY // 2, diffY - diffY // 2))

    # Unfold the input tensor for matrix multiplication
    input_feature = torch.nn.functional.unfold(
                                               x,
                                               kernel_size = (self.kernel_size, self.kernel_size),
                                               stride = self.stride,
                                               padding = self.padding
                                              )

    # Resize sv_kernel_feature to match the input feature
    sv_kernel = sv_kernel_feature.reshape(
                                          1,
                                          self.input_channels * self.kernel_size * self.kernel_size,
                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
                                         )

    # Resize weight to match the input channels and kernel size
    si_kernel = self.weight.reshape(
                                    self.weight_output_channels,
                                    self.input_channels * self.kernel_size * self.kernel_size
                                   )

    # Apply spatially varying kernels
    sv_feature = input_feature * sv_kernel

    # Perform matrix multiplication
    sa_output = torch.matmul(si_kernel, sv_feature).reshape(
                                                            1, self.weight_output_channels,
                                                            (x.size(-2) // self.stride),
                                                            (x.size(-1) // self.stride)
                                                           )
    return sa_output

spatially_adaptive_module

Bases: Module

A spatially adaptive module that combines learned spatially adaptive convolutions.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

Source code in odak/learn/models/components.py
class spatially_adaptive_module(torch.nn.Module):
    """
    A spatially adaptive module that combines learned spatially adaptive convolutions.

    References
    ----------

    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.
    """
    def __init__(
                 self,
                 input_channels = 2,
                 output_channels = 2,
                 kernel_size = 3,
                 stride = 1,
                 padding = 1,
                 bias = False,
                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
                ):
        """
        Initializes a spatially adaptive module.

        Parameters
        ----------
        input_channels  : int
                          Number of input channels.
        output_channels : int
                          Number of output channels.
        kernel_size     : int
                          Size of the convolution kernel.
        stride          : int
                          Stride of the convolution.
        padding         : int
                          Padding added to both sides of the input.
        bias            : bool
                          If True, includes a bias term in the convolution.
        activation      : torch.nn
                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
        """
        super(spatially_adaptive_module, self).__init__()
        self.kernel_size = kernel_size
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.stride = stride
        self.padding = padding
        self.weight_output_channels = self.output_channels - 1
        self.standard_convolution = torch.nn.Conv2d(
                                                    in_channels = input_channels,
                                                    out_channels = self.weight_output_channels,
                                                    kernel_size = kernel_size,
                                                    stride = stride,
                                                    padding = padding,
                                                    bias = bias
                                                   )
        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
        self.activation = activation


    def forward(self, x, sv_kernel_feature):
        """
        Forward pass for the spatially adaptive module.

        Parameters
        ----------
        x                  : torch.tensor
                            Input data tensor.
                            Dimension: (1, C, H, W)
        sv_kernel_feature   : torch.tensor
                            Spatially varying kernel features.
                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)

        Returns
        -------
        output             : torch.tensor
                            Combined output tensor from standard and spatially adaptive convolutions.
                            Dimension: (1, output_channels, H_out, W_out)
        """
        # Pad input and sv_kernel_feature if necessary
        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
                -2) * self.stride != x.size(-2):
            diffY = sv_kernel_feature.size(-2) % self.stride
            diffX = sv_kernel_feature.size(-1) % self.stride
            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
                                                                            diffY // 2, diffY - diffY // 2))
            diffY = x.size(-2) % self.stride
            diffX = x.size(-1) % self.stride
            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
                                            diffY // 2, diffY - diffY // 2))

        # Unfold the input tensor for matrix multiplication
        input_feature = torch.nn.functional.unfold(
                                                   x,
                                                   kernel_size = (self.kernel_size, self.kernel_size),
                                                   stride = self.stride,
                                                   padding = self.padding
                                                  )

        # Resize sv_kernel_feature to match the input feature
        sv_kernel = sv_kernel_feature.reshape(
                                              1,
                                              self.input_channels * self.kernel_size * self.kernel_size,
                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
                                             )

        # Apply sv_kernel to the input_feature
        sv_feature = input_feature * sv_kernel

        # Original spatially varying convolution output
        sv_output = torch.sum(sv_feature, dim = 1).reshape(
                                                           1,
                                                            1,
                                                            (x.size(-2) // self.stride),
                                                            (x.size(-1) // self.stride)
                                                           )

        # Reshape weight for spatially adaptive convolution
        si_kernel = self.weight.reshape(
                                        self.weight_output_channels,
                                        self.input_channels * self.kernel_size * self.kernel_size
                                       )

        # Apply si_kernel on sv convolution output
        sa_output = torch.matmul(si_kernel, sv_feature).reshape(
                                                                1, self.weight_output_channels,
                                                                (x.size(-2) // self.stride),
                                                                (x.size(-1) // self.stride)
                                                               )

        # Combine the outputs and apply activation function
        output = self.activation(torch.cat((sv_output, sa_output), dim = 1))
        return output

__init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))

Initializes a spatially adaptive module.

Parameters:

  • input_channels
              Number of input channels.
    
  • output_channels (int, default: 2 ) –
              Number of output channels.
    
  • kernel_size
              Size of the convolution kernel.
    
  • stride
              Stride of the convolution.
    
  • padding
              Padding added to both sides of the input.
    
  • bias
              If True, includes a bias term in the convolution.
    
  • activation
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels = 2,
             output_channels = 2,
             kernel_size = 3,
             stride = 1,
             padding = 1,
             bias = False,
             activation = torch.nn.LeakyReLU(0.2, inplace = True)
            ):
    """
    Initializes a spatially adaptive module.

    Parameters
    ----------
    input_channels  : int
                      Number of input channels.
    output_channels : int
                      Number of output channels.
    kernel_size     : int
                      Size of the convolution kernel.
    stride          : int
                      Stride of the convolution.
    padding         : int
                      Padding added to both sides of the input.
    bias            : bool
                      If True, includes a bias term in the convolution.
    activation      : torch.nn
                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    """
    super(spatially_adaptive_module, self).__init__()
    self.kernel_size = kernel_size
    self.input_channels = input_channels
    self.output_channels = output_channels
    self.stride = stride
    self.padding = padding
    self.weight_output_channels = self.output_channels - 1
    self.standard_convolution = torch.nn.Conv2d(
                                                in_channels = input_channels,
                                                out_channels = self.weight_output_channels,
                                                kernel_size = kernel_size,
                                                stride = stride,
                                                padding = padding,
                                                bias = bias
                                               )
    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
    self.activation = activation

forward(x, sv_kernel_feature)

Forward pass for the spatially adaptive module.

Parameters:

  • x
                Input data tensor.
                Dimension: (1, C, H, W)
    
  • sv_kernel_feature
                Spatially varying kernel features.
                Dimension: (1, C_i * kernel_size * kernel_size, H, W)
    

Returns:

  • output ( tensor ) –

    Combined output tensor from standard and spatially adaptive convolutions. Dimension: (1, output_channels, H_out, W_out)

Source code in odak/learn/models/components.py
def forward(self, x, sv_kernel_feature):
    """
    Forward pass for the spatially adaptive module.

    Parameters
    ----------
    x                  : torch.tensor
                        Input data tensor.
                        Dimension: (1, C, H, W)
    sv_kernel_feature   : torch.tensor
                        Spatially varying kernel features.
                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)

    Returns
    -------
    output             : torch.tensor
                        Combined output tensor from standard and spatially adaptive convolutions.
                        Dimension: (1, output_channels, H_out, W_out)
    """
    # Pad input and sv_kernel_feature if necessary
    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
            -2) * self.stride != x.size(-2):
        diffY = sv_kernel_feature.size(-2) % self.stride
        diffX = sv_kernel_feature.size(-1) % self.stride
        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
                                                                        diffY // 2, diffY - diffY // 2))
        diffY = x.size(-2) % self.stride
        diffX = x.size(-1) % self.stride
        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
                                        diffY // 2, diffY - diffY // 2))

    # Unfold the input tensor for matrix multiplication
    input_feature = torch.nn.functional.unfold(
                                               x,
                                               kernel_size = (self.kernel_size, self.kernel_size),
                                               stride = self.stride,
                                               padding = self.padding
                                              )

    # Resize sv_kernel_feature to match the input feature
    sv_kernel = sv_kernel_feature.reshape(
                                          1,
                                          self.input_channels * self.kernel_size * self.kernel_size,
                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
                                         )

    # Apply sv_kernel to the input_feature
    sv_feature = input_feature * sv_kernel

    # Original spatially varying convolution output
    sv_output = torch.sum(sv_feature, dim = 1).reshape(
                                                       1,
                                                        1,
                                                        (x.size(-2) // self.stride),
                                                        (x.size(-1) // self.stride)
                                                       )

    # Reshape weight for spatially adaptive convolution
    si_kernel = self.weight.reshape(
                                    self.weight_output_channels,
                                    self.input_channels * self.kernel_size * self.kernel_size
                                   )

    # Apply si_kernel on sv convolution output
    sa_output = torch.matmul(si_kernel, sv_feature).reshape(
                                                            1, self.weight_output_channels,
                                                            (x.size(-2) // self.stride),
                                                            (x.size(-1) // self.stride)
                                                           )

    # Combine the outputs and apply activation function
    output = self.activation(torch.cat((sv_output, sa_output), dim = 1))
    return output

spatially_adaptive_unet

Bases: Module

Spatially varying U-Net model based on spatially adaptive convolution.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

Source code in odak/learn/models/models.py
class spatially_adaptive_unet(torch.nn.Module):
    """
    Spatially varying U-Net model based on spatially adaptive convolution.

    References
    ----------

    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.
    """
    def __init__(
                 self,
                 depth=3,
                 dimensions=8,
                 input_channels=6,
                 out_channels=6,
                 kernel_size=3,
                 bias=True,
                 normalization=False,
                 activation=torch.nn.LeakyReLU(0.2, inplace=True)
                ):
        """
        U-Net model.

        Parameters
        ----------
        depth          : int
                         Number of upsampling and downsampling layers.
        dimensions     : int
                         Number of dimensions.
        input_channels : int
                         Number of input channels.
        out_channels   : int
                         Number of output channels.
        bias           : bool
                         Set to True to let convolutional layers learn a bias term.
        normalization  : bool
                         If True, adds a Batch Normalization layer after the convolutional layer.
        activation     : torch.nn
                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
        """
        super().__init__()
        self.depth = depth
        self.out_channels = out_channels
        self.inc = convolution_layer(
                                     input_channels=input_channels,
                                     output_channels=dimensions,
                                     kernel_size=kernel_size,
                                     bias=bias,
                                     normalization=normalization,
                                     activation=activation
                                    )

        self.encoder = torch.nn.ModuleList()
        for i in range(self.depth + 1):  # Downsampling layers
            down_in_channels = dimensions * (2 ** i)
            down_out_channels = 2 * down_in_channels
            pooling_layer = torch.nn.AvgPool2d(2)
            double_convolution_layer = double_convolution(
                                                          input_channels=down_in_channels,
                                                          mid_channels=down_in_channels,
                                                          output_channels=down_in_channels,
                                                          kernel_size=kernel_size,
                                                          bias=bias,
                                                          normalization=normalization,
                                                          activation=activation
                                                         )
            sam = spatially_adaptive_module(
                                            input_channels=down_in_channels,
                                            output_channels=down_out_channels,
                                            kernel_size=kernel_size,
                                            bias=bias,
                                            activation=activation
                                           )
            self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))
        self.global_feature_module = torch.nn.ModuleList()
        double_convolution_layer = double_convolution(
                                                      input_channels=dimensions * (2 ** (depth + 1)),
                                                      mid_channels=dimensions * (2 ** (depth + 1)),
                                                      output_channels=dimensions * (2 ** (depth + 1)),
                                                      kernel_size=kernel_size,
                                                      bias=bias,
                                                      normalization=normalization,
                                                      activation=activation
                                                     )
        global_feature_layer = global_feature_module(
                                                     input_channels=dimensions * (2 ** (depth + 1)),
                                                     mid_channels=dimensions * (2 ** (depth + 1)),
                                                     output_channels=dimensions * (2 ** (depth + 1)),
                                                     kernel_size=kernel_size,
                                                     bias=bias,
                                                     activation=torch.nn.LeakyReLU(0.2, inplace=True)
                                                    )
        self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))
        self.decoder = torch.nn.ModuleList()
        for i in range(depth, -1, -1):
            up_in_channels = dimensions * (2 ** (i + 1))
            up_mid_channels = up_in_channels // 2
            if i == 0:
                up_out_channels = self.out_channels
                upsample_layer = upsample_convtranspose2d_layer(
                                                                input_channels=up_in_channels,
                                                                output_channels=up_mid_channels,
                                                                kernel_size=2,
                                                                stride=2,
                                                                bias=bias,
                                                               )
                conv_layer = torch.nn.Sequential(
                    convolution_layer(
                                      input_channels=up_mid_channels,
                                      output_channels=up_mid_channels,
                                      kernel_size=kernel_size,
                                      bias=bias,
                                      normalization=normalization,
                                      activation=activation,
                                     ),
                    convolution_layer(
                                      input_channels=up_mid_channels,
                                      output_channels=up_out_channels,
                                      kernel_size=1,
                                      bias=bias,
                                      normalization=normalization,
                                      activation=None,
                                     )
                )
                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
            else:
                up_out_channels = up_in_channels // 2
                upsample_layer = upsample_convtranspose2d_layer(
                                                                input_channels=up_in_channels,
                                                                output_channels=up_mid_channels,
                                                                kernel_size=2,
                                                                stride=2,
                                                                bias=bias,
                                                               )
                conv_layer = double_convolution(
                                                input_channels=up_mid_channels,
                                                mid_channels=up_mid_channels,
                                                output_channels=up_out_channels,
                                                kernel_size=kernel_size,
                                                bias=bias,
                                                normalization=normalization,
                                                activation=activation,
                                               )
                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))


    def forward(self, sv_kernel, field):
        """
        Forward model.

        Parameters
        ----------
        sv_kernel : list of torch.tensor
                    Learned spatially varying kernels.
                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
                    where C_i, H_i, and W_i represent the channel, height, and width
                    of each feature at a certain scale.

        field     : torch.tensor
                    Input field data.
                    Dimension: (1, 6, H, W)

        Returns
        -------
        target_field : torch.tensor
                       Estimated output.
                       Dimension: (1, 6, H, W)
        """
        x = self.inc(field)
        downsampling_outputs = [x]
        for i, down_layer in enumerate(self.encoder):
            x_down = down_layer[0](downsampling_outputs[-1])
            downsampling_outputs.append(x_down)
            sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])
            downsampling_outputs.append(sam_output)
        global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])
        global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)
        downsampling_outputs.append(global_feature)
        x_up = downsampling_outputs[-1]
        for i, up_layer in enumerate(self.decoder):
            x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])
            x_up = up_layer[1](x_up)
        result = x_up
        return result

__init__(depth=3, dimensions=8, input_channels=6, out_channels=6, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))

U-Net model.

Parameters:

  • depth
             Number of upsampling and downsampling layers.
    
  • dimensions
             Number of dimensions.
    
  • input_channels (int, default: 6 ) –
             Number of input channels.
    
  • out_channels
             Number of output channels.
    
  • bias
             Set to True to let convolutional layers learn a bias term.
    
  • normalization
             If True, adds a Batch Normalization layer after the convolutional layer.
    
  • activation
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    
Source code in odak/learn/models/models.py
def __init__(
             self,
             depth=3,
             dimensions=8,
             input_channels=6,
             out_channels=6,
             kernel_size=3,
             bias=True,
             normalization=False,
             activation=torch.nn.LeakyReLU(0.2, inplace=True)
            ):
    """
    U-Net model.

    Parameters
    ----------
    depth          : int
                     Number of upsampling and downsampling layers.
    dimensions     : int
                     Number of dimensions.
    input_channels : int
                     Number of input channels.
    out_channels   : int
                     Number of output channels.
    bias           : bool
                     Set to True to let convolutional layers learn a bias term.
    normalization  : bool
                     If True, adds a Batch Normalization layer after the convolutional layer.
    activation     : torch.nn
                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    """
    super().__init__()
    self.depth = depth
    self.out_channels = out_channels
    self.inc = convolution_layer(
                                 input_channels=input_channels,
                                 output_channels=dimensions,
                                 kernel_size=kernel_size,
                                 bias=bias,
                                 normalization=normalization,
                                 activation=activation
                                )

    self.encoder = torch.nn.ModuleList()
    for i in range(self.depth + 1):  # Downsampling layers
        down_in_channels = dimensions * (2 ** i)
        down_out_channels = 2 * down_in_channels
        pooling_layer = torch.nn.AvgPool2d(2)
        double_convolution_layer = double_convolution(
                                                      input_channels=down_in_channels,
                                                      mid_channels=down_in_channels,
                                                      output_channels=down_in_channels,
                                                      kernel_size=kernel_size,
                                                      bias=bias,
                                                      normalization=normalization,
                                                      activation=activation
                                                     )
        sam = spatially_adaptive_module(
                                        input_channels=down_in_channels,
                                        output_channels=down_out_channels,
                                        kernel_size=kernel_size,
                                        bias=bias,
                                        activation=activation
                                       )
        self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))
    self.global_feature_module = torch.nn.ModuleList()
    double_convolution_layer = double_convolution(
                                                  input_channels=dimensions * (2 ** (depth + 1)),
                                                  mid_channels=dimensions * (2 ** (depth + 1)),
                                                  output_channels=dimensions * (2 ** (depth + 1)),
                                                  kernel_size=kernel_size,
                                                  bias=bias,
                                                  normalization=normalization,
                                                  activation=activation
                                                 )
    global_feature_layer = global_feature_module(
                                                 input_channels=dimensions * (2 ** (depth + 1)),
                                                 mid_channels=dimensions * (2 ** (depth + 1)),
                                                 output_channels=dimensions * (2 ** (depth + 1)),
                                                 kernel_size=kernel_size,
                                                 bias=bias,
                                                 activation=torch.nn.LeakyReLU(0.2, inplace=True)
                                                )
    self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))
    self.decoder = torch.nn.ModuleList()
    for i in range(depth, -1, -1):
        up_in_channels = dimensions * (2 ** (i + 1))
        up_mid_channels = up_in_channels // 2
        if i == 0:
            up_out_channels = self.out_channels
            upsample_layer = upsample_convtranspose2d_layer(
                                                            input_channels=up_in_channels,
                                                            output_channels=up_mid_channels,
                                                            kernel_size=2,
                                                            stride=2,
                                                            bias=bias,
                                                           )
            conv_layer = torch.nn.Sequential(
                convolution_layer(
                                  input_channels=up_mid_channels,
                                  output_channels=up_mid_channels,
                                  kernel_size=kernel_size,
                                  bias=bias,
                                  normalization=normalization,
                                  activation=activation,
                                 ),
                convolution_layer(
                                  input_channels=up_mid_channels,
                                  output_channels=up_out_channels,
                                  kernel_size=1,
                                  bias=bias,
                                  normalization=normalization,
                                  activation=None,
                                 )
            )
            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
        else:
            up_out_channels = up_in_channels // 2
            upsample_layer = upsample_convtranspose2d_layer(
                                                            input_channels=up_in_channels,
                                                            output_channels=up_mid_channels,
                                                            kernel_size=2,
                                                            stride=2,
                                                            bias=bias,
                                                           )
            conv_layer = double_convolution(
                                            input_channels=up_mid_channels,
                                            mid_channels=up_mid_channels,
                                            output_channels=up_out_channels,
                                            kernel_size=kernel_size,
                                            bias=bias,
                                            normalization=normalization,
                                            activation=activation,
                                           )
            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))

forward(sv_kernel, field)

Forward model.

Parameters:

  • sv_kernel (list of torch.tensor) –
        Learned spatially varying kernels.
        Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
        where C_i, H_i, and W_i represent the channel, height, and width
        of each feature at a certain scale.
    
  • field
        Input field data.
        Dimension: (1, 6, H, W)
    

Returns:

  • target_field ( tensor ) –

    Estimated output. Dimension: (1, 6, H, W)

Source code in odak/learn/models/models.py
def forward(self, sv_kernel, field):
    """
    Forward model.

    Parameters
    ----------
    sv_kernel : list of torch.tensor
                Learned spatially varying kernels.
                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
                where C_i, H_i, and W_i represent the channel, height, and width
                of each feature at a certain scale.

    field     : torch.tensor
                Input field data.
                Dimension: (1, 6, H, W)

    Returns
    -------
    target_field : torch.tensor
                   Estimated output.
                   Dimension: (1, 6, H, W)
    """
    x = self.inc(field)
    downsampling_outputs = [x]
    for i, down_layer in enumerate(self.encoder):
        x_down = down_layer[0](downsampling_outputs[-1])
        downsampling_outputs.append(x_down)
        sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])
        downsampling_outputs.append(sam_output)
    global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])
    global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)
    downsampling_outputs.append(global_feature)
    x_up = downsampling_outputs[-1]
    for i, up_layer in enumerate(self.decoder):
        x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])
        x_up = up_layer[1](x_up)
    result = x_up
    return result

spatially_varying_kernel_generation_model

Bases: Module

Spatially_varying_kernel_generation_model revised from RSGUnet: https://github.com/MTLab/rsgunet_image_enhance.

Refer to: J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.

Source code in odak/learn/models/models.py
class spatially_varying_kernel_generation_model(torch.nn.Module):
    """
    Spatially_varying_kernel_generation_model revised from RSGUnet:
    https://github.com/MTLab/rsgunet_image_enhance.

    Refer to:
    J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.
    """

    def __init__(
                 self,
                 depth = 3,
                 dimensions = 8,
                 input_channels = 7,
                 kernel_size = 3,
                 bias = True,
                 normalization = False,
                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
                ):
        """
        U-Net model.

        Parameters
        ----------
        depth          : int
                         Number of upsampling and downsampling layers.
        dimensions     : int
                         Number of dimensions.
        input_channels : int
                         Number of input channels.
        bias           : bool
                         Set to True to let convolutional layers learn a bias term.
        normalization  : bool
                         If True, adds a Batch Normalization layer after the convolutional layer.
        activation     : torch.nn
                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
        """
        super().__init__()
        self.depth = depth
        self.inc = convolution_layer(
                                     input_channels = input_channels,
                                     output_channels = dimensions,
                                     kernel_size = kernel_size,
                                     bias = bias,
                                     normalization = normalization,
                                     activation = activation
                                    )
        self.encoder = torch.nn.ModuleList()
        for i in range(depth + 1):  # downsampling layers
            if i == 0:
                in_channels = dimensions * (2 ** i)
                out_channels = dimensions * (2 ** i)
            elif i == depth:
                in_channels = dimensions * (2 ** (i - 1))
                out_channels = dimensions * (2 ** (i - 1))
            else:
                in_channels = dimensions * (2 ** (i - 1))
                out_channels = 2 * in_channels
            pooling_layer = torch.nn.AvgPool2d(2)
            double_convolution_layer = double_convolution(
                                                          input_channels = in_channels,
                                                          mid_channels = in_channels,
                                                          output_channels = out_channels,
                                                          kernel_size = kernel_size,
                                                          bias = bias,
                                                          normalization = normalization,
                                                          activation = activation
                                                         )
            self.encoder.append(pooling_layer)
            self.encoder.append(double_convolution_layer)
        self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation
        for i in range(depth, -1, -1):
            if i == 1:
                svf_in_channels = dimensions + 2 ** (self.depth + i) + 1
            else:
                svf_in_channels = 2 ** (self.depth + i) + 1
            svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)
            svf_mid_channels = dimensions * (2 ** (self.depth - 1))
            spatially_varying_kernel_generation = torch.nn.ModuleList()
            for j in range(i, -1, -1):
                pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))
                spatially_varying_kernel_generation.append(pooling_layer)
            kernel_generation_block = torch.nn.Sequential(
                torch.nn.Conv2d(
                                in_channels = svf_in_channels,
                                out_channels = svf_mid_channels,
                                kernel_size = kernel_size,
                                padding = kernel_size // 2,
                                bias = bias
                               ),
                activation,
                torch.nn.Conv2d(
                                in_channels = svf_mid_channels,
                                out_channels = svf_mid_channels,
                                kernel_size = kernel_size,
                                padding = kernel_size // 2,
                                bias = bias
                               ),
                activation,
                torch.nn.Conv2d(
                                in_channels = svf_mid_channels,
                                out_channels = svf_out_channels,
                                kernel_size = kernel_size,
                                padding = kernel_size // 2,
                                bias = bias
                               ),
            )
            spatially_varying_kernel_generation.append(kernel_generation_block)
            self.spatially_varying_feature.append(spatially_varying_kernel_generation)
        self.decoder = torch.nn.ModuleList()
        global_feature_layer = global_feature_module(  # global feature layer
                                                     input_channels = dimensions * (2 ** (depth - 1)),
                                                     mid_channels = dimensions * (2 ** (depth - 1)),
                                                     output_channels = dimensions * (2 ** (depth - 1)),
                                                     kernel_size = kernel_size,
                                                     bias = bias,
                                                     activation = torch.nn.LeakyReLU(0.2, inplace = True)
                                                    )
        self.decoder.append(global_feature_layer)
        for i in range(depth, 0, -1):
            if i == 2:
                up_in_channels = (dimensions // 2) * (2 ** i)
                up_out_channels = up_in_channels
                up_mid_channels = up_in_channels
            elif i == 1:
                up_in_channels = dimensions * 2
                up_out_channels = dimensions
                up_mid_channels = up_out_channels
            else:
                up_in_channels = (dimensions // 2) * (2 ** i)
                up_out_channels = up_in_channels // 2
                up_mid_channels = up_in_channels
            upsample_layer = upsample_convtranspose2d_layer(
                                                            input_channels = up_in_channels,
                                                            output_channels = up_mid_channels,
                                                            kernel_size = 2,
                                                            stride = 2,
                                                            bias = bias,
                                                           )
            conv_layer = double_convolution(
                                            input_channels = up_mid_channels,
                                            output_channels = up_out_channels,
                                            kernel_size = kernel_size,
                                            bias = bias,
                                            normalization = normalization,
                                            activation = activation,
                                           )
            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))


    def forward(self, focal_surface, field):
        """
        Forward model.

        Parameters
        ----------
        focal_surface : torch.tensor
                        Input focal surface data.
                        Dimension: (1, 1, H, W)

        field         : torch.tensor
                        Input field data.
                        Dimension: (1, 6, H, W)

        Returns
        -------
        sv_kernel : list of torch.tensor
                    Learned spatially varying kernels.
                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
                    where C_i, H_i, and W_i represent the channel, height, and width
                    of each feature at a certain scale.
        """
        x = self.inc(torch.cat((focal_surface, field), dim = 1))
        downsampling_outputs = [focal_surface]
        downsampling_outputs.append(x)
        for i, down_layer in enumerate(self.encoder):
            x_down = down_layer(downsampling_outputs[-1])
            downsampling_outputs.append(x_down)
        sv_kernels = []
        for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):
            if i == 0:
                global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])
                downsampling_outputs[-1] = global_feature
                sv_feature = [global_feature, downsampling_outputs[0]]
                for j in range(self.depth - i + 1):
                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
                    if j > 0:
                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],
                              sv_feature[3]]
                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
                sv_kernels.append(sv_kernel)
            else:
                x_up = up_layer[0](downsampling_outputs[-1],
                                   downsampling_outputs[2 * (self.depth + 1 - i) + 1])
                x_up = up_layer[1](x_up)
                downsampling_outputs[-1] = x_up
                sv_feature = [x_up, downsampling_outputs[0]]
                for j in range(self.depth - i + 1):
                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
                    if j > 0:
                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
                if i == 1:
                    sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]
                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
                sv_kernels.append(sv_kernel)
        return sv_kernels

__init__(depth=3, dimensions=8, input_channels=7, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))

U-Net model.

Parameters:

  • depth
             Number of upsampling and downsampling layers.
    
  • dimensions
             Number of dimensions.
    
  • input_channels (int, default: 7 ) –
             Number of input channels.
    
  • bias
             Set to True to let convolutional layers learn a bias term.
    
  • normalization
             If True, adds a Batch Normalization layer after the convolutional layer.
    
  • activation
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    
Source code in odak/learn/models/models.py
def __init__(
             self,
             depth = 3,
             dimensions = 8,
             input_channels = 7,
             kernel_size = 3,
             bias = True,
             normalization = False,
             activation = torch.nn.LeakyReLU(0.2, inplace = True)
            ):
    """
    U-Net model.

    Parameters
    ----------
    depth          : int
                     Number of upsampling and downsampling layers.
    dimensions     : int
                     Number of dimensions.
    input_channels : int
                     Number of input channels.
    bias           : bool
                     Set to True to let convolutional layers learn a bias term.
    normalization  : bool
                     If True, adds a Batch Normalization layer after the convolutional layer.
    activation     : torch.nn
                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    """
    super().__init__()
    self.depth = depth
    self.inc = convolution_layer(
                                 input_channels = input_channels,
                                 output_channels = dimensions,
                                 kernel_size = kernel_size,
                                 bias = bias,
                                 normalization = normalization,
                                 activation = activation
                                )
    self.encoder = torch.nn.ModuleList()
    for i in range(depth + 1):  # downsampling layers
        if i == 0:
            in_channels = dimensions * (2 ** i)
            out_channels = dimensions * (2 ** i)
        elif i == depth:
            in_channels = dimensions * (2 ** (i - 1))
            out_channels = dimensions * (2 ** (i - 1))
        else:
            in_channels = dimensions * (2 ** (i - 1))
            out_channels = 2 * in_channels
        pooling_layer = torch.nn.AvgPool2d(2)
        double_convolution_layer = double_convolution(
                                                      input_channels = in_channels,
                                                      mid_channels = in_channels,
                                                      output_channels = out_channels,
                                                      kernel_size = kernel_size,
                                                      bias = bias,
                                                      normalization = normalization,
                                                      activation = activation
                                                     )
        self.encoder.append(pooling_layer)
        self.encoder.append(double_convolution_layer)
    self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation
    for i in range(depth, -1, -1):
        if i == 1:
            svf_in_channels = dimensions + 2 ** (self.depth + i) + 1
        else:
            svf_in_channels = 2 ** (self.depth + i) + 1
        svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)
        svf_mid_channels = dimensions * (2 ** (self.depth - 1))
        spatially_varying_kernel_generation = torch.nn.ModuleList()
        for j in range(i, -1, -1):
            pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))
            spatially_varying_kernel_generation.append(pooling_layer)
        kernel_generation_block = torch.nn.Sequential(
            torch.nn.Conv2d(
                            in_channels = svf_in_channels,
                            out_channels = svf_mid_channels,
                            kernel_size = kernel_size,
                            padding = kernel_size // 2,
                            bias = bias
                           ),
            activation,
            torch.nn.Conv2d(
                            in_channels = svf_mid_channels,
                            out_channels = svf_mid_channels,
                            kernel_size = kernel_size,
                            padding = kernel_size // 2,
                            bias = bias
                           ),
            activation,
            torch.nn.Conv2d(
                            in_channels = svf_mid_channels,
                            out_channels = svf_out_channels,
                            kernel_size = kernel_size,
                            padding = kernel_size // 2,
                            bias = bias
                           ),
        )
        spatially_varying_kernel_generation.append(kernel_generation_block)
        self.spatially_varying_feature.append(spatially_varying_kernel_generation)
    self.decoder = torch.nn.ModuleList()
    global_feature_layer = global_feature_module(  # global feature layer
                                                 input_channels = dimensions * (2 ** (depth - 1)),
                                                 mid_channels = dimensions * (2 ** (depth - 1)),
                                                 output_channels = dimensions * (2 ** (depth - 1)),
                                                 kernel_size = kernel_size,
                                                 bias = bias,
                                                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
                                                )
    self.decoder.append(global_feature_layer)
    for i in range(depth, 0, -1):
        if i == 2:
            up_in_channels = (dimensions // 2) * (2 ** i)
            up_out_channels = up_in_channels
            up_mid_channels = up_in_channels
        elif i == 1:
            up_in_channels = dimensions * 2
            up_out_channels = dimensions
            up_mid_channels = up_out_channels
        else:
            up_in_channels = (dimensions // 2) * (2 ** i)
            up_out_channels = up_in_channels // 2
            up_mid_channels = up_in_channels
        upsample_layer = upsample_convtranspose2d_layer(
                                                        input_channels = up_in_channels,
                                                        output_channels = up_mid_channels,
                                                        kernel_size = 2,
                                                        stride = 2,
                                                        bias = bias,
                                                       )
        conv_layer = double_convolution(
                                        input_channels = up_mid_channels,
                                        output_channels = up_out_channels,
                                        kernel_size = kernel_size,
                                        bias = bias,
                                        normalization = normalization,
                                        activation = activation,
                                       )
        self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))

forward(focal_surface, field)

Forward model.

Parameters:

  • focal_surface (tensor) –
            Input focal surface data.
            Dimension: (1, 1, H, W)
    
  • field
            Input field data.
            Dimension: (1, 6, H, W)
    

Returns:

  • sv_kernel ( list of torch.tensor ) –

    Learned spatially varying kernels. Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i), where C_i, H_i, and W_i represent the channel, height, and width of each feature at a certain scale.

Source code in odak/learn/models/models.py
def forward(self, focal_surface, field):
    """
    Forward model.

    Parameters
    ----------
    focal_surface : torch.tensor
                    Input focal surface data.
                    Dimension: (1, 1, H, W)

    field         : torch.tensor
                    Input field data.
                    Dimension: (1, 6, H, W)

    Returns
    -------
    sv_kernel : list of torch.tensor
                Learned spatially varying kernels.
                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
                where C_i, H_i, and W_i represent the channel, height, and width
                of each feature at a certain scale.
    """
    x = self.inc(torch.cat((focal_surface, field), dim = 1))
    downsampling_outputs = [focal_surface]
    downsampling_outputs.append(x)
    for i, down_layer in enumerate(self.encoder):
        x_down = down_layer(downsampling_outputs[-1])
        downsampling_outputs.append(x_down)
    sv_kernels = []
    for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):
        if i == 0:
            global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])
            downsampling_outputs[-1] = global_feature
            sv_feature = [global_feature, downsampling_outputs[0]]
            for j in range(self.depth - i + 1):
                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
                if j > 0:
                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
            sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],
                          sv_feature[3]]
            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
            sv_kernels.append(sv_kernel)
        else:
            x_up = up_layer[0](downsampling_outputs[-1],
                               downsampling_outputs[2 * (self.depth + 1 - i) + 1])
            x_up = up_layer[1](x_up)
            downsampling_outputs[-1] = x_up
            sv_feature = [x_up, downsampling_outputs[0]]
            for j in range(self.depth - i + 1):
                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
                if j > 0:
                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
            if i == 1:
                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]
            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
            sv_kernels.append(sv_kernel)
    return sv_kernels

unet

Bases: Module

A U-Net model, heavily inspired from https://github.com/milesial/Pytorch-UNet/tree/master/unet and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.

Source code in odak/learn/models/models.py
class unet(torch.nn.Module):
    """
    A U-Net model, heavily inspired from `https://github.com/milesial/Pytorch-UNet/tree/master/unet` and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.
    """

    def __init__(
                 self, 
                 depth = 4,
                 dimensions = 64, 
                 input_channels = 2, 
                 output_channels = 1, 
                 bilinear = False,
                 kernel_size = 3,
                 bias = False,
                 activation = torch.nn.ReLU(inplace = True),
                ):
        """
        U-Net model.

        Parameters
        ----------
        depth             : int
                            Number of upsampling and downsampling
        dimensions        : int
                            Number of dimensions.
        input_channels    : int
                            Number of input channels.
        output_channels   : int
                            Number of output channels.
        bilinear          : bool
                            Uses bilinear upsampling in upsampling layers when set True.
        bias              : bool
                            Set True to let convolutional layers learn a bias term.
        activation        : torch.nn
                            Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
        """
        super(unet, self).__init__()
        self.inc = double_convolution(
                                      input_channels = input_channels,
                                      mid_channels = dimensions,
                                      output_channels = dimensions,
                                      kernel_size = kernel_size,
                                      bias = bias,
                                      activation = activation
                                     )      

        self.downsampling_layers = torch.nn.ModuleList()
        self.upsampling_layers = torch.nn.ModuleList()
        for i in range(depth): # downsampling layers
            in_channels = dimensions * (2 ** i)
            out_channels = dimensions * (2 ** (i + 1))
            down_layer = downsample_layer(in_channels,
                                            out_channels,
                                            kernel_size=kernel_size,
                                            bias=bias,
                                            activation=activation
                                            )
            self.downsampling_layers.append(down_layer)      

        for i in range(depth - 1, -1, -1):  # upsampling layers
            up_in_channels = dimensions * (2 ** (i + 1))  
            up_out_channels = dimensions * (2 ** i) 
            up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)
            self.upsampling_layers.append(up_layer)
        self.outc = torch.nn.Conv2d(
                                    dimensions, 
                                    output_channels,
                                    kernel_size = kernel_size,
                                    padding = kernel_size // 2,
                                    bias = bias
                                   )


    def forward(self, x):
        """
        Forward model.

        Parameters
        ----------
        x             : torch.tensor
                        Input data.


        Returns
        ----------
        result        : torch.tensor
                        Estimated output.      
        """
        downsampling_outputs = [self.inc(x)]
        for down_layer in self.downsampling_layers:
            x_down = down_layer(downsampling_outputs[-1])
            downsampling_outputs.append(x_down)
        x_up = downsampling_outputs[-1]
        for i, up_layer in enumerate((self.upsampling_layers)):
            x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       
        result = self.outc(x_up)
        return result

__init__(depth=4, dimensions=64, input_channels=2, output_channels=1, bilinear=False, kernel_size=3, bias=False, activation=torch.nn.ReLU(inplace=True))

U-Net model.

Parameters:

  • depth
                Number of upsampling and downsampling
    
  • dimensions
                Number of dimensions.
    
  • input_channels
                Number of input channels.
    
  • output_channels
                Number of output channels.
    
  • bilinear
                Uses bilinear upsampling in upsampling layers when set True.
    
  • bias
                Set True to let convolutional layers learn a bias term.
    
  • activation
                Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
    
Source code in odak/learn/models/models.py
def __init__(
             self, 
             depth = 4,
             dimensions = 64, 
             input_channels = 2, 
             output_channels = 1, 
             bilinear = False,
             kernel_size = 3,
             bias = False,
             activation = torch.nn.ReLU(inplace = True),
            ):
    """
    U-Net model.

    Parameters
    ----------
    depth             : int
                        Number of upsampling and downsampling
    dimensions        : int
                        Number of dimensions.
    input_channels    : int
                        Number of input channels.
    output_channels   : int
                        Number of output channels.
    bilinear          : bool
                        Uses bilinear upsampling in upsampling layers when set True.
    bias              : bool
                        Set True to let convolutional layers learn a bias term.
    activation        : torch.nn
                        Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
    """
    super(unet, self).__init__()
    self.inc = double_convolution(
                                  input_channels = input_channels,
                                  mid_channels = dimensions,
                                  output_channels = dimensions,
                                  kernel_size = kernel_size,
                                  bias = bias,
                                  activation = activation
                                 )      

    self.downsampling_layers = torch.nn.ModuleList()
    self.upsampling_layers = torch.nn.ModuleList()
    for i in range(depth): # downsampling layers
        in_channels = dimensions * (2 ** i)
        out_channels = dimensions * (2 ** (i + 1))
        down_layer = downsample_layer(in_channels,
                                        out_channels,
                                        kernel_size=kernel_size,
                                        bias=bias,
                                        activation=activation
                                        )
        self.downsampling_layers.append(down_layer)      

    for i in range(depth - 1, -1, -1):  # upsampling layers
        up_in_channels = dimensions * (2 ** (i + 1))  
        up_out_channels = dimensions * (2 ** i) 
        up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)
        self.upsampling_layers.append(up_layer)
    self.outc = torch.nn.Conv2d(
                                dimensions, 
                                output_channels,
                                kernel_size = kernel_size,
                                padding = kernel_size // 2,
                                bias = bias
                               )

forward(x)

Forward model.

Parameters:

  • x
            Input data.
    

Returns:

  • result ( tensor ) –

    Estimated output.

Source code in odak/learn/models/models.py
def forward(self, x):
    """
    Forward model.

    Parameters
    ----------
    x             : torch.tensor
                    Input data.


    Returns
    ----------
    result        : torch.tensor
                    Estimated output.      
    """
    downsampling_outputs = [self.inc(x)]
    for down_layer in self.downsampling_layers:
        x_down = down_layer(downsampling_outputs[-1])
        downsampling_outputs.append(x_down)
    x_up = downsampling_outputs[-1]
    for i, up_layer in enumerate((self.upsampling_layers)):
        x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       
    result = self.outc(x_up)
    return result

upsample_convtranspose2d_layer

Bases: Module

An upsampling convtranspose2d layer.

Source code in odak/learn/models/components.py
class upsample_convtranspose2d_layer(torch.nn.Module):
    """
    An upsampling convtranspose2d layer.
    """
    def __init__(
                 self,
                 input_channels,
                 output_channels,
                 kernel_size = 2,
                 stride = 2,
                 bias = False,
                ):
        """
        A downscaling component with a double convolution.

        Parameters
        ----------
        input_channels  : int
                          Number of input channels.
        output_channels : int
                          Number of output channels.
        kernel_size     : int
                          Kernel size.
        bias            : bool
                          Set to True to let convolutional layers have bias term.
        """
        super().__init__()
        self.up = torch.nn.ConvTranspose2d(
                                           in_channels = input_channels,
                                           out_channels = output_channels,
                                           bias = bias,
                                           kernel_size = kernel_size,
                                           stride = stride
                                          )

    def forward(self, x1, x2):
        """
        Forward model.

        Parameters
        ----------
        x1             : torch.tensor
                         First input data.
        x2             : torch.tensor
                         Second input data.


        Returns
        ----------
        result        : torch.tensor
                        Result of the forward operation
        """
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                                          diffY // 2, diffY - diffY // 2])
        result = x1 + x2
        return result

__init__(input_channels, output_channels, kernel_size=2, stride=2, bias=False)

A downscaling component with a double convolution.

Parameters:

  • input_channels
              Number of input channels.
    
  • output_channels (int) –
              Number of output channels.
    
  • kernel_size
              Kernel size.
    
  • bias
              Set to True to let convolutional layers have bias term.
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels,
             output_channels,
             kernel_size = 2,
             stride = 2,
             bias = False,
            ):
    """
    A downscaling component with a double convolution.

    Parameters
    ----------
    input_channels  : int
                      Number of input channels.
    output_channels : int
                      Number of output channels.
    kernel_size     : int
                      Kernel size.
    bias            : bool
                      Set to True to let convolutional layers have bias term.
    """
    super().__init__()
    self.up = torch.nn.ConvTranspose2d(
                                       in_channels = input_channels,
                                       out_channels = output_channels,
                                       bias = bias,
                                       kernel_size = kernel_size,
                                       stride = stride
                                      )

forward(x1, x2)

Forward model.

Parameters:

  • x1
             First input data.
    
  • x2
             Second input data.
    

Returns:

  • result ( tensor ) –

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x1, x2):
    """
    Forward model.

    Parameters
    ----------
    x1             : torch.tensor
                     First input data.
    x2             : torch.tensor
                     Second input data.


    Returns
    ----------
    result        : torch.tensor
                    Result of the forward operation
    """
    x1 = self.up(x1)
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]
    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                                      diffY // 2, diffY - diffY // 2])
    result = x1 + x2
    return result

upsample_layer

Bases: Module

An upsampling convolutional layer.

Source code in odak/learn/models/components.py
class upsample_layer(torch.nn.Module):
    """
    An upsampling convolutional layer.
    """
    def __init__(
                 self,
                 input_channels,
                 output_channels,
                 kernel_size = 3,
                 bias = False,
                 normalization = False,
                 activation = torch.nn.ReLU(),
                 bilinear = True
                ):
        """
        A downscaling component with a double convolution.

        Parameters
        ----------
        input_channels  : int
                          Number of input channels.
        output_channels : int
                          Number of output channels.
        kernel_size     : int
                          Kernel size.
        bias            : bool 
                          Set to True to let convolutional layers have bias term.
        normalization   : bool                
                          If True, adds a Batch Normalization layer after the convolutional layer.
        activation      : torch.nn
                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
        bilinear        : bool
                          If set to True, bilinear sampling is used.
        """
        super(upsample_layer, self).__init__()
        if bilinear:
            self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
            self.conv = double_convolution(
                                           input_channels = input_channels + output_channels,
                                           mid_channels = input_channels // 2,
                                           output_channels = output_channels,
                                           kernel_size = kernel_size,
                                           normalization = normalization,
                                           bias = bias,
                                           activation = activation
                                          )
        else:
            self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)
            self.conv = double_convolution(
                                           input_channels = input_channels,
                                           mid_channels = output_channels,
                                           output_channels = output_channels,
                                           kernel_size = kernel_size,
                                           normalization = normalization,
                                           bias = bias,
                                           activation = activation
                                          )


    def forward(self, x1, x2):
        """
        Forward model.

        Parameters
        ----------
        x1             : torch.tensor
                         First input data.
        x2             : torch.tensor
                         Second input data.


        Returns
        ----------
        result        : torch.tensor
                        Result of the forward operation
        """ 
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                                          diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim = 1)
        result = self.conv(x)
        return result

__init__(input_channels, output_channels, kernel_size=3, bias=False, normalization=False, activation=torch.nn.ReLU(), bilinear=True)

A downscaling component with a double convolution.

Parameters:

  • input_channels
              Number of input channels.
    
  • output_channels (int) –
              Number of output channels.
    
  • kernel_size
              Kernel size.
    
  • bias
              Set to True to let convolutional layers have bias term.
    
  • normalization
              If True, adds a Batch Normalization layer after the convolutional layer.
    
  • activation
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    
  • bilinear
              If set to True, bilinear sampling is used.
    
Source code in odak/learn/models/components.py
def __init__(
             self,
             input_channels,
             output_channels,
             kernel_size = 3,
             bias = False,
             normalization = False,
             activation = torch.nn.ReLU(),
             bilinear = True
            ):
    """
    A downscaling component with a double convolution.

    Parameters
    ----------
    input_channels  : int
                      Number of input channels.
    output_channels : int
                      Number of output channels.
    kernel_size     : int
                      Kernel size.
    bias            : bool 
                      Set to True to let convolutional layers have bias term.
    normalization   : bool                
                      If True, adds a Batch Normalization layer after the convolutional layer.
    activation      : torch.nn
                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    bilinear        : bool
                      If set to True, bilinear sampling is used.
    """
    super(upsample_layer, self).__init__()
    if bilinear:
        self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
        self.conv = double_convolution(
                                       input_channels = input_channels + output_channels,
                                       mid_channels = input_channels // 2,
                                       output_channels = output_channels,
                                       kernel_size = kernel_size,
                                       normalization = normalization,
                                       bias = bias,
                                       activation = activation
                                      )
    else:
        self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)
        self.conv = double_convolution(
                                       input_channels = input_channels,
                                       mid_channels = output_channels,
                                       output_channels = output_channels,
                                       kernel_size = kernel_size,
                                       normalization = normalization,
                                       bias = bias,
                                       activation = activation
                                      )

forward(x1, x2)

Forward model.

Parameters:

  • x1
             First input data.
    
  • x2
             Second input data.
    

Returns:

  • result ( tensor ) –

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x1, x2):
    """
    Forward model.

    Parameters
    ----------
    x1             : torch.tensor
                     First input data.
    x2             : torch.tensor
                     Second input data.


    Returns
    ----------
    result        : torch.tensor
                    Result of the forward operation
    """ 
    x1 = self.up(x1)
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]
    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                                      diffY // 2, diffY - diffY // 2])
    x = torch.cat([x2, x1], dim = 1)
    result = self.conv(x)
    return result

gaussian(x, multiplier=1.0)

A Gaussian non-linear activation. For more details: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.

Parameters:

  • x
           Input data.
    
  • multiplier
           Multiplier.
    

Returns:

  • result ( float or tensor ) –

    Ouput data.

Source code in odak/learn/models/components.py
def gaussian(x, multiplier = 1.):
    """
    A Gaussian non-linear activation.
    For more details: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.

    Parameters
    ----------
    x            : float or torch.tensor
                   Input data.
    multiplier   : float or torch.tensor
                   Multiplier.

    Returns
    -------
    result       : float or torch.tensor
                   Ouput data.
    """
    result = torch.exp(- (multiplier * x) ** 2)
    return result

swish(x)

A swish non-linear activation. For more details: https://en.wikipedia.org/wiki/Swish_function

Parameters:

  • x
             Input.
    

Returns:

  • out ( float or tensor ) –

    Output.

Source code in odak/learn/models/components.py
def swish(x):
    """
    A swish non-linear activation.
    For more details: https://en.wikipedia.org/wiki/Swish_function

    Parameters
    -----------
    x              : float or torch.tensor
                     Input.

    Returns
    -------
    out            : float or torch.tensor
                     Output.
    """
    out = x * torch.sigmoid(x)
    return out

multi_color_hologram_optimizer

A class for optimizing single or multi color holograms. For more details, see Kavaklı et al., SIGGRAPH ASIA 2023, Multi-color Holograms Improve Brightness in HOlographic Displays.

Source code in odak/learn/wave/optimizers.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
class multi_color_hologram_optimizer():
    """
    A class for optimizing single or multi color holograms.
    For more details, see Kavaklı et al., SIGGRAPH ASIA 2023, Multi-color Holograms Improve Brightness in HOlographic Displays.
    """
    def __init__(self,
                 wavelengths,
                 resolution,
                 targets,
                 propagator,
                 number_of_frames = 3,
                 number_of_depth_layers = 1,
                 learning_rate = 2e-2,
                 learning_rate_floor = 5e-3,
                 double_phase = True,
                 scale_factor = 1,
                 method = 'multi-color',
                 channel_power_filename = '',
                 device = None,
                 loss_function = None,
                 peak_amplitude = 1.0,
                 optimize_peak_amplitude = False,
                 img_loss_thres = 2e-3,
                 reduction = 'sum'
                ):
        self.device = device
        if isinstance(self.device, type(None)):
            self.device = torch.device("cpu")
        torch.cuda.empty_cache()
        torch.random.seed()
        self.wavelengths = wavelengths
        self.resolution = resolution
        self.targets = targets
        if propagator.propagation_type != 'Impulse Response Fresnel':
            scale_factor = 1
        self.scale_factor = scale_factor
        self.propagator = propagator
        self.learning_rate = learning_rate
        self.learning_rate_floor = learning_rate_floor
        self.number_of_channels = len(self.wavelengths)
        self.number_of_frames = number_of_frames
        self.number_of_depth_layers = number_of_depth_layers
        self.double_phase = double_phase
        self.channel_power_filename = channel_power_filename
        self.method = method
        if self.method != 'conventional' and self.method != 'multi-color':
           logging.warning('Unknown optimization method. Options are conventional or multi-color.')
           import sys
           sys.exit()
        self.peak_amplitude = peak_amplitude
        self.optimize_peak_amplitude = optimize_peak_amplitude
        if self.optimize_peak_amplitude:
            self.init_peak_amplitude_scale()
        self.img_loss_thres = img_loss_thres
        self.kernels = []
        self.init_phase()
        self.init_channel_power()
        self.init_loss_function(loss_function, reduction = reduction)
        self.init_amplitude()
        self.init_phase_scale()


    def init_peak_amplitude_scale(self):
        """
        Internal function to set the phase scale.
        """
        self.peak_amplitude = torch.tensor(
                                           self.peak_amplitude,
                                           requires_grad = True,
                                           device=self.device
                                          )


    def init_phase_scale(self):
        """
        Internal function to set the phase scale.
        """
        if self.method == 'conventional':
            self.phase_scale = torch.tensor(
                                            [
                                             1.,
                                             1.,
                                             1.
                                            ],
                                            requires_grad = False,
                                            device = self.device
                                           )
        if self.method == 'multi-color':
            self.phase_scale = torch.tensor(
                                            [
                                             1.,
                                             1.,
                                             1.
                                            ],
                                            requires_grad = False,
                                            device = self.device
                                           )


    def init_amplitude(self):
        """
        Internal function to set the amplitude of the illumination source.
        """
        self.amplitude = torch.zeros(
                                     self.resolution[0] * self.scale_factor,
                                     self.resolution[1] * self.scale_factor,
                                     requires_grad = False,
                                     device = self.device
                                    )
        self.amplitude[::self.scale_factor, ::self.scale_factor] = 1.


    def init_phase(self):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        self.phase = torch.zeros(
                                 self.number_of_frames,
                                 self.resolution[0],
                                 self.resolution[1],
                                 device = self.device,
                                 requires_grad = True
                                )
        self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)


    def init_channel_power(self):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        if self.method == 'conventional':
            logging.warning('Scheme: Conventional')
            self.channel_power = torch.eye(
                                           self.number_of_frames,
                                           self.number_of_channels,
                                           device = self.device,
                                           requires_grad = False
                                          )

        elif self.method == 'multi-color':
            logging.warning('Scheme: Multi-color')
            self.channel_power = torch.ones(
                                            self.number_of_frames,
                                            self.number_of_channels,
                                            device = self.device,
                                            requires_grad = True
                                           )
        if self.channel_power_filename != '':
            self.channel_power = torch_load(self.channel_power_filename).to(self.device)
            self.channel_power.requires_grad = False
            self.channel_power[self.channel_power < 0.] = 0.
            self.channel_power[self.channel_power > 1.] = 1.
            if self.method == 'multi-color':
                self.channel_power.requires_grad = True
            if self.method == 'conventional':
                self.channel_power = torch.abs(torch.cos(self.channel_power))
            logging.warning('Channel powers:')
            logging.warning(self.channel_power)
            logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))
        self.propagator.set_laser_powers(self.channel_power)



    def init_optimizer(self):
        """
        Internal function to set the optimizer.
        """
        optimization_variables = [self.phase, self.offset]
        if self.optimize_peak_amplitude:
            optimization_variables.append(self.peak_amplitude)
        if self.method == 'multi-color':
            optimization_variables.append(self.propagator.channel_power)
        self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)


    def init_loss_function(self, loss_function, reduction = 'sum'):
        """
        Internal function to set the loss function.
        """
        self.l2_loss = torch.nn.MSELoss(reduction = reduction)
        self.loss_type = 'custom'
        self.loss_function = loss_function
        if isinstance(self.loss_function, type(None)):
            self.loss_type = 'conventional'
            self.loss_function = torch.nn.MSELoss(reduction = reduction)



    def evaluate(self, input_image, target_image, plane_id = 0, noise_ratio = 1e-3, inject_noise = False):
        """
        Internal function to evaluate the loss.
        """
        if self.loss_type == 'conventional':
            loss = self.loss_function(input_image, target_image)
        elif self.loss_type == 'custom':
            loss = 0
            for i in range(len(self.wavelengths)):
                loss += self.loss_function(
                                           input_image[i],
                                           target_image[i],
                                           plane_id = plane_id,
                                           noise_ratio = noise_ratio,
                                           inject_noise = inject_noise
                                          )
        return loss


    def double_phase_constrain(self, phase, phase_offset):
        """
        Internal function to constrain a given phase similarly to double phase encoding.

        Parameters
        ----------
        phase                      : torch.tensor
                                     Input phase values to be constrained.
        phase_offset               : torch.tensor
                                     Input phase offset value.

        Returns
        -------
        phase_only                 : torch.tensor
                                     Constrained output phase.
        """
        phase_zero_mean = phase - torch.mean(phase)
        phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * torch.pi)
        phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * torch.pi)
        loss = multi_scale_total_variation_loss(phase_low, levels = 6)
        loss += multi_scale_total_variation_loss(phase_high, levels = 6)
        loss += torch.std(phase_low)
        loss += torch.std(phase_high)
        phase_only = torch.zeros_like(phase)
        phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
        phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
        phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
        phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
        return phase_only, loss


    def direct_phase_constrain(self, phase, phase_offset):
        """
        Internal function to constrain a given phase.

        Parameters
        ----------
        phase                      : torch.tensor
                                     Input phase values to be constrained.
        phase_offset               : torch.tensor
                                     Input phase offset value.

        Returns
        -------
        phase_only                 : torch.tensor
                                     Constrained output phase.
        """
        phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * torch.pi)
        loss = multi_scale_total_variation_loss(phase, levels = 6)
        loss += multi_scale_total_variation_loss(phase_offset, levels = 6)
        return phase_only, loss


    def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.], inject_noise = False, noise_ratio  = 1e-3):
        """
        Function to optimize multiplane phase-only holograms using stochastic gradient descent.

        Parameters
        ----------
        number_of_iterations       : float
                                     Number of iterations.
        weights                    : list
                                     Weights used in the loss function.
        inject_noise               : bool
                                     When set True, this will inject noise with the given `noise_ratio` to the target images.
        noise_ratio                : float
                                     Noise ratio, a multiplier (1e-3 is 0.1 percent).

        Returns
        -------
        hologram                   : torch.tensor
                                     Optimised hologram.
        """
        hologram_phases = torch.zeros(
                                      self.number_of_frames,
                                      self.resolution[0],
                                      self.resolution[1],
                                      device = self.device
                                     )
        t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)
        if self.optimize_peak_amplitude:
            peak_amp_cache = self.peak_amplitude.item()
        for step in t:
            for g in self.optimizer.param_groups:
                g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations
                if g['lr'] < self.learning_rate_floor:
                    g['lr'] = self.learning_rate_floor
                learning_rate = g['lr']
            total_loss = 0
            t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)
            for depth_id in t_depth:
                self.optimizer.zero_grad()
                depth_target = self.targets[depth_id]
                reconstruction_intensities = torch.zeros(
                                                         self.number_of_frames,
                                                         self.number_of_channels,
                                                         self.resolution[0] * self.scale_factor,
                                                         self.resolution[1] * self.scale_factor,
                                                         device = self.device
                                                        )
                loss_variation_hologram = 0
                laser_powers = self.propagator.get_laser_powers()
                for frame_id in range(self.number_of_frames):
                    if self.double_phase:
                        phase, loss_phase = self.double_phase_constrain(
                                                                        self.phase[frame_id],
                                                                        self.offset[frame_id]
                                                                       )
                    else:
                        phase, loss_phase = self.direct_phase_constrain(
                                                                        self.phase[frame_id],
                                                                        self.offset[frame_id]
                                                                       )
                    loss_variation_hologram += loss_phase
                    for channel_id in range(self.number_of_channels):
                        phase_scaled = torch.zeros_like(self.amplitude)
                        phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
                        laser_power = laser_powers[frame_id][channel_id]
                        hologram = generate_complex_field(
                                                          laser_power * self.amplitude,
                                                          phase_scaled * self.phase_scale[channel_id]
                                                         )
                        reconstruction_field = self.propagator(hologram, channel_id, depth_id)
                        intensity = calculate_amplitude(reconstruction_field) ** 2
                        reconstruction_intensities[frame_id, channel_id] += intensity
                    hologram_phases[frame_id] = phase.detach().clone()
                loss_laser = self.l2_loss(
                                          torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,
                                          torch.sum(laser_powers, dim = 0)
                                         )
                loss_laser += self.l2_loss(
                                           torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),
                                           torch.sum(laser_powers).view(1,)
                                          )
                loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))
                reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)
                loss_image = self.evaluate(
                                           reconstruction_intensity,
                                           depth_target * self.peak_amplitude,
                                           noise_ratio = noise_ratio,
                                           inject_noise = inject_noise,
                                           plane_id = depth_id
                                          )
                loss = weights[0] * loss_image
                loss += weights[1] * loss_laser
                loss += weights[2] * loss_variation_hologram
                include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres
                if include_pa_loss_flag:
                    loss -= self.peak_amplitude * 1.
                if self.method == 'conventional':
                    loss.backward()
                else:
                    loss.backward(retain_graph = True)
                self.optimizer.step()
                if include_pa_loss_flag:
                    peak_amp_cache = self.peak_amplitude.item()
                else:
                    with torch.no_grad():
                        if self.optimize_peak_amplitude:
                            self.peak_amplitude.view([1])[0] = peak_amp_cache
                total_loss += loss.detach().item()
                loss_image = loss_image.detach()
                del loss_laser
                del loss_variation_hologram
                del loss
            description = "Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)
            t.set_description(description)
            del total_loss
            del loss_image
            del reconstruction_field
            del reconstruction_intensities
            del intensity
            del phase
            del hologram
        logging.warning(description)
        return hologram_phases.detach()


    def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8, inject_noise = False, noise_ratio = 1e-3):
        """
        Function to optimize multiplane phase-only holograms.

        Parameters
        ----------
        number_of_iterations       : int
                                     Number of iterations.
        weights                    : list
                                     Loss weights.
        bits                       : int
                                     Quantizes the hologram using the given bits and reconstructs.
        inject_noise               : bool
                                     When set True, this will inject noise with the given `noise_ratio` to the target images.
        noise_ratio                : float
                                     Noise ratio, a multiplier (1e-3 is 0.1 percent).


        Returns
        -------
        hologram_phases            : torch.tensor
                                     Phases of the optimized phase-only hologram.
        reconstruction_intensities : torch.tensor
                                     Intensities of the images reconstructed at each plane with the optimized phase-only hologram.
        """
        self.init_optimizer()
        hologram_phases = self.gradient_descent(
                                                number_of_iterations=number_of_iterations,
                                                noise_ratio = noise_ratio,
                                                inject_noise = inject_noise,
                                                weights=weights
                                               )
        hologram_phases = quantize(hologram_phases % (2 * torch.pi), bits = bits, limits = [0., 2 * torch.pi]) / 2 ** bits * 2 * torch.pi
        torch.no_grad()
        reconstruction_intensities = self.propagator.reconstruct(hologram_phases)
        laser_powers = self.propagator.get_laser_powers()
        channel_powers = self.propagator.channel_power
        logging.warning("Final peak amplitude: {}".format(self.peak_amplitude))
        logging.warning('Laser powers: {}'.format(laser_powers))
        return hologram_phases, reconstruction_intensities, laser_powers, channel_powers, float(self.peak_amplitude)

direct_phase_constrain(phase, phase_offset)

Internal function to constrain a given phase.

Parameters:

  • phase
                         Input phase values to be constrained.
    
  • phase_offset
                         Input phase offset value.
    

Returns:

  • phase_only ( tensor ) –

    Constrained output phase.

Source code in odak/learn/wave/optimizers.py
def direct_phase_constrain(self, phase, phase_offset):
    """
    Internal function to constrain a given phase.

    Parameters
    ----------
    phase                      : torch.tensor
                                 Input phase values to be constrained.
    phase_offset               : torch.tensor
                                 Input phase offset value.

    Returns
    -------
    phase_only                 : torch.tensor
                                 Constrained output phase.
    """
    phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * torch.pi)
    loss = multi_scale_total_variation_loss(phase, levels = 6)
    loss += multi_scale_total_variation_loss(phase_offset, levels = 6)
    return phase_only, loss

double_phase_constrain(phase, phase_offset)

Internal function to constrain a given phase similarly to double phase encoding.

Parameters:

  • phase
                         Input phase values to be constrained.
    
  • phase_offset
                         Input phase offset value.
    

Returns:

  • phase_only ( tensor ) –

    Constrained output phase.

Source code in odak/learn/wave/optimizers.py
def double_phase_constrain(self, phase, phase_offset):
    """
    Internal function to constrain a given phase similarly to double phase encoding.

    Parameters
    ----------
    phase                      : torch.tensor
                                 Input phase values to be constrained.
    phase_offset               : torch.tensor
                                 Input phase offset value.

    Returns
    -------
    phase_only                 : torch.tensor
                                 Constrained output phase.
    """
    phase_zero_mean = phase - torch.mean(phase)
    phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * torch.pi)
    phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * torch.pi)
    loss = multi_scale_total_variation_loss(phase_low, levels = 6)
    loss += multi_scale_total_variation_loss(phase_high, levels = 6)
    loss += torch.std(phase_low)
    loss += torch.std(phase_high)
    phase_only = torch.zeros_like(phase)
    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
    return phase_only, loss

evaluate(input_image, target_image, plane_id=0, noise_ratio=0.001, inject_noise=False)

Internal function to evaluate the loss.

Source code in odak/learn/wave/optimizers.py
def evaluate(self, input_image, target_image, plane_id = 0, noise_ratio = 1e-3, inject_noise = False):
    """
    Internal function to evaluate the loss.
    """
    if self.loss_type == 'conventional':
        loss = self.loss_function(input_image, target_image)
    elif self.loss_type == 'custom':
        loss = 0
        for i in range(len(self.wavelengths)):
            loss += self.loss_function(
                                       input_image[i],
                                       target_image[i],
                                       plane_id = plane_id,
                                       noise_ratio = noise_ratio,
                                       inject_noise = inject_noise
                                      )
    return loss

gradient_descent(number_of_iterations=100, weights=[1.0, 1.0, 0.0, 0.0], inject_noise=False, noise_ratio=0.001)

Function to optimize multiplane phase-only holograms using stochastic gradient descent.

Parameters:

  • number_of_iterations
                         Number of iterations.
    
  • weights
                         Weights used in the loss function.
    
  • inject_noise
                         When set True, this will inject noise with the given `noise_ratio` to the target images.
    
  • noise_ratio
                         Noise ratio, a multiplier (1e-3 is 0.1 percent).
    

Returns:

  • hologram ( tensor ) –

    Optimised hologram.

Source code in odak/learn/wave/optimizers.py
def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.], inject_noise = False, noise_ratio  = 1e-3):
    """
    Function to optimize multiplane phase-only holograms using stochastic gradient descent.

    Parameters
    ----------
    number_of_iterations       : float
                                 Number of iterations.
    weights                    : list
                                 Weights used in the loss function.
    inject_noise               : bool
                                 When set True, this will inject noise with the given `noise_ratio` to the target images.
    noise_ratio                : float
                                 Noise ratio, a multiplier (1e-3 is 0.1 percent).

    Returns
    -------
    hologram                   : torch.tensor
                                 Optimised hologram.
    """
    hologram_phases = torch.zeros(
                                  self.number_of_frames,
                                  self.resolution[0],
                                  self.resolution[1],
                                  device = self.device
                                 )
    t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)
    if self.optimize_peak_amplitude:
        peak_amp_cache = self.peak_amplitude.item()
    for step in t:
        for g in self.optimizer.param_groups:
            g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations
            if g['lr'] < self.learning_rate_floor:
                g['lr'] = self.learning_rate_floor
            learning_rate = g['lr']
        total_loss = 0
        t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)
        for depth_id in t_depth:
            self.optimizer.zero_grad()
            depth_target = self.targets[depth_id]
            reconstruction_intensities = torch.zeros(
                                                     self.number_of_frames,
                                                     self.number_of_channels,
                                                     self.resolution[0] * self.scale_factor,
                                                     self.resolution[1] * self.scale_factor,
                                                     device = self.device
                                                    )
            loss_variation_hologram = 0
            laser_powers = self.propagator.get_laser_powers()
            for frame_id in range(self.number_of_frames):
                if self.double_phase:
                    phase, loss_phase = self.double_phase_constrain(
                                                                    self.phase[frame_id],
                                                                    self.offset[frame_id]
                                                                   )
                else:
                    phase, loss_phase = self.direct_phase_constrain(
                                                                    self.phase[frame_id],
                                                                    self.offset[frame_id]
                                                                   )
                loss_variation_hologram += loss_phase
                for channel_id in range(self.number_of_channels):
                    phase_scaled = torch.zeros_like(self.amplitude)
                    phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
                    laser_power = laser_powers[frame_id][channel_id]
                    hologram = generate_complex_field(
                                                      laser_power * self.amplitude,
                                                      phase_scaled * self.phase_scale[channel_id]
                                                     )
                    reconstruction_field = self.propagator(hologram, channel_id, depth_id)
                    intensity = calculate_amplitude(reconstruction_field) ** 2
                    reconstruction_intensities[frame_id, channel_id] += intensity
                hologram_phases[frame_id] = phase.detach().clone()
            loss_laser = self.l2_loss(
                                      torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,
                                      torch.sum(laser_powers, dim = 0)
                                     )
            loss_laser += self.l2_loss(
                                       torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),
                                       torch.sum(laser_powers).view(1,)
                                      )
            loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))
            reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)
            loss_image = self.evaluate(
                                       reconstruction_intensity,
                                       depth_target * self.peak_amplitude,
                                       noise_ratio = noise_ratio,
                                       inject_noise = inject_noise,
                                       plane_id = depth_id
                                      )
            loss = weights[0] * loss_image
            loss += weights[1] * loss_laser
            loss += weights[2] * loss_variation_hologram
            include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres
            if include_pa_loss_flag:
                loss -= self.peak_amplitude * 1.
            if self.method == 'conventional':
                loss.backward()
            else:
                loss.backward(retain_graph = True)
            self.optimizer.step()
            if include_pa_loss_flag:
                peak_amp_cache = self.peak_amplitude.item()
            else:
                with torch.no_grad():
                    if self.optimize_peak_amplitude:
                        self.peak_amplitude.view([1])[0] = peak_amp_cache
            total_loss += loss.detach().item()
            loss_image = loss_image.detach()
            del loss_laser
            del loss_variation_hologram
            del loss
        description = "Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)
        t.set_description(description)
        del total_loss
        del loss_image
        del reconstruction_field
        del reconstruction_intensities
        del intensity
        del phase
        del hologram
    logging.warning(description)
    return hologram_phases.detach()

init_amplitude()

Internal function to set the amplitude of the illumination source.

Source code in odak/learn/wave/optimizers.py
def init_amplitude(self):
    """
    Internal function to set the amplitude of the illumination source.
    """
    self.amplitude = torch.zeros(
                                 self.resolution[0] * self.scale_factor,
                                 self.resolution[1] * self.scale_factor,
                                 requires_grad = False,
                                 device = self.device
                                )
    self.amplitude[::self.scale_factor, ::self.scale_factor] = 1.

init_channel_power()

Internal function to set the starting phase of the phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def init_channel_power(self):
    """
    Internal function to set the starting phase of the phase-only hologram.
    """
    if self.method == 'conventional':
        logging.warning('Scheme: Conventional')
        self.channel_power = torch.eye(
                                       self.number_of_frames,
                                       self.number_of_channels,
                                       device = self.device,
                                       requires_grad = False
                                      )

    elif self.method == 'multi-color':
        logging.warning('Scheme: Multi-color')
        self.channel_power = torch.ones(
                                        self.number_of_frames,
                                        self.number_of_channels,
                                        device = self.device,
                                        requires_grad = True
                                       )
    if self.channel_power_filename != '':
        self.channel_power = torch_load(self.channel_power_filename).to(self.device)
        self.channel_power.requires_grad = False
        self.channel_power[self.channel_power < 0.] = 0.
        self.channel_power[self.channel_power > 1.] = 1.
        if self.method == 'multi-color':
            self.channel_power.requires_grad = True
        if self.method == 'conventional':
            self.channel_power = torch.abs(torch.cos(self.channel_power))
        logging.warning('Channel powers:')
        logging.warning(self.channel_power)
        logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))
    self.propagator.set_laser_powers(self.channel_power)

init_loss_function(loss_function, reduction='sum')

Internal function to set the loss function.

Source code in odak/learn/wave/optimizers.py
def init_loss_function(self, loss_function, reduction = 'sum'):
    """
    Internal function to set the loss function.
    """
    self.l2_loss = torch.nn.MSELoss(reduction = reduction)
    self.loss_type = 'custom'
    self.loss_function = loss_function
    if isinstance(self.loss_function, type(None)):
        self.loss_type = 'conventional'
        self.loss_function = torch.nn.MSELoss(reduction = reduction)

init_optimizer()

Internal function to set the optimizer.

Source code in odak/learn/wave/optimizers.py
def init_optimizer(self):
    """
    Internal function to set the optimizer.
    """
    optimization_variables = [self.phase, self.offset]
    if self.optimize_peak_amplitude:
        optimization_variables.append(self.peak_amplitude)
    if self.method == 'multi-color':
        optimization_variables.append(self.propagator.channel_power)
    self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)

init_peak_amplitude_scale()

Internal function to set the phase scale.

Source code in odak/learn/wave/optimizers.py
def init_peak_amplitude_scale(self):
    """
    Internal function to set the phase scale.
    """
    self.peak_amplitude = torch.tensor(
                                       self.peak_amplitude,
                                       requires_grad = True,
                                       device=self.device
                                      )

init_phase()

Internal function to set the starting phase of the phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def init_phase(self):
    """
    Internal function to set the starting phase of the phase-only hologram.
    """
    self.phase = torch.zeros(
                             self.number_of_frames,
                             self.resolution[0],
                             self.resolution[1],
                             device = self.device,
                             requires_grad = True
                            )
    self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)

init_phase_scale()

Internal function to set the phase scale.

Source code in odak/learn/wave/optimizers.py
def init_phase_scale(self):
    """
    Internal function to set the phase scale.
    """
    if self.method == 'conventional':
        self.phase_scale = torch.tensor(
                                        [
                                         1.,
                                         1.,
                                         1.
                                        ],
                                        requires_grad = False,
                                        device = self.device
                                       )
    if self.method == 'multi-color':
        self.phase_scale = torch.tensor(
                                        [
                                         1.,
                                         1.,
                                         1.
                                        ],
                                        requires_grad = False,
                                        device = self.device
                                       )

optimize(number_of_iterations=100, weights=[1.0, 1.0, 1.0], bits=8, inject_noise=False, noise_ratio=0.001)

Function to optimize multiplane phase-only holograms.

Parameters:

  • number_of_iterations
                         Number of iterations.
    
  • weights
                         Loss weights.
    
  • bits
                         Quantizes the hologram using the given bits and reconstructs.
    
  • inject_noise
                         When set True, this will inject noise with the given `noise_ratio` to the target images.
    
  • noise_ratio
                         Noise ratio, a multiplier (1e-3 is 0.1 percent).
    

Returns:

  • hologram_phases ( tensor ) –

    Phases of the optimized phase-only hologram.

  • reconstruction_intensities ( tensor ) –

    Intensities of the images reconstructed at each plane with the optimized phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8, inject_noise = False, noise_ratio = 1e-3):
    """
    Function to optimize multiplane phase-only holograms.

    Parameters
    ----------
    number_of_iterations       : int
                                 Number of iterations.
    weights                    : list
                                 Loss weights.
    bits                       : int
                                 Quantizes the hologram using the given bits and reconstructs.
    inject_noise               : bool
                                 When set True, this will inject noise with the given `noise_ratio` to the target images.
    noise_ratio                : float
                                 Noise ratio, a multiplier (1e-3 is 0.1 percent).


    Returns
    -------
    hologram_phases            : torch.tensor
                                 Phases of the optimized phase-only hologram.
    reconstruction_intensities : torch.tensor
                                 Intensities of the images reconstructed at each plane with the optimized phase-only hologram.
    """
    self.init_optimizer()
    hologram_phases = self.gradient_descent(
                                            number_of_iterations=number_of_iterations,
                                            noise_ratio = noise_ratio,
                                            inject_noise = inject_noise,
                                            weights=weights
                                           )
    hologram_phases = quantize(hologram_phases % (2 * torch.pi), bits = bits, limits = [0., 2 * torch.pi]) / 2 ** bits * 2 * torch.pi
    torch.no_grad()
    reconstruction_intensities = self.propagator.reconstruct(hologram_phases)
    laser_powers = self.propagator.get_laser_powers()
    channel_powers = self.propagator.channel_power
    logging.warning("Final peak amplitude: {}".format(self.peak_amplitude))
    logging.warning('Laser powers: {}'.format(laser_powers))
    return hologram_phases, reconstruction_intensities, laser_powers, channel_powers, float(self.peak_amplitude)

propagator

A light propagation model that propagates light to desired image plane with two separate propagations. We use this class in our various works including Kavaklı et al., Realistic Defocus Blur for Multiplane Computer-Generated Holography.

Source code in odak/learn/wave/propagators.py
class propagator():
    """
    A light propagation model that propagates light to desired image plane with two separate propagations. 
    We use this class in our various works including `Kavaklı et al., Realistic Defocus Blur for Multiplane Computer-Generated Holography`.
    """
    def __init__(
                 self,
                 resolution = [1920, 1080],
                 wavelengths = [515e-9,],
                 pixel_pitch = 8e-6,
                 resolution_factor = 1,
                 number_of_frames = 1,
                 number_of_depth_layers = 1,
                 volume_depth = 1e-2,
                 image_location_offset = 5e-3,
                 propagation_type = 'Bandlimited Angular Spectrum',
                 propagator_type = 'back and forth',
                 back_and_forth_distance = 0.3,
                 laser_channel_power = None,
                 aperture = None,
                 aperture_size = None,
                 distances = None,
                 aperture_samples = [20, 20, 5, 5],
                 method = 'conventional',
                 device = torch.device('cpu')
                ):
        """
        Parameters
        ----------
        resolution              : list
                                  Resolution.
        wavelengths             : float
                                  Wavelength of light in meters.
        pixel_pitch             : float
                                  Pixel pitch in meters.
        resolution_factor       : int
                                  Resolution factor for scaled simulations.
        number_of_frames        : int
                                  Number of hologram frames.
                                  Typically, there are three frames, each one for a single color primary.
        number_of_depth_layers  : int
                                  Equ-distance number of depth layers within the desired volume. If `distances` parameter is passed, this value will be automatically set to the length of the `distances` verson provided.
        volume_depth            : float
                                  Width of the volume along the propagation direction.
        image_location_offset   : float
                                  Center of the volume along the propagation direction.
        propagation_type        : str
                                  Propagation type. 
                                  See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.
        propagator_type         : str
                                  Propagator type.
                                  The options are `back and forth` and `forward` propagators.
        back_and_forth_distance : float
                                  Zero mode distance for `back and forth` propagator type.
        laser_channel_power     : torch.tensor
                                  Laser channel powers for given number of frames and number of wavelengths.
        aperture                : torch.tensor
                                  Aperture at the Fourier plane.
        aperture_size           : float
                                  Aperture width for a circular aperture.
        aperture_samples        : list
                                  When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
        distances               : torch.tensor
                                  Propagation distances in meters.
        method                  : str
                                  Hologram type conventional or multi-color.
        device                  : torch.device
                                  Device to be used for computation. For more see torch.device().
        """
        self.device = device
        self.pixel_pitch = pixel_pitch
        self.wavelengths = wavelengths
        self.resolution = resolution
        self.propagation_type = propagation_type
        if self.propagation_type != 'Impulse Response Fresnel':
            resolution_factor = 1
        self.resolution_factor = resolution_factor
        self.number_of_frames = number_of_frames
        self.number_of_depth_layers = number_of_depth_layers
        self.number_of_channels = len(self.wavelengths)
        self.volume_depth = volume_depth
        self.image_location_offset = image_location_offset
        self.propagator_type = propagator_type
        self.aperture_samples = aperture_samples
        self.zero_mode_distance = torch.tensor(back_and_forth_distance, device = device)
        self.method = method
        self.aperture = aperture
        self.init_distances(distances)
        self.init_kernels()
        self.init_channel_power(laser_channel_power)
        self.init_phase_scale()
        self.set_aperture(aperture, aperture_size)


    def init_distances(self, distances):
        """
        Internal function to initialize distances.

        Parameters
        ----------
        distances               : torch.tensor
                                  Propagation distances.
        """
        if isinstance(distances, type(None)):
            self.distances = torch.linspace(-self.volume_depth / 2., self.volume_depth / 2., self.number_of_depth_layers) + self.image_location_offset
        else:
            self.distances = torch.as_tensor(distances)
            self.number_of_depth_layers = self.distances.shape[0]
        logging.warning('Distances: {}'.format(self.distances))


    def init_kernels(self):
        """
        Internal function to initialize kernels.
        """
        self.generated_kernels = torch.zeros(
                                             self.number_of_depth_layers,
                                             self.number_of_channels,
                                             device = self.device
                                            )
        self.kernels = torch.zeros(
                                   self.number_of_depth_layers,
                                   self.number_of_channels,
                                   self.resolution[0] * self.resolution_factor * 2,
                                   self.resolution[1] * self.resolution_factor * 2,
                                   dtype = torch.complex64,
                                   device = self.device
                                  )


    def init_channel_power(self, channel_power):
        """
        Internal function to set the starting phase of the phase-only hologram.
        """
        self.channel_power = channel_power
        if isinstance(self.channel_power, type(None)):
            self.channel_power = torch.eye(
                                           self.number_of_frames,
                                           self.number_of_channels,
                                           device = self.device,
                                           requires_grad = False
                                          )


    def init_phase_scale(self):
        """
        Internal function to set the phase scale.
        In some cases, you may want to modify this init to ratio phases for different color primaries as an SLM is configured for a specific central wavelength.
        """
        self.phase_scale = torch.tensor(
                                        [
                                         1.,
                                         1.,
                                         1.
                                        ],
                                        requires_grad = False,
                                        device = self.device
                                       )


    def set_aperture(self, aperture = None, aperture_size = None):
        """
        Set aperture in the Fourier plane.


        Parameters
        ----------
        aperture        : torch.tensor
                          Aperture at the original resolution of a hologram.
                          If aperture is provided as None, it will assign a circular aperture at the size of the short edge (width or height).
        aperture_size   : int
                          If no aperture is provided, this will determine the size of the circular aperture.
        """
        if isinstance(aperture, type(None)):
            if isinstance(aperture_size, type(None)):
                aperture_size = torch.max(
                                          torch.tensor([
                                                        self.resolution[0] * self.resolution_factor, 
                                                        self.resolution[1] * self.resolution_factor
                                                       ])
                                         )
            self.aperture = circular_binary_mask(
                                                 self.resolution[0] * self.resolution_factor * 2,
                                                 self.resolution[1] * self.resolution_factor * 2,
                                                 aperture_size,
                                                ).to(self.device) * 1.
        else:
            self.aperture = zero_pad(aperture).to(self.device) * 1.


    def get_laser_powers(self):
        """
        Internal function to get the laser powers.

        Returns
        -------
        laser_power      : torch.tensor
                           Laser powers.
        """
        if self.method == 'conventional':
            laser_power = self.channel_power
        if self.method == 'multi-color':
            laser_power = torch.abs(torch.cos(self.channel_power))
        return laser_power


    def set_laser_powers(self, laser_power):
        """
        Internal function to set the laser powers.

        Parameters
        -------
        laser_power      : torch.tensor
                           Laser powers.
        """
        self.channel_power = laser_power



    def get_kernels(self):
        """
        Function to return the kernels used in the light transport.

        Returns
        -------
        kernels           : torch.tensor
                            Kernel amplitudes.
        """
        h = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(self.kernels)))
        kernels_amplitude = calculate_amplitude(h)
        kernels_phase = calculate_phase(h)
        return kernels_amplitude, kernels_phase


    def __call__(self, input_field, channel_id, depth_id):
        """
        Function that represents the forward model in hologram optimization.

        Parameters
        ----------
        input_field         : torch.tensor
                              Input complex input field.
        channel_id          : int
                              Identifying the color primary to be used.
        depth_id            : int
                              Identifying the depth layer to be used.

        Returns
        -------
        output_field        : torch.tensor
                              Propagated output complex field.
        """
        distance = self.distances[depth_id]
        if not self.generated_kernels[depth_id, channel_id]:
            if self.propagator_type == 'forward':
                H = get_propagation_kernel(
                                           nu = self.resolution[0] * 2,
                                           nv = self.resolution[1] * 2,
                                           dx = self.pixel_pitch,
                                           wavelength = self.wavelengths[channel_id],
                                           distance = distance,
                                           device = self.device,
                                           propagation_type = self.propagation_type,
                                           samples = self.aperture_samples,
                                           scale = self.resolution_factor
                                          )
            elif self.propagator_type == 'back and forth':
                H_forward = get_propagation_kernel(
                                                   nu = self.resolution[0] * 2,
                                                   nv = self.resolution[1] * 2,
                                                   dx = self.pixel_pitch,
                                                   wavelength = self.wavelengths[channel_id],
                                                   distance = self.zero_mode_distance,
                                                   device = self.device,
                                                   propagation_type = self.propagation_type,
                                                   samples = self.aperture_samples,
                                                   scale = self.resolution_factor
                                                  )
                distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)
                H_back = get_propagation_kernel(
                                                nu = self.resolution[0] * 2,
                                                nv = self.resolution[1] * 2,
                                                dx = self.pixel_pitch,
                                                wavelength = self.wavelengths[channel_id],
                                                distance = distance_back,
                                                device = self.device,
                                                propagation_type = self.propagation_type,
                                                samples = self.aperture_samples,
                                                scale = self.resolution_factor
                                               )
                H = H_forward * H_back
            self.kernels[depth_id, channel_id] = H
            self.generated_kernels[depth_id, channel_id] = True
        else:
            H = self.kernels[depth_id, channel_id].detach().clone()
        field_scale = input_field
        field_scale_padded = zero_pad(field_scale)
        output_field_padded = custom(field_scale_padded, H, aperture = self.aperture)
        output_field = crop_center(output_field_padded)
        return output_field


    def reconstruct(self, hologram_phases, amplitude = None, no_grad = True, get_complex = False):
        """
        Internal function to reconstruct a given hologram.


        Parameters
        ----------
        hologram_phases            : torch.tensor
                                     Hologram phases [ch x m x n].
        amplitude                  : torch.tensor
                                     Amplitude profiles for each color primary [ch x m x n]
        no_grad                    : bool
                                     If set True, uses torch.no_grad in reconstruction.
        get_complex                : bool
                                     If set True, reconstructor returns the complex field but not the intensities.

        Returns
        -------
        reconstructions            : torch.tensor
                                     Reconstructed frames.
        """
        if no_grad:
            torch.no_grad()
        if len(hologram_phases.shape) > 3:
            hologram_phases = hologram_phases.squeeze(0)
        if get_complex == True:
            reconstruction_type = torch.complex64
        else:
            reconstruction_type = torch.float32
        if hologram_phases.shape[0] != self.number_of_frames:
            logging.warning('Provided hologram frame count is {} but the configured number of frames is {}.'.format(hologram_phases.shape[0], self.number_of_frames))
        reconstructions = torch.zeros(
                                      self.number_of_frames,
                                      self.number_of_depth_layers,
                                      self.number_of_channels,
                                      self.resolution[0] * self.resolution_factor,
                                      self.resolution[1] * self.resolution_factor,
                                      dtype = reconstruction_type,
                                      device = self.device
                                     )
        if isinstance(amplitude, type(None)):
            amplitude = torch.zeros(
                                    self.number_of_channels,
                                    self.resolution[0] * self.resolution_factor,
                                    self.resolution[1] * self.resolution_factor,
                                    device = self.device
                                   )
            amplitude[:, ::self.resolution_factor, ::self.resolution_factor] = 1.
        if self.resolution_factor != 1:
            hologram_phases_scaled = torch.zeros_like(amplitude)
            hologram_phases_scaled[
                                   :,
                                   ::self.resolution_factor,
                                   ::self.resolution_factor
                                  ] = hologram_phases
        else:
            hologram_phases_scaled = hologram_phases
        for frame_id in range(self.number_of_frames):
            for depth_id in range(self.number_of_depth_layers):
                for channel_id in range(self.number_of_channels):
                    laser_power = self.get_laser_powers()[frame_id][channel_id]
                    phase = hologram_phases_scaled[frame_id]
                    hologram = generate_complex_field(
                                                      laser_power * amplitude[channel_id],
                                                      phase * self.phase_scale[channel_id]
                                                     )
                    reconstruction_field = self.__call__(hologram, channel_id, depth_id)
                    if get_complex == True:
                        result = reconstruction_field
                    else:
                        result = calculate_amplitude(reconstruction_field) ** 2

                    if no_grad: 
                        result = result.detach().clone()

                    reconstructions[
                                    frame_id,
                                    depth_id,
                                    channel_id
                                   ] = result

        return reconstructions

__call__(input_field, channel_id, depth_id)

Function that represents the forward model in hologram optimization.

Parameters:

  • input_field
                  Input complex input field.
    
  • channel_id
                  Identifying the color primary to be used.
    
  • depth_id
                  Identifying the depth layer to be used.
    

Returns:

  • output_field ( tensor ) –

    Propagated output complex field.

Source code in odak/learn/wave/propagators.py
def __call__(self, input_field, channel_id, depth_id):
    """
    Function that represents the forward model in hologram optimization.

    Parameters
    ----------
    input_field         : torch.tensor
                          Input complex input field.
    channel_id          : int
                          Identifying the color primary to be used.
    depth_id            : int
                          Identifying the depth layer to be used.

    Returns
    -------
    output_field        : torch.tensor
                          Propagated output complex field.
    """
    distance = self.distances[depth_id]
    if not self.generated_kernels[depth_id, channel_id]:
        if self.propagator_type == 'forward':
            H = get_propagation_kernel(
                                       nu = self.resolution[0] * 2,
                                       nv = self.resolution[1] * 2,
                                       dx = self.pixel_pitch,
                                       wavelength = self.wavelengths[channel_id],
                                       distance = distance,
                                       device = self.device,
                                       propagation_type = self.propagation_type,
                                       samples = self.aperture_samples,
                                       scale = self.resolution_factor
                                      )
        elif self.propagator_type == 'back and forth':
            H_forward = get_propagation_kernel(
                                               nu = self.resolution[0] * 2,
                                               nv = self.resolution[1] * 2,
                                               dx = self.pixel_pitch,
                                               wavelength = self.wavelengths[channel_id],
                                               distance = self.zero_mode_distance,
                                               device = self.device,
                                               propagation_type = self.propagation_type,
                                               samples = self.aperture_samples,
                                               scale = self.resolution_factor
                                              )
            distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)
            H_back = get_propagation_kernel(
                                            nu = self.resolution[0] * 2,
                                            nv = self.resolution[1] * 2,
                                            dx = self.pixel_pitch,
                                            wavelength = self.wavelengths[channel_id],
                                            distance = distance_back,
                                            device = self.device,
                                            propagation_type = self.propagation_type,
                                            samples = self.aperture_samples,
                                            scale = self.resolution_factor
                                           )
            H = H_forward * H_back
        self.kernels[depth_id, channel_id] = H
        self.generated_kernels[depth_id, channel_id] = True
    else:
        H = self.kernels[depth_id, channel_id].detach().clone()
    field_scale = input_field
    field_scale_padded = zero_pad(field_scale)
    output_field_padded = custom(field_scale_padded, H, aperture = self.aperture)
    output_field = crop_center(output_field_padded)
    return output_field

__init__(resolution=[1920, 1080], wavelengths=[5.15e-07], pixel_pitch=8e-06, resolution_factor=1, number_of_frames=1, number_of_depth_layers=1, volume_depth=0.01, image_location_offset=0.005, propagation_type='Bandlimited Angular Spectrum', propagator_type='back and forth', back_and_forth_distance=0.3, laser_channel_power=None, aperture=None, aperture_size=None, distances=None, aperture_samples=[20, 20, 5, 5], method='conventional', device=torch.device('cpu'))

Parameters:

  • resolution
                      Resolution.
    
  • wavelengths
                      Wavelength of light in meters.
    
  • pixel_pitch
                      Pixel pitch in meters.
    
  • resolution_factor
                      Resolution factor for scaled simulations.
    
  • number_of_frames
                      Number of hologram frames.
                      Typically, there are three frames, each one for a single color primary.
    
  • number_of_depth_layers
                      Equ-distance number of depth layers within the desired volume. If `distances` parameter is passed, this value will be automatically set to the length of the `distances` verson provided.
    
  • volume_depth
                      Width of the volume along the propagation direction.
    
  • image_location_offset
                      Center of the volume along the propagation direction.
    
  • propagation_type
                      Propagation type. 
                      See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.
    
  • propagator_type
                      Propagator type.
                      The options are `back and forth` and `forward` propagators.
    
  • back_and_forth_distance (float, default: 0.3 ) –
                      Zero mode distance for `back and forth` propagator type.
    
  • laser_channel_power
                      Laser channel powers for given number of frames and number of wavelengths.
    
  • aperture
                      Aperture at the Fourier plane.
    
  • aperture_size
                      Aperture width for a circular aperture.
    
  • aperture_samples
                      When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
    
  • distances
                      Propagation distances in meters.
    
  • method
                      Hologram type conventional or multi-color.
    
  • device
                      Device to be used for computation. For more see torch.device().
    
Source code in odak/learn/wave/propagators.py
def __init__(
             self,
             resolution = [1920, 1080],
             wavelengths = [515e-9,],
             pixel_pitch = 8e-6,
             resolution_factor = 1,
             number_of_frames = 1,
             number_of_depth_layers = 1,
             volume_depth = 1e-2,
             image_location_offset = 5e-3,
             propagation_type = 'Bandlimited Angular Spectrum',
             propagator_type = 'back and forth',
             back_and_forth_distance = 0.3,
             laser_channel_power = None,
             aperture = None,
             aperture_size = None,
             distances = None,
             aperture_samples = [20, 20, 5, 5],
             method = 'conventional',
             device = torch.device('cpu')
            ):
    """
    Parameters
    ----------
    resolution              : list
                              Resolution.
    wavelengths             : float
                              Wavelength of light in meters.
    pixel_pitch             : float
                              Pixel pitch in meters.
    resolution_factor       : int
                              Resolution factor for scaled simulations.
    number_of_frames        : int
                              Number of hologram frames.
                              Typically, there are three frames, each one for a single color primary.
    number_of_depth_layers  : int
                              Equ-distance number of depth layers within the desired volume. If `distances` parameter is passed, this value will be automatically set to the length of the `distances` verson provided.
    volume_depth            : float
                              Width of the volume along the propagation direction.
    image_location_offset   : float
                              Center of the volume along the propagation direction.
    propagation_type        : str
                              Propagation type. 
                              See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.
    propagator_type         : str
                              Propagator type.
                              The options are `back and forth` and `forward` propagators.
    back_and_forth_distance : float
                              Zero mode distance for `back and forth` propagator type.
    laser_channel_power     : torch.tensor
                              Laser channel powers for given number of frames and number of wavelengths.
    aperture                : torch.tensor
                              Aperture at the Fourier plane.
    aperture_size           : float
                              Aperture width for a circular aperture.
    aperture_samples        : list
                              When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
    distances               : torch.tensor
                              Propagation distances in meters.
    method                  : str
                              Hologram type conventional or multi-color.
    device                  : torch.device
                              Device to be used for computation. For more see torch.device().
    """
    self.device = device
    self.pixel_pitch = pixel_pitch
    self.wavelengths = wavelengths
    self.resolution = resolution
    self.propagation_type = propagation_type
    if self.propagation_type != 'Impulse Response Fresnel':
        resolution_factor = 1
    self.resolution_factor = resolution_factor
    self.number_of_frames = number_of_frames
    self.number_of_depth_layers = number_of_depth_layers
    self.number_of_channels = len(self.wavelengths)
    self.volume_depth = volume_depth
    self.image_location_offset = image_location_offset
    self.propagator_type = propagator_type
    self.aperture_samples = aperture_samples
    self.zero_mode_distance = torch.tensor(back_and_forth_distance, device = device)
    self.method = method
    self.aperture = aperture
    self.init_distances(distances)
    self.init_kernels()
    self.init_channel_power(laser_channel_power)
    self.init_phase_scale()
    self.set_aperture(aperture, aperture_size)

get_kernels()

Function to return the kernels used in the light transport.

Returns:

  • kernels ( tensor ) –

    Kernel amplitudes.

Source code in odak/learn/wave/propagators.py
def get_kernels(self):
    """
    Function to return the kernels used in the light transport.

    Returns
    -------
    kernels           : torch.tensor
                        Kernel amplitudes.
    """
    h = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(self.kernels)))
    kernels_amplitude = calculate_amplitude(h)
    kernels_phase = calculate_phase(h)
    return kernels_amplitude, kernels_phase

get_laser_powers()

Internal function to get the laser powers.

Returns:

  • laser_power ( tensor ) –

    Laser powers.

Source code in odak/learn/wave/propagators.py
def get_laser_powers(self):
    """
    Internal function to get the laser powers.

    Returns
    -------
    laser_power      : torch.tensor
                       Laser powers.
    """
    if self.method == 'conventional':
        laser_power = self.channel_power
    if self.method == 'multi-color':
        laser_power = torch.abs(torch.cos(self.channel_power))
    return laser_power

init_channel_power(channel_power)

Internal function to set the starting phase of the phase-only hologram.

Source code in odak/learn/wave/propagators.py
def init_channel_power(self, channel_power):
    """
    Internal function to set the starting phase of the phase-only hologram.
    """
    self.channel_power = channel_power
    if isinstance(self.channel_power, type(None)):
        self.channel_power = torch.eye(
                                       self.number_of_frames,
                                       self.number_of_channels,
                                       device = self.device,
                                       requires_grad = False
                                      )

init_distances(distances)

Internal function to initialize distances.

Parameters:

  • distances
                      Propagation distances.
    
Source code in odak/learn/wave/propagators.py
def init_distances(self, distances):
    """
    Internal function to initialize distances.

    Parameters
    ----------
    distances               : torch.tensor
                              Propagation distances.
    """
    if isinstance(distances, type(None)):
        self.distances = torch.linspace(-self.volume_depth / 2., self.volume_depth / 2., self.number_of_depth_layers) + self.image_location_offset
    else:
        self.distances = torch.as_tensor(distances)
        self.number_of_depth_layers = self.distances.shape[0]
    logging.warning('Distances: {}'.format(self.distances))

init_kernels()

Internal function to initialize kernels.

Source code in odak/learn/wave/propagators.py
def init_kernels(self):
    """
    Internal function to initialize kernels.
    """
    self.generated_kernels = torch.zeros(
                                         self.number_of_depth_layers,
                                         self.number_of_channels,
                                         device = self.device
                                        )
    self.kernels = torch.zeros(
                               self.number_of_depth_layers,
                               self.number_of_channels,
                               self.resolution[0] * self.resolution_factor * 2,
                               self.resolution[1] * self.resolution_factor * 2,
                               dtype = torch.complex64,
                               device = self.device
                              )

init_phase_scale()

Internal function to set the phase scale. In some cases, you may want to modify this init to ratio phases for different color primaries as an SLM is configured for a specific central wavelength.

Source code in odak/learn/wave/propagators.py
def init_phase_scale(self):
    """
    Internal function to set the phase scale.
    In some cases, you may want to modify this init to ratio phases for different color primaries as an SLM is configured for a specific central wavelength.
    """
    self.phase_scale = torch.tensor(
                                    [
                                     1.,
                                     1.,
                                     1.
                                    ],
                                    requires_grad = False,
                                    device = self.device
                                   )

reconstruct(hologram_phases, amplitude=None, no_grad=True, get_complex=False)

Internal function to reconstruct a given hologram.

Parameters:

  • hologram_phases
                         Hologram phases [ch x m x n].
    
  • amplitude
                         Amplitude profiles for each color primary [ch x m x n]
    
  • no_grad
                         If set True, uses torch.no_grad in reconstruction.
    
  • get_complex
                         If set True, reconstructor returns the complex field but not the intensities.
    

Returns:

  • reconstructions ( tensor ) –

    Reconstructed frames.

Source code in odak/learn/wave/propagators.py
def reconstruct(self, hologram_phases, amplitude = None, no_grad = True, get_complex = False):
    """
    Internal function to reconstruct a given hologram.


    Parameters
    ----------
    hologram_phases            : torch.tensor
                                 Hologram phases [ch x m x n].
    amplitude                  : torch.tensor
                                 Amplitude profiles for each color primary [ch x m x n]
    no_grad                    : bool
                                 If set True, uses torch.no_grad in reconstruction.
    get_complex                : bool
                                 If set True, reconstructor returns the complex field but not the intensities.

    Returns
    -------
    reconstructions            : torch.tensor
                                 Reconstructed frames.
    """
    if no_grad:
        torch.no_grad()
    if len(hologram_phases.shape) > 3:
        hologram_phases = hologram_phases.squeeze(0)
    if get_complex == True:
        reconstruction_type = torch.complex64
    else:
        reconstruction_type = torch.float32
    if hologram_phases.shape[0] != self.number_of_frames:
        logging.warning('Provided hologram frame count is {} but the configured number of frames is {}.'.format(hologram_phases.shape[0], self.number_of_frames))
    reconstructions = torch.zeros(
                                  self.number_of_frames,
                                  self.number_of_depth_layers,
                                  self.number_of_channels,
                                  self.resolution[0] * self.resolution_factor,
                                  self.resolution[1] * self.resolution_factor,
                                  dtype = reconstruction_type,
                                  device = self.device
                                 )
    if isinstance(amplitude, type(None)):
        amplitude = torch.zeros(
                                self.number_of_channels,
                                self.resolution[0] * self.resolution_factor,
                                self.resolution[1] * self.resolution_factor,
                                device = self.device
                               )
        amplitude[:, ::self.resolution_factor, ::self.resolution_factor] = 1.
    if self.resolution_factor != 1:
        hologram_phases_scaled = torch.zeros_like(amplitude)
        hologram_phases_scaled[
                               :,
                               ::self.resolution_factor,
                               ::self.resolution_factor
                              ] = hologram_phases
    else:
        hologram_phases_scaled = hologram_phases
    for frame_id in range(self.number_of_frames):
        for depth_id in range(self.number_of_depth_layers):
            for channel_id in range(self.number_of_channels):
                laser_power = self.get_laser_powers()[frame_id][channel_id]
                phase = hologram_phases_scaled[frame_id]
                hologram = generate_complex_field(
                                                  laser_power * amplitude[channel_id],
                                                  phase * self.phase_scale[channel_id]
                                                 )
                reconstruction_field = self.__call__(hologram, channel_id, depth_id)
                if get_complex == True:
                    result = reconstruction_field
                else:
                    result = calculate_amplitude(reconstruction_field) ** 2

                if no_grad: 
                    result = result.detach().clone()

                reconstructions[
                                frame_id,
                                depth_id,
                                channel_id
                               ] = result

    return reconstructions

set_aperture(aperture=None, aperture_size=None)

Set aperture in the Fourier plane.

Parameters:

  • aperture
              Aperture at the original resolution of a hologram.
              If aperture is provided as None, it will assign a circular aperture at the size of the short edge (width or height).
    
  • aperture_size
              If no aperture is provided, this will determine the size of the circular aperture.
    
Source code in odak/learn/wave/propagators.py
def set_aperture(self, aperture = None, aperture_size = None):
    """
    Set aperture in the Fourier plane.


    Parameters
    ----------
    aperture        : torch.tensor
                      Aperture at the original resolution of a hologram.
                      If aperture is provided as None, it will assign a circular aperture at the size of the short edge (width or height).
    aperture_size   : int
                      If no aperture is provided, this will determine the size of the circular aperture.
    """
    if isinstance(aperture, type(None)):
        if isinstance(aperture_size, type(None)):
            aperture_size = torch.max(
                                      torch.tensor([
                                                    self.resolution[0] * self.resolution_factor, 
                                                    self.resolution[1] * self.resolution_factor
                                                   ])
                                     )
        self.aperture = circular_binary_mask(
                                             self.resolution[0] * self.resolution_factor * 2,
                                             self.resolution[1] * self.resolution_factor * 2,
                                             aperture_size,
                                            ).to(self.device) * 1.
    else:
        self.aperture = zero_pad(aperture).to(self.device) * 1.

set_laser_powers(laser_power)

Internal function to set the laser powers.

Parameters:

  • laser_power
               Laser powers.
    
Source code in odak/learn/wave/propagators.py
def set_laser_powers(self, laser_power):
    """
    Internal function to set the laser powers.

    Parameters
    -------
    laser_power      : torch.tensor
                       Laser powers.
    """
    self.channel_power = laser_power

calculate_amplitude(field)

Definition to calculate amplitude of a single or multiple given electric field(s).

Parameters:

  • field
           Electric fields or an electric field.
    

Returns:

  • amplitude ( float ) –

    Amplitude or amplitudes of electric field(s).

Source code in odak/learn/wave/util.py
def calculate_amplitude(field):
    """ 
    Definition to calculate amplitude of a single or multiple given electric field(s).

    Parameters
    ----------
    field        : torch.cfloat
                   Electric fields or an electric field.

    Returns
    -------
    amplitude    : torch.float
                   Amplitude or amplitudes of electric field(s).
    """
    amplitude = torch.abs(field)
    return amplitude

calculate_phase(field, deg=False)

Definition to calculate phase of a single or multiple given electric field(s).

Parameters:

  • field
           Electric fields or an electric field.
    
  • deg
           If set True, the angles will be returned in degrees.
    

Returns:

  • phase ( float ) –

    Phase or phases of electric field(s) in radians.

Source code in odak/learn/wave/util.py
def calculate_phase(field, deg = False):
    """ 
    Definition to calculate phase of a single or multiple given electric field(s).

    Parameters
    ----------
    field        : torch.cfloat
                   Electric fields or an electric field.
    deg          : bool
                   If set True, the angles will be returned in degrees.

    Returns
    -------
    phase        : torch.float
                   Phase or phases of electric field(s) in radians.
    """
    phase = field.imag.atan2(field.real)
    if deg:
        phase *= 180. / torch.pi
    return phase

generate_complex_field(amplitude, phase)

Definition to generate a complex field with a given amplitude and phase.

Parameters:

  • amplitude
                Amplitude of the field.
                The expected size is [m x n] or [1 x m x n].
    
  • phase
                Phase of the field.
                The expected size is [m x n] or [1 x m x n].
    

Returns:

  • field ( ndarray ) –

    Complex field. Depending on the input, the expected size is [m x n] or [1 x m x n].

Source code in odak/learn/wave/util.py
def generate_complex_field(amplitude, phase):
    """
    Definition to generate a complex field with a given amplitude and phase.

    Parameters
    ----------
    amplitude         : torch.tensor
                        Amplitude of the field.
                        The expected size is [m x n] or [1 x m x n].
    phase             : torch.tensor
                        Phase of the field.
                        The expected size is [m x n] or [1 x m x n].

    Returns
    -------
    field             : ndarray
                        Complex field.
                        Depending on the input, the expected size is [m x n] or [1 x m x n].
    """
    field = amplitude * torch.cos(phase) + 1j * amplitude * torch.sin(phase)
    return field

set_amplitude(field, amplitude)

Definition to keep phase as is and change the amplitude of a given field.

Parameters:

  • field
           Complex field.
    
  • amplitude
           Amplitudes.
    

Returns:

  • new_field ( cfloat ) –

    Complex field.

Source code in odak/learn/wave/util.py
def set_amplitude(field, amplitude):
    """
    Definition to keep phase as is and change the amplitude of a given field.

    Parameters
    ----------
    field        : torch.cfloat
                   Complex field.
    amplitude    : torch.cfloat or torch.float
                   Amplitudes.

    Returns
    -------
    new_field    : torch.cfloat
                   Complex field.
    """
    amplitude = calculate_amplitude(amplitude)
    phase = calculate_phase(field)
    new_field = amplitude * torch.cos(phase) + 1j * amplitude * torch.sin(phase)
    return new_field

wavenumber(wavelength)

Definition for calculating the wavenumber of a plane wave.

Parameters:

  • wavelength
           Wavelength of a wave in mm.
    

Returns:

  • k ( float ) –

    Wave number for a given wavelength.

Source code in odak/learn/wave/util.py
def wavenumber(wavelength):
    """
    Definition for calculating the wavenumber of a plane wave.

    Parameters
    ----------
    wavelength   : float
                   Wavelength of a wave in mm.

    Returns
    -------
    k            : float
                   Wave number for a given wavelength.
    """
    k = 2 * torch.pi / wavelength
    return k