Source code for equiadapt.pointcloud.canonicalization_networks.vector_neuron_layers

"""
Layers for vector neuron networks

Taken from Vector Neurons: A General Framework for SO(3)-Equivariant Networks (https://arxiv.org/abs/2104.12229) paper and
their codebase https://github.com/FlyingGiraffe/vnn
"""

from typing import Tuple

import torch
import torch.nn as nn

EPS = 1e-6


[docs] class VNLinear(nn.Module): """ Vector Neuron Linear layer. This layer applies a linear transformation to the input tensor. Methods: __init__: Initializes the VNLinear layer. forward: Performs forward pass of the VNLinear layer. """
[docs] def __init__(self, in_channels: int, out_channels: int): """ Initializes a VNLinear layer. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. """ super().__init__() self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs forward pass of the VNLinear layer. Args: x (torch.Tensor): Input tensor of shape [B, N_feat, 3, N_samples, ...]. Returns: torch.Tensor: Output tensor of shape [B, N_feat, 3, N_samples, ...]. """ x_out = self.map_to_feat(x.transpose(1, -1)).transpose(1, -1) return x_out
[docs] class VNBilinear(nn.Module): """ Vector Neuron Bilinear layer. VNBilinear applies a bilinear layer to the input features. Methods: __init__: Initializes the VNBilinear layer. forward: Performs forward pass of the VNBilinear layer. """
[docs] def __init__(self, in_channels1: int, in_channels2: int, out_channels: int): """ Initializes the VNBilinear layer. Args: in_channels1 (int): Number of input channels for the first input. in_channels2 (int): Number of input channels for the second input. out_channels (int): Number of output channels. """ super().__init__() self.map_to_feat = nn.Bilinear( in_channels1, in_channels2, out_channels, bias=False )
[docs] def forward(self, x: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ Forward pass of the VNBilinear layer. Args: x (torch.Tensor): Input features of shape [B, N_feat, 3, N_samples, ...]. labels (torch.Tensor): Labels of shape [B, N_feat, N_samples]. Returns: torch.Tensor: Output features after applying the bilinear transformation. """ labels = labels.repeat(1, x.shape[2], 1).float() x_out = self.map_to_feat(x.transpose(1, -1), labels).transpose(1, -1) return x_out
[docs] class VNSoftplus(nn.Module): """ Vector Neuron Softplus layer. VNSoftplus applies a softplus activation to the input features. Methods: __init__: Initializes the VNSoftplus layer. forward: Performs forward pass of the VNSoftplus layer. """
[docs] def __init__( self, in_channels: int, share_nonlinearity: bool = False, negative_slope: float = 0.0, ): """ Initializes a VNSoftplus layer. Args: in_channels (int): Number of input channels. share_nonlinearity (bool): Whether to share the nonlinearity across channels. negative_slope (float): Negative slope parameter for the LeakyReLU activation. """ super().__init__() if share_nonlinearity: self.map_to_dir = nn.Linear(in_channels, 1, bias=False) else: self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) self.negative_slope = negative_slope
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs forward pass of the VNSoftplus layer. Args: x (torch.Tensor): Input tensor of shape [B, N_feat, 3, N_samples, ...]. Returns: torch.Tensor: Output tensor of shape [B, N_feat, 3, N_samples, ...]. """ d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) dotprod = (x * d).sum(2, keepdim=True) angle_between = torch.acos( dotprod / ( torch.norm(x, dim=2, keepdim=True) * torch.norm(d, dim=2, keepdim=True) + EPS ) ) mask = torch.cos(angle_between / 2) ** 2 d_norm_sq = (d * d).sum(2, keepdim=True) x_out = self.negative_slope * x + (1 - self.negative_slope) * ( mask * x + (1 - mask) * (x - (dotprod / (d_norm_sq + EPS)) * d) ) return x_out
[docs] class VNLeakyReLU(nn.Module): """ Vector Neuron Leaky ReLU layer. VNLLeakyReLU applies a LeakyReLU activation to the input features. Methods: __init__: Initializes the VNLeakyReLU layer. forward: Performs forward pass of the VNLeakyReLU layer. """
[docs] def __init__( self, in_channels: int, share_nonlinearity: bool = False, negative_slope: float = 0.2, ): """ Vector Neuron Leaky ReLU (VNLeakyReLU) module. Args: in_channels (int): Number of input channels. share_nonlinearity (bool, optional): Whether to share the nonlinearity across channels. If True, a single linear layer is used to compute the direction. If False, a separate linear layer is used for each channel. Defaults to False. negative_slope (float, optional): Negative slope of the Leaky ReLU activation. Defaults to 0.2. """ super().__init__() if share_nonlinearity: self.map_to_dir = nn.Linear(in_channels, 1, bias=False) else: self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) self.negative_slope = negative_slope
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the VNLeakyReLU module. Args: x (torch.Tensor): Input tensor of shape [B, N_feat, 3, N_samples, ...]. Returns: torch.Tensor: Output tensor after applying VNLeakyReLU activation. """ d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) dotprod = (x * d).sum(2, keepdim=True) mask = (dotprod >= 0).float() d_norm_sq = (d * d).sum(2, keepdim=True) x_out = self.negative_slope * x + (1 - self.negative_slope) * ( mask * x + (1 - mask) * (x - (dotprod / (d_norm_sq + EPS)) * d) ) return x_out
[docs] class VNLinearLeakyReLU(nn.Module): """ Vector Neuron Linear Leaky ReLU layer. VNLinearLeakyReLU applies a linear transformation followed by a LeakyReLU activation to the input features. Methods: __init__: Initializes the VNLinearLeakyReLU layer. forward: Performs forward pass of the VNLinearLeakyReLU layer. """
[docs] def __init__( self, in_channels: int, out_channels: int, dim: int = 5, share_nonlinearity: bool = False, negative_slope: float = 0.2, ): """ Vector Neuron Linear Leaky ReLU layer. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. dim (int, optional): Dimension of the input features. Defaults to 5. share_nonlinearity (bool, optional): Whether to share the nonlinearity across channels. Defaults to False. negative_slope (float, optional): Negative slope of the LeakyReLU activation. Defaults to 0.2. """ super(VNLinearLeakyReLU, self).__init__() self.dim = dim self.negative_slope = negative_slope self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False) self.batchnorm = VNBatchNorm(out_channels, dim=dim) if share_nonlinearity: self.map_to_dir = nn.Linear(in_channels, 1, bias=False) else: self.map_to_dir = nn.Linear(in_channels, out_channels, bias=False)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the VNLinearLeakyReLU layer. Args: x (torch.Tensor): Input tensor of shape [B, N_feat, 3, N_samples, ...] Returns: torch.Tensor: Output tensor of shape [B, N_feat, 3, N_samples, ...] """ # Linear p = self.map_to_feat(x.transpose(1, -1)).transpose(1, -1) # BatchNorm p = self.batchnorm(p) # LeakyReLU d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) dotprod = (p * d).sum(2, keepdims=True) mask = (dotprod >= 0).float() d_norm_sq = (d * d).sum(2, keepdims=True) x_out = self.negative_slope * p + (1 - self.negative_slope) * ( mask * p + (1 - mask) * (p - (dotprod / (d_norm_sq + EPS)) * d) ) return x_out
[docs] class VNBatchNorm(nn.Module): """ Vector Neuron Batch Normalization layer. VNBatchNorm applies batch normalization to the input features. Methods: __init__: Initializes the VNBatchNorm layer. forward: Performs forward pass of the VNBatchNorm layer. """
[docs] def __init__(self, num_features: int, dim: int): """ Vector Neuron Batch Normalization layer. Args: num_features (int): Number of input features. dim (int): Dimensionality of the input tensor. """ super(VNBatchNorm, self).__init__() self.dim = dim if dim == 3 or dim == 4: self.bn1d = nn.BatchNorm1d(num_features) elif dim == 5: self.bn2d = nn.BatchNorm2d(num_features)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the Vector Neuron Batch Normalization layer. Args: x (torch.Tensor): Input tensor of shape [B, N_feat, 3, N_samples, ...]. Returns: torch.Tensor: Output tensor after applying batch normalization. """ # norm = torch.sqrt((x*x).sum(2)) norm = torch.norm(x, dim=2) + EPS if self.dim == 3 or self.dim == 4: norm_bn = self.bn1d(norm) elif self.dim == 5: norm_bn = self.bn2d(norm) norm = norm.unsqueeze(2) norm_bn = norm_bn.unsqueeze(2) x = x / norm * norm_bn return x
[docs] class VNMaxPool(nn.Module): """ Vector Neuron Max Pooling layer. VNMaxPool applies max pooling to the input features. Methods: __init__: Initializes the VNMaxPool layer. forward: Performs forward pass of the VNMaxPool layer. """
[docs] def __init__(self, in_channels: int): """ Initializes a VNMaxPool layer. Args: in_channels (int): The number of input channels. """ super(VNMaxPool, self).__init__() self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs vector neuron max pooling on the input tensor. Args: x (torch.Tensor): Point features of shape [B, N_feat, 3, N_samples, ...]. Returns: torch.Tensor: Max pooled tensor of shape [B, N_feat, 3, N_samples, ...]. """ d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) dotprod = (x * d).sum(2, keepdims=True) idx = dotprod.max(dim=-1, keepdim=False)[1] index_tuple = torch.meshgrid([torch.arange(j) for j in x.size()[:-1]]) + (idx,) x_max = x[index_tuple] return x_max
[docs] def mean_pool(x: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor: """ Compute the mean pooling of a tensor along a specified dimension. Args: x (torch.Tensor): The input tensor. dim (int, optional): The dimension along which to compute the mean pooling. Default is -1. keepdim (bool, optional): Whether to keep the dimension of the input tensor. Default is False. Returns: torch.Tensor: The mean pooled tensor. """ return x.mean(dim=dim, keepdim=keepdim)
[docs] class VNStdFeature(nn.Module): """ Vector Neuron Standard Feature module. This module performs standard feature extraction using Vector Neuron layers. It takes point features as input and applies a series of VNLinearLeakyReLU layers followed by a linear layer to produce the standard features. Attributes: dim (int): Dimension of the input features. normalize_frame (bool): Whether to normalize the frame. Methods: __init__: Initializes the VNStdFeature module. forward: Performs forward pass of the VNStdFeature module. Shape: - Input: (B, N_feat, 3, N_samples, ...) - Output: - x_std: (B, N_feat, dim, N_samples, ...) - z0: (B, dim, 3) Example: >>> model = VNStdFeature(in_channels=64, dim=4, normalize_frame=True) >>> input = torch.randn(2, 64, 3, 100) >>> output, frame_vectors = model(input) """
[docs] def __init__( self, in_channels: int, dim: int = 4, normalize_frame: bool = False, share_nonlinearity: bool = False, negative_slope: float = 0.2, ): """ Initializes the VNStdFeature layer. Args: in_channels (int): Number of input channels. dim (int, optional): Dimension of the input feature. Defaults to 4. normalize_frame (bool, optional): Whether to normalize the frame. Defaults to False. share_nonlinearity (bool, optional): Whether to share the nonlinearity across layers. Defaults to False. negative_slope (float, optional): Negative slope of the LeakyReLU activation function. Defaults to 0.2. """ super(VNStdFeature, self).__init__() self.dim = dim self.normalize_frame = normalize_frame self.vn1 = VNLinearLeakyReLU( in_channels, in_channels // 2, dim=dim, share_nonlinearity=share_nonlinearity, negative_slope=negative_slope, ) self.vn2 = VNLinearLeakyReLU( in_channels // 2, in_channels // 4, dim=dim, share_nonlinearity=share_nonlinearity, negative_slope=negative_slope, ) if normalize_frame: self.vn_lin = nn.Linear(in_channels // 4, 2, bias=False) else: self.vn_lin = nn.Linear(in_channels // 4, 3, bias=False)
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass of the VNStdFeature module. Args: x (torch.Tensor): Input point features of shape (B, N_feat, 3, N_samples, ...). Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple containing the standard features and the frame vectors. Note: - The frame vectors are computed only if `normalize_frame` is set to True. - The shape of the standard features depends on the value of `dim`. """ z0 = x z0 = self.vn1(z0) z0 = self.vn2(z0) z0 = self.vn_lin(z0.transpose(1, -1)).transpose(1, -1) if self.normalize_frame: v1 = z0[:, 0, :] v1_norm = torch.sqrt((v1 * v1).sum(1, keepdims=True)) # ignore type u1 = v1 / (v1_norm + EPS) v2 = z0[:, 1, :] v2 = v2 - (v2 * u1).sum(1, keepdims=True) * u1 # ignore type v2_norm = torch.sqrt((v2 * v2).sum(1, keepdims=True)) # ignore type u2 = v2 / (v2_norm + EPS) u3 = torch.cross(u1, u2) z0 = torch.stack([u1, u2, u3], dim=1).transpose(1, 2) else: z0 = z0.transpose(1, 2) if self.dim == 4: x_std = torch.einsum("bijm,bjkm->bikm", x, z0) elif self.dim == 3: x_std = torch.einsum("bij,bjk->bik", x, z0) elif self.dim == 5: x_std = torch.einsum("bijmn,bjkmn->bikmn", x, z0) return x_std, z0