Skip to content

odak.learn.lensless

spec_track

Bases: Module

The learned holography model used in the paper, Ziyang Chen and Mustafa Dogan and Josef Spjut and Kaan Akşit. "SpecTrack: Learned Multi-Rotation Tracking via Speckle Imaging." In SIGGRAPH Asia 2024 Posters (SA Posters '24).

This model performs multi-rotation tracking via speckle imaging using a deep convolutional neural network architecture.

Parameters:

  • reduction (str, default: 'sum' ) –

    Reduction method for torch.nn.MSELoss and torch.nn.L1Loss. Default is 'sum'.

  • device (device, default: device('cpu') ) –

    Device to run the model on. Default is CPU.

Source code in odak/learn/lensless/models.py
class spec_track(nn.Module):
    """
    The learned holography model used in the paper, Ziyang Chen and Mustafa Dogan and Josef Spjut and Kaan Akşit. "SpecTrack: Learned Multi-Rotation Tracking via Speckle Imaging." In SIGGRAPH Asia 2024 Posters (SA Posters '24).

    This model performs multi-rotation tracking via speckle imaging using a deep convolutional neural network architecture.

    Parameters
    ----------
    reduction : str, optional
        Reduction method for torch.nn.MSELoss and torch.nn.L1Loss. Default is 'sum'.
    device : torch.device, optional
        Device to run the model on. Default is CPU.
    """

    def __init__(self, reduction="sum", device=torch.device("cpu")):
        super(spec_track, self).__init__()
        self.device = device
        self.init_layers()
        self.reduction = reduction
        self.l2 = torch.nn.MSELoss(reduction=self.reduction)
        self.l1 = torch.nn.L1Loss(reduction=self.reduction)
        self.train_history = []
        self.validation_history = []

    def init_layers(self):
        """
        Initialize the layers of the network.

        The network architecture consists of:
        - Three convolutional layers with batch normalization and ReLU activation
        - Three max pooling layers
        - Five fully connected layers ending with a 3-dimensional output
        """
        # Convolutional layers with batch normalization and pooling
        self.network = nn.Sequential(
            OrderedDict(
                [
                    ("conv1", nn.Conv2d(5, 32, kernel_size=3, padding=1)),
                    ("bn1", nn.BatchNorm2d(32)),
                    ("relu1", nn.ReLU()),
                    ("pool1", nn.MaxPool2d(kernel_size=3)),
                    ("conv2", nn.Conv2d(32, 64, kernel_size=5, padding=1)),
                    ("bn2", nn.BatchNorm2d(64)),
                    ("relu2", nn.ReLU()),
                    ("pool2", nn.MaxPool2d(kernel_size=3)),
                    ("conv3", nn.Conv2d(64, 128, kernel_size=7, padding=1)),
                    ("bn3", nn.BatchNorm2d(128)),
                    ("relu3", nn.ReLU()),
                    ("pool3", nn.MaxPool2d(kernel_size=3)),
                    ("flatten", nn.Flatten()),
                    ("fc1", nn.Linear(6400, 2048)),
                    ("fc_bn1", nn.BatchNorm1d(2048)),
                    ("relu_fc1", nn.ReLU()),
                    ("fc2", nn.Linear(2048, 1024)),
                    ("fc_bn2", nn.BatchNorm1d(1024)),
                    ("relu_fc2", nn.ReLU()),
                    ("fc3", nn.Linear(1024, 512)),
                    ("fc_bn3", nn.BatchNorm1d(512)),
                    ("relu_fc3", nn.ReLU()),
                    ("fc4", nn.Linear(512, 128)),
                    ("fc_bn4", nn.BatchNorm1d(128)),
                    ("relu_fc4", nn.ReLU()),
                    ("fc5", nn.Linear(128, 3)),
                ]
            )
        ).to(self.device)

    def forward(self, x):
        """
        Forward pass of the network.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch_size, 5, height, width).

        Returns
        -------
        torch.Tensor
            Output tensor of shape (batch_size, 3) representing the predicted rotation angles.
        """
        return self.network(x)

    def evaluate(self, input_data, ground_truth, weights=[100.0, 1.0]):
        """
        Evaluate the model's performance using weighted L1 and L2 losses.

        Parameters
        ----------
        input_data : torch.Tensor
            Predicted data from the model.
        ground_truth : torch.Tensor
            Ground truth data.
        weights : list, optional
            Weights for L2 and L1 losses. Default is [100.0, 1.0].

        Returns
        -------
        torch.Tensor
            Combined weighted loss value.
        """
        loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(
            input_data, ground_truth
        )
        return loss

    def fit(
        self,
        trainloader,
        testloader,
        number_of_epochs=100,
        learning_rate=1e-5,
        weight_decay=1e-5,
        directory="./output",
    ):
        """
        Train the model using the provided data loaders.

        Parameters
        ----------
        trainloader : torch.utils.data.DataLoader
            Training data loader.
        testloader : torch.utils.data.DataLoader
            Testing data loader.
        number_of_epochs : int, optional
            Number of epochs to train for. Default is 100.
        learning_rate : float, optional
            Learning rate for the optimizer. Default is 1e-5.
        weight_decay : float, optional
            Weight decay for the optimizer. Default is 1e-5.
        directory : str, optional
            Directory to save the model weights and logs. Default is './output'.

        Raises
        ------
        ValueError    : If directory path contains dangerous patterns (traversal, null bytes, etc.).
        TypeError     : If directory is not a string.
        """
        safe_directory = validate_path(directory)
        check_directory(safe_directory, validate=True)
        check_directory(join(safe_directory, "log"), validate=True)

        self.optimizer = torch.optim.Adam(
            self.parameters(), lr=learning_rate, weight_decay=weight_decay
        )
        best_val_loss = float("inf")

        for epoch in range(number_of_epochs):
            # Training phase
            self.train()
            train_loss = 0.0
            train_batches = 0
            train_pbar = tqdm(
                trainloader,
                desc=f"Epoch {epoch + 1}/{number_of_epochs} [Train]",
                leave=False,
                dynamic_ncols=True,
            )

            for batch, labels in train_pbar:
                self.optimizer.zero_grad()
                batch = batch.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                predicts = torch.squeeze(self.forward(batch))
                loss = self.evaluate(predicts, labels)
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()
                train_batches += 1
                train_pbar.set_postfix({"Loss": f"{loss.item():.4f}"})

            avg_train_loss = train_loss / train_batches
            self.train_history.append(avg_train_loss)

            # Validation phase
            self.eval()
            val_loss = 0.0
            val_batches = 0
            val_pbar = tqdm(
                testloader,
                desc=f"Epoch {epoch + 1}/{number_of_epochs} [Val]",
                leave=False,
                dynamic_ncols=True,
            )

            with torch.no_grad():
                for batch, labels in val_pbar:
                    batch = batch.to(self.device, non_blocking=True)
                    labels = labels.to(self.device, non_blocking=True)
                    predicts = torch.squeeze(self.forward(batch), dim=1)
                    loss = self.evaluate(predicts, labels)

                    val_loss += loss.item()
                    val_batches += 1
                    val_pbar.set_postfix({"Loss": f"{loss.item():.4f}"})

            avg_val_loss = val_loss / val_batches
            self.validation_history.append(avg_val_loss)

            # Print epoch summary
            print(
                f"Epoch {epoch + 1}/{number_of_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}"
            )

            # Save best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                self.save_weights(join(directory, f"best_model_epoch_{epoch + 1}.pt"))
                print(f"Best model saved at epoch {epoch + 1}")

        # Save training history with path validation
        train_log_path = validate_path(
            join(directory, "log", "train_log.pt"), allowed_extensions=[".pt"]
        )
        val_log_path = validate_path(
            join(directory, "log", "validation_log.pt"), allowed_extensions=[".pt"]
        )
        torch.save(self.train_history, train_log_path)
        torch.save(self.validation_history, val_log_path)
        print("Training completed. History saved.")

    def save_weights(self, filename="./weights.pt"):
        """
        Save the current weights of the network to a file.

        Parameters
        ----------
        filename : str, optional
            Path to save the weights. Default is './weights.pt'.

        Raises
        ------
        ValueError    : If path validation fails or extension is not allowed.
        TypeError     : If filename is not a string.
        """
        safe_path = validate_path(filename, allowed_extensions=[".pt", ".pth"])
        torch.save(self.network.state_dict(), safe_path)

    def load_weights(self, filename="./weights.pt"):
        """
        Load weights for the network from a file.

        Parameters
        ----------
        filename : str, optional
            Path to load the weights from. Default is './weights.pt'.

        Raises
        ------
        ValueError    : If path validation fails or extension is not allowed.
        TypeError     : If filename is not a string.
        """
        safe_path = validate_path(filename, allowed_extensions=[".pt", ".pth"])
        self.network.load_state_dict(torch.load(safe_path, weights_only=True))
        self.network.eval()

evaluate(input_data, ground_truth, weights=[100.0, 1.0])

Evaluate the model's performance using weighted L1 and L2 losses.

Parameters:

  • input_data (Tensor) –

    Predicted data from the model.

  • ground_truth (Tensor) –

    Ground truth data.

  • weights (list, default: [100.0, 1.0] ) –

    Weights for L2 and L1 losses. Default is [100.0, 1.0].

Returns:

  • Tensor

    Combined weighted loss value.

Source code in odak/learn/lensless/models.py
def evaluate(self, input_data, ground_truth, weights=[100.0, 1.0]):
    """
    Evaluate the model's performance using weighted L1 and L2 losses.

    Parameters
    ----------
    input_data : torch.Tensor
        Predicted data from the model.
    ground_truth : torch.Tensor
        Ground truth data.
    weights : list, optional
        Weights for L2 and L1 losses. Default is [100.0, 1.0].

    Returns
    -------
    torch.Tensor
        Combined weighted loss value.
    """
    loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(
        input_data, ground_truth
    )
    return loss

fit(trainloader, testloader, number_of_epochs=100, learning_rate=1e-05, weight_decay=1e-05, directory='./output')

Train the model using the provided data loaders.

Parameters:

  • trainloader (DataLoader) –

    Training data loader.

  • testloader (DataLoader) –

    Testing data loader.

  • number_of_epochs (int, default: 100 ) –

    Number of epochs to train for. Default is 100.

  • learning_rate (float, default: 1e-05 ) –

    Learning rate for the optimizer. Default is 1e-5.

  • weight_decay (float, default: 1e-05 ) –

    Weight decay for the optimizer. Default is 1e-5.

  • directory (str, default: './output' ) –

    Directory to save the model weights and logs. Default is './output'.

Raises:

  • ValueError : If directory path contains dangerous patterns (traversal, null bytes, etc.).
  • TypeError : If directory is not a string.
Source code in odak/learn/lensless/models.py
def fit(
    self,
    trainloader,
    testloader,
    number_of_epochs=100,
    learning_rate=1e-5,
    weight_decay=1e-5,
    directory="./output",
):
    """
    Train the model using the provided data loaders.

    Parameters
    ----------
    trainloader : torch.utils.data.DataLoader
        Training data loader.
    testloader : torch.utils.data.DataLoader
        Testing data loader.
    number_of_epochs : int, optional
        Number of epochs to train for. Default is 100.
    learning_rate : float, optional
        Learning rate for the optimizer. Default is 1e-5.
    weight_decay : float, optional
        Weight decay for the optimizer. Default is 1e-5.
    directory : str, optional
        Directory to save the model weights and logs. Default is './output'.

    Raises
    ------
    ValueError    : If directory path contains dangerous patterns (traversal, null bytes, etc.).
    TypeError     : If directory is not a string.
    """
    safe_directory = validate_path(directory)
    check_directory(safe_directory, validate=True)
    check_directory(join(safe_directory, "log"), validate=True)

    self.optimizer = torch.optim.Adam(
        self.parameters(), lr=learning_rate, weight_decay=weight_decay
    )
    best_val_loss = float("inf")

    for epoch in range(number_of_epochs):
        # Training phase
        self.train()
        train_loss = 0.0
        train_batches = 0
        train_pbar = tqdm(
            trainloader,
            desc=f"Epoch {epoch + 1}/{number_of_epochs} [Train]",
            leave=False,
            dynamic_ncols=True,
        )

        for batch, labels in train_pbar:
            self.optimizer.zero_grad()
            batch = batch.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            predicts = torch.squeeze(self.forward(batch))
            loss = self.evaluate(predicts, labels)
            loss.backward()
            self.optimizer.step()

            train_loss += loss.item()
            train_batches += 1
            train_pbar.set_postfix({"Loss": f"{loss.item():.4f}"})

        avg_train_loss = train_loss / train_batches
        self.train_history.append(avg_train_loss)

        # Validation phase
        self.eval()
        val_loss = 0.0
        val_batches = 0
        val_pbar = tqdm(
            testloader,
            desc=f"Epoch {epoch + 1}/{number_of_epochs} [Val]",
            leave=False,
            dynamic_ncols=True,
        )

        with torch.no_grad():
            for batch, labels in val_pbar:
                batch = batch.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                predicts = torch.squeeze(self.forward(batch), dim=1)
                loss = self.evaluate(predicts, labels)

                val_loss += loss.item()
                val_batches += 1
                val_pbar.set_postfix({"Loss": f"{loss.item():.4f}"})

        avg_val_loss = val_loss / val_batches
        self.validation_history.append(avg_val_loss)

        # Print epoch summary
        print(
            f"Epoch {epoch + 1}/{number_of_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}"
        )

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            self.save_weights(join(directory, f"best_model_epoch_{epoch + 1}.pt"))
            print(f"Best model saved at epoch {epoch + 1}")

    # Save training history with path validation
    train_log_path = validate_path(
        join(directory, "log", "train_log.pt"), allowed_extensions=[".pt"]
    )
    val_log_path = validate_path(
        join(directory, "log", "validation_log.pt"), allowed_extensions=[".pt"]
    )
    torch.save(self.train_history, train_log_path)
    torch.save(self.validation_history, val_log_path)
    print("Training completed. History saved.")

forward(x)

Forward pass of the network.

Parameters:

  • x (Tensor) –

    Input tensor of shape (batch_size, 5, height, width).

Returns:

  • Tensor

    Output tensor of shape (batch_size, 3) representing the predicted rotation angles.

Source code in odak/learn/lensless/models.py
def forward(self, x):
    """
    Forward pass of the network.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (batch_size, 5, height, width).

    Returns
    -------
    torch.Tensor
        Output tensor of shape (batch_size, 3) representing the predicted rotation angles.
    """
    return self.network(x)

init_layers()

Initialize the layers of the network.

The network architecture consists of: - Three convolutional layers with batch normalization and ReLU activation - Three max pooling layers - Five fully connected layers ending with a 3-dimensional output

Source code in odak/learn/lensless/models.py
def init_layers(self):
    """
    Initialize the layers of the network.

    The network architecture consists of:
    - Three convolutional layers with batch normalization and ReLU activation
    - Three max pooling layers
    - Five fully connected layers ending with a 3-dimensional output
    """
    # Convolutional layers with batch normalization and pooling
    self.network = nn.Sequential(
        OrderedDict(
            [
                ("conv1", nn.Conv2d(5, 32, kernel_size=3, padding=1)),
                ("bn1", nn.BatchNorm2d(32)),
                ("relu1", nn.ReLU()),
                ("pool1", nn.MaxPool2d(kernel_size=3)),
                ("conv2", nn.Conv2d(32, 64, kernel_size=5, padding=1)),
                ("bn2", nn.BatchNorm2d(64)),
                ("relu2", nn.ReLU()),
                ("pool2", nn.MaxPool2d(kernel_size=3)),
                ("conv3", nn.Conv2d(64, 128, kernel_size=7, padding=1)),
                ("bn3", nn.BatchNorm2d(128)),
                ("relu3", nn.ReLU()),
                ("pool3", nn.MaxPool2d(kernel_size=3)),
                ("flatten", nn.Flatten()),
                ("fc1", nn.Linear(6400, 2048)),
                ("fc_bn1", nn.BatchNorm1d(2048)),
                ("relu_fc1", nn.ReLU()),
                ("fc2", nn.Linear(2048, 1024)),
                ("fc_bn2", nn.BatchNorm1d(1024)),
                ("relu_fc2", nn.ReLU()),
                ("fc3", nn.Linear(1024, 512)),
                ("fc_bn3", nn.BatchNorm1d(512)),
                ("relu_fc3", nn.ReLU()),
                ("fc4", nn.Linear(512, 128)),
                ("fc_bn4", nn.BatchNorm1d(128)),
                ("relu_fc4", nn.ReLU()),
                ("fc5", nn.Linear(128, 3)),
            ]
        )
    ).to(self.device)

load_weights(filename='./weights.pt')

Load weights for the network from a file.

Parameters:

  • filename (str, default: './weights.pt' ) –

    Path to load the weights from. Default is './weights.pt'.

Raises:

  • ValueError : If path validation fails or extension is not allowed.
  • TypeError : If filename is not a string.
Source code in odak/learn/lensless/models.py
def load_weights(self, filename="./weights.pt"):
    """
    Load weights for the network from a file.

    Parameters
    ----------
    filename : str, optional
        Path to load the weights from. Default is './weights.pt'.

    Raises
    ------
    ValueError    : If path validation fails or extension is not allowed.
    TypeError     : If filename is not a string.
    """
    safe_path = validate_path(filename, allowed_extensions=[".pt", ".pth"])
    self.network.load_state_dict(torch.load(safe_path, weights_only=True))
    self.network.eval()

save_weights(filename='./weights.pt')

Save the current weights of the network to a file.

Parameters:

  • filename (str, default: './weights.pt' ) –

    Path to save the weights. Default is './weights.pt'.

Raises:

  • ValueError : If path validation fails or extension is not allowed.
  • TypeError : If filename is not a string.
Source code in odak/learn/lensless/models.py
def save_weights(self, filename="./weights.pt"):
    """
    Save the current weights of the network to a file.

    Parameters
    ----------
    filename : str, optional
        Path to save the weights. Default is './weights.pt'.

    Raises
    ------
    ValueError    : If path validation fails or extension is not allowed.
    TypeError     : If filename is not a string.
    """
    safe_path = validate_path(filename, allowed_extensions=[".pt", ".pth"])
    torch.save(self.network.state_dict(), safe_path)