equiadapt package
Subpackages
- equiadapt.common package
- Submodules
- equiadapt.common.basecanonicalization module
BaseCanonicalization
ContinuousGroupCanonicalization
ContinuousGroupCanonicalization.canonicalization_network
ContinuousGroupCanonicalization.beta
ContinuousGroupCanonicalization.__init__()
ContinuousGroupCanonicalization.canonicalizationnetworkout_to_groupelement()
ContinuousGroupCanonicalization.canonicalize()
ContinuousGroupCanonicalization.invert_canonicalization()
ContinuousGroupCanonicalization.get_prior_regularization_loss()
ContinuousGroupCanonicalization.get_identity_metric()
ContinuousGroupCanonicalization.canonicalizationnetworkout_to_groupelement()
ContinuousGroupCanonicalization.canonicalize()
ContinuousGroupCanonicalization.get_identity_metric()
ContinuousGroupCanonicalization.get_prior_regularization_loss()
ContinuousGroupCanonicalization.invert_canonicalization()
DiscreteGroupCanonicalization
DiscreteGroupCanonicalization.canonicalization_network
DiscreteGroupCanonicalization.beta
DiscreteGroupCanonicalization.gradient_trick
DiscreteGroupCanonicalization.__init__()
DiscreteGroupCanonicalization.groupactivations_to_groupelementonehot()
DiscreteGroupCanonicalization.canonicalize()
DiscreteGroupCanonicalization.invert_canonicalization()
DiscreteGroupCanonicalization.get_prior_regularization_loss()
DiscreteGroupCanonicalization.get_identity_metric()
DiscreteGroupCanonicalization.canonicalize()
DiscreteGroupCanonicalization.get_identity_metric()
DiscreteGroupCanonicalization.get_prior_regularization_loss()
DiscreteGroupCanonicalization.groupactivations_to_groupelementonehot()
DiscreteGroupCanonicalization.invert_canonicalization()
IdentityCanonicalization
IdentityCanonicalization.canonicalization_network
IdentityCanonicalization.__init__()
IdentityCanonicalization.canonicalize()
IdentityCanonicalization.canonicalize()
IdentityCanonicalization.get_identity_metric()
IdentityCanonicalization.get_prior_regularization_loss()
IdentityCanonicalization.invert_canonicalization()
- equiadapt.common.utils module
- Module contents
BaseCanonicalization
ContinuousGroupCanonicalization
ContinuousGroupCanonicalization.canonicalization_network
ContinuousGroupCanonicalization.beta
ContinuousGroupCanonicalization.__init__()
ContinuousGroupCanonicalization.canonicalizationnetworkout_to_groupelement()
ContinuousGroupCanonicalization.canonicalize()
ContinuousGroupCanonicalization.invert_canonicalization()
ContinuousGroupCanonicalization.get_prior_regularization_loss()
ContinuousGroupCanonicalization.get_identity_metric()
ContinuousGroupCanonicalization.canonicalizationnetworkout_to_groupelement()
ContinuousGroupCanonicalization.canonicalize()
ContinuousGroupCanonicalization.get_identity_metric()
ContinuousGroupCanonicalization.get_prior_regularization_loss()
ContinuousGroupCanonicalization.invert_canonicalization()
DiscreteGroupCanonicalization
DiscreteGroupCanonicalization.canonicalization_network
DiscreteGroupCanonicalization.beta
DiscreteGroupCanonicalization.gradient_trick
DiscreteGroupCanonicalization.__init__()
DiscreteGroupCanonicalization.groupactivations_to_groupelementonehot()
DiscreteGroupCanonicalization.canonicalize()
DiscreteGroupCanonicalization.invert_canonicalization()
DiscreteGroupCanonicalization.get_prior_regularization_loss()
DiscreteGroupCanonicalization.get_identity_metric()
DiscreteGroupCanonicalization.canonicalize()
DiscreteGroupCanonicalization.get_identity_metric()
DiscreteGroupCanonicalization.get_prior_regularization_loss()
DiscreteGroupCanonicalization.groupactivations_to_groupelementonehot()
DiscreteGroupCanonicalization.invert_canonicalization()
IdentityCanonicalization
IdentityCanonicalization.canonicalization_network
IdentityCanonicalization.__init__()
IdentityCanonicalization.canonicalize()
IdentityCanonicalization.canonicalize()
IdentityCanonicalization.get_identity_metric()
IdentityCanonicalization.get_prior_regularization_loss()
IdentityCanonicalization.invert_canonicalization()
LieParameterization
gram_schmidt()
- equiadapt.images package
- Subpackages
- equiadapt.images.canonicalization package
- equiadapt.images.canonicalization_networks package
- Submodules
- equiadapt.images.canonicalization_networks.custom_equivariant_networks module
- equiadapt.images.canonicalization_networks.custom_group_equivariant_layers module
- equiadapt.images.canonicalization_networks.custom_nonequivariant_networks module
- equiadapt.images.canonicalization_networks.escnn_networks module
- Module contents
- Submodules
- equiadapt.images.utils module
- Module contents
ContinuousGroupImageCanonicalization
ContinuousGroupImageCanonicalization.__init__()
ContinuousGroupImageCanonicalization.get_rotation_matrix_from_vector()
ContinuousGroupImageCanonicalization.get_groupelement()
ContinuousGroupImageCanonicalization.transformations_before_canonicalization_network_forward()
ContinuousGroupImageCanonicalization.get_group_from_out_vectors()
ContinuousGroupImageCanonicalization.canonicalize()
ContinuousGroupImageCanonicalization.invert_canonicalization()
ContinuousGroupImageCanonicalization.canonicalize()
ContinuousGroupImageCanonicalization.get_group_from_out_vectors()
ContinuousGroupImageCanonicalization.get_groupelement()
ContinuousGroupImageCanonicalization.invert_canonicalization()
ContinuousGroupImageCanonicalization.transformations_before_canonicalization_network_forward()
ConvNetwork
CustomEquivariantNetwork
DiscreteGroupImageCanonicalization
DiscreteGroupImageCanonicalization.__init__()
DiscreteGroupImageCanonicalization.groupactivations_to_groupelement()
DiscreteGroupImageCanonicalization.get_groupelement()
DiscreteGroupImageCanonicalization.transformations_before_canonicalization_network_forward()
DiscreteGroupImageCanonicalization.canonicalize()
DiscreteGroupImageCanonicalization.invert_canonicalization()
DiscreteGroupImageCanonicalization.canonicalize()
DiscreteGroupImageCanonicalization.get_group_activations()
DiscreteGroupImageCanonicalization.get_groupelement()
DiscreteGroupImageCanonicalization.groupactivations_to_groupelement()
DiscreteGroupImageCanonicalization.invert_canonicalization()
DiscreteGroupImageCanonicalization.transformations_before_canonicalization_network_forward()
ESCNNEquivariantNetwork
ESCNNSteerableNetwork
ESCNNWRNEquivariantNetwork
ESCNNWideBasic
ESCNNWideBottleneck
GroupEquivariantImageCanonicalization
OptimizedGroupEquivariantImageCanonicalization
OptimizedGroupEquivariantImageCanonicalization.__init__()
OptimizedGroupEquivariantImageCanonicalization.rotate_and_maybe_reflect()
OptimizedGroupEquivariantImageCanonicalization.group_augment()
OptimizedGroupEquivariantImageCanonicalization.get_group_activations()
OptimizedGroupEquivariantImageCanonicalization.get_optimization_specific_loss()
OptimizedGroupEquivariantImageCanonicalization.get_group_activations()
OptimizedGroupEquivariantImageCanonicalization.get_optimization_specific_loss()
OptimizedGroupEquivariantImageCanonicalization.group_augment()
OptimizedGroupEquivariantImageCanonicalization.rotate_and_maybe_reflect()
OptimizedSteerableImageCanonicalization
OptimizedSteerableImageCanonicalization.__init__()
OptimizedSteerableImageCanonicalization.get_rotation_matrix_from_vector()
OptimizedSteerableImageCanonicalization.group_augment()
OptimizedSteerableImageCanonicalization.get_groupelement()
OptimizedSteerableImageCanonicalization.get_optimization_specific_loss()
OptimizedSteerableImageCanonicalization.get_groupelement()
OptimizedSteerableImageCanonicalization.get_optimization_specific_loss()
OptimizedSteerableImageCanonicalization.get_rotation_matrix_from_vector()
OptimizedSteerableImageCanonicalization.group_augment()
ResNet18Network
RotationEquivariantConv
RotationEquivariantConvLift
RotoReflectionEquivariantConv
RotoReflectionEquivariantConvLift
SteerableImageCanonicalization
flip_boxes()
flip_masks()
get_action_on_image_features()
roll_by_gather()
rotate_boxes()
rotate_masks()
rotate_points()
- Subpackages
- equiadapt.nbody package
- equiadapt.pointcloud package
- Subpackages
- Module contents
Module contents
- class equiadapt.BaseCanonicalization(canonicalization_network: Module)[source]
Bases:
Module
This is the base class for canonicalization.
This class is used as a base for all canonicalization methods. Subclasses should implement the canonicalize method to define the specific canonicalization process.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List] [source]
This method takes an input data with, optionally, targets that need to be canonicalized
- Parameters:
x – input data
targets – (optional) additional targets that need to be canonicalized, such as boxes for promptable instance segmentation
**kwargs – additional arguments
- Returns:
the canonicalized version of the data and targets
- forward(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List] [source]
Forward method for the canonicalization which takes the input data and returns the canonicalized version of the data
- Parameters:
x – input data
targets – (optional) additional targets that need to be canonicalized, such as boxes for promptable instance segmentation
**kwargs – additional arguments
- Returns:
canonicalized version of the input data
- Return type:
canonicalized_x
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor [source]
This method takes the output of the canonicalized data and returns the output for the original data orientation
- Parameters:
canonicalized_outputs – output of the prediction network for canonicalized data
- Returns:
- output of the prediction network for the original data orientation,
by using the group element used to canonicalize the original data
- Return type:
outputs
- class equiadapt.ContinuousGroupCanonicalization(canonicalization_network: Module, beta: float = 1.0)[source]
Bases:
BaseCanonicalization
This class represents a continuous group canonicalization method.
Continuous group canonicalization is a method that transforms the input data into a canonical form using a continuous group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for continuous group canonicalization.
- canonicalization_network
The network used for canonicalization.
- Type:
torch.nn.Module
- canonicalizationnetworkout_to_groupelement()[source]
Converts the output of the canonicalization network to a group element in a differentiable manner.
- canonicalizationnetworkout_to_groupelement(group_activations: Tensor) Tensor [source]
Converts the output of the canonicalization network to a group element in a differentiable manner.
- Parameters:
group_activations (torch.Tensor) – The activations for each group element.
- Returns:
The group element.
- Return type:
torch.Tensor
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List] [source]
Canonicalizes the input data.
- Parameters:
x (torch.Tensor) – The input data.
targets (List, optional) – Additional targets that need to be canonicalized.
**kwargs – Additional arguments.
- Returns:
The canonicalized input data and targets. Simultaneously, it updates a dictionary containing the information about the canonicalization.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, List]]
- get_identity_metric() Tensor [source]
Gets the identity metric.
The identity metric is calculated as 1 minus the mean of the mean squared error between the group element matrix representation and the identity matrix.
- Returns:
The identity metric.
- Return type:
torch.Tensor
- get_prior_regularization_loss() Tensor [source]
Gets the prior regularization loss.
The prior regularization loss is calculated as the mean squared error between the group element matrix representation and the identity matrix.
- Returns:
The prior regularization loss.
- Return type:
torch.Tensor
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor [source]
Inverts the canonicalization.
- Parameters:
x_canonicalized_out (torch.Tensor) – The canonicalized output.
**kwargs – Additional arguments.
- Returns:
The output for the original data orientation.
- Return type:
torch.Tensor
- class equiadapt.ContinuousGroupImageCanonicalization(canonicalization_network: Module, canonicalization_hyperparams: DictConfig, in_shape: tuple)[source]
Bases:
ContinuousGroupCanonicalization
This class represents a continuous group image canonicalization model.
The model is designed to be equivariant under a continuous group of transformations, which can include rotations and reflections. Other specific continuous group image canonicalization classes can be derived from this class.
- get_rotation_matrix_from_vector()
This method takes the input vector and returns the rotation matrix.
- transformations_before_canonicalization_network_forward()[source]
Applies transformations to the input image before forwarding it through the canonicalization network.
- get_group_from_out_vectors()[source]
This method takes the output of the canonicalization network and returns the group element.
- invert_canonicalization()[source]
Inverts the canonicalization process on the output of the canonicalized image.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List] [source]
This method takes an image as input and returns the canonicalized image
- Parameters:
x (torch.Tensor) – The input image.
targets (Optional[List]) – The targets, if any.
- Returns:
canonicalized image
- Return type:
torch.Tensor
- get_group_from_out_vectors(out_vectors: Tensor) Tuple[dict, Tensor] [source]
This method takes the output of the canonicalization network and returns the group element
- Parameters:
out_vectors (torch.Tensor) – output of the canonicalization network
- Returns:
group element torch.Tensor: group element representation
- Return type:
- get_groupelement(x: Tensor) dict [source]
This method takes the input image and maps it to the group element
- Parameters:
x (torch.Tensor) – input image
- Returns:
group element
- Return type:
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor [source]
Inverts the canonicalization process on the output of the canonicalized image.
- Parameters:
x_canonicalized_out (torch.Tensor) – The output of the canonicalized image.
- Returns:
The output corresponding to the original image.
- Return type:
torch.Tensor
- class equiadapt.ContinuousGroupPointcloudCanonicalization(canonicalization_network: Module, canonicalization_hyperparams: DictConfig)[source]
Bases:
ContinuousGroupCanonicalization
This class represents a continuous group point cloud canonicalization.
- Parameters:
canonicalization_network (torch.nn.Module) – The canonicalization network.
canonicalization_hyperparams (DictConfig) – The hyperparameters for canonicalization.
- device
The device on which the operations are performed.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List] [source]
This method takes an image as input and returns the canonicalized image.
- Parameters:
x (torch.Tensor) – The input point cloud.
targets (Optional[List]) – The list of targets (optional).
**kwargs (Any) – Additional keyword arguments.
- Returns:
The canonicalized point cloud.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, List]]
- get_groupelement(x: Tensor) dict [source]
This method takes the input image and maps it to the group element.
- Parameters:
x (torch.Tensor) – The input image.
- Returns:
The group element.
- Return type:
- Raises:
NotImplementedError – If the method is not implemented.
- class equiadapt.ConvNetwork(in_shape: tuple, out_channels: int, kernel_size: int, num_layers: int = 2, out_vector_size: int = 128)[source]
Bases:
Module
This class represents a convolutional neural network.
The network consists of a sequence of convolutional layers, each followed by batch normalization and a GELU activation function. The number of output channels of the convolutional layers increases after every third layer. The network ends with a fully connected layer.
- forward(x: Tensor) Tensor [source]
Performs a forward pass through the network.
- Parameters:
x (torch.Tensor) – The input data. It should have the shape (batch_size, in_channels, height, width).
- Returns:
The output of the network. It has the shape (batch_size, out_vector_size).
- Return type:
torch.Tensor
- class equiadapt.CustomEquivariantNetwork(in_shape: Tuple[int, int, int, int], out_channels: int, kernel_size: int, group_type: str = 'rotation', num_rotations: int = 4, num_layers: int = 1, device: str = 'cpu')[source]
Bases:
Module
This class represents a custom equivariant network.
The network is equivariant to a specified group, which can be either the rotation group or the roto-reflection group. The network consists of a sequence of equivariant convolutional layers, each followed by a ReLU activation function.
- class equiadapt.DiscreteGroupCanonicalization(canonicalization_network: Module, beta: float = 1.0, gradient_trick: str = 'straight_through')[source]
Bases:
BaseCanonicalization
This class represents a discrete group canonicalization method.
Discrete group canonicalization is a method that transforms the input data into a canonical form using a discrete group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for discrete group canonicalization.
- canonicalization_network
The network used for canonicalization.
- Type:
torch.nn.Module
- gradient_trick
The method used for backpropagation through the discrete operation. Defaults to “straight_through”.
- Type:
- groupactivations_to_groupelementonehot()[source]
Converts group activations to one-hot encoded group elements in a differentiable manner.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List] [source]
Canonicalizes the input data.
- Parameters:
x (torch.Tensor) – The input data.
targets (List, optional) – Additional targets that need to be canonicalized.
**kwargs – Additional arguments.
- Returns:
The canonicalized input data and targets. Simultaneously, it updates a dictionary containing the information about the canonicalization.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, List]]
- get_identity_metric() Tensor [source]
Gets the identity metric.
- Returns:
The identity metric.
- Return type:
torch.Tensor
- get_prior_regularization_loss() Tensor [source]
Gets the prior regularization loss.
- Returns:
The prior regularization loss.
- Return type:
torch.Tensor
- groupactivations_to_groupelementonehot(group_activations: Tensor) Tensor [source]
Converts group activations to one-hot encoded group elements in a differentiable manner.
- Parameters:
group_activations (torch.Tensor) – The activations for each group element.
- Returns:
The one-hot encoding of the group elements.
- Return type:
torch.Tensor
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor [source]
Inverts the canonicalization.
- Parameters:
x_canonicalized_out (torch.Tensor) – The canonicalized output.
**kwargs – Additional arguments.
- Returns:
The output for the original data orientation.
- Return type:
torch.Tensor
- class equiadapt.DiscreteGroupImageCanonicalization(canonicalization_network: Module, canonicalization_hyperparams: DictConfig, in_shape: tuple)[source]
Bases:
DiscreteGroupCanonicalization
This class represents a discrete group image canonicalization model.
The model is designed to be equivariant under a discrete group of transformations, which can include rotations and reflections. Other discrete group canonicalizers can be derived from this class.
- groupactivations_to_groupelement()[source]
Takes the activations for each group element as input and returns the group element.
- transformations_before_canonicalization_network_forward()[source]
Applies transformations to the input images before passing it through the canonicalization network.
- invert_canonicalization()[source]
Inverts the canonicalization of the output of the canonicalized image.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List] [source]
Canonicalizes the input images.
- Parameters:
x (torch.Tensor) – The input images.
targets (Optional[List], optional) – The targets for instance segmentation. Defaults to None.
**kwargs (Any) – Additional keyword arguments.
- Returns:
The canonicalized image, and optionally the targets.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, List]]
- get_group_activations(x: Tensor) Tensor [source]
Gets the group activations for the input images.
- Parameters:
x (torch.Tensor) – The input images.
- Returns:
The group activations.
- Return type:
torch.Tensor
- groupactivations_to_groupelement(group_activations: Tensor) dict [source]
This method takes the activations for each group element as input and returns the group element
- Parameters:
group_activations (torch.Tensor) – activations for each group element.
- Returns:
group element.
- Return type:
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor [source]
Inverts the canonicalization of the output of the canonicalized image.
- Parameters:
x_canonicalized_out (torch.Tensor) – The output of the canonicalized image.
**kwargs (Any) – Additional keyword arguments.
- Returns:
The output corresponding to the original image.
- Return type:
torch.Tensor
- class equiadapt.ESCNNEquivariantNetwork(in_shape: tuple, out_channels: int, kernel_size: int, group_type: str = 'rotation', num_rotations: int = 4, num_layers: int = 1)[source]
Bases:
Module
This class represents an Equivariant Convolutional Neural Network (Equivariant CNN).
The network is equivariant to a group of transformations, which can be either rotations or roto-reflections. The network consists of a sequence of equivariant convolutional layers, each followed by batch normalization, a ReLU activation function, and dropout. The number of output channels of the convolutional layers is the same for all layers.
- forward(x: Tensor) Tensor [source]
Performs a forward pass through the network.
- Parameters:
x (torch.Tensor) – The input data. It should have the shape (batch_size, in_channels, height, width).
- Returns:
The output of the network. It has the shape (batch_size, num_group_elements).
- Return type:
torch.Tensor
- class equiadapt.ESCNNSteerableNetwork(in_shape: tuple, out_channels: int, kernel_size: int = 9, group_type: str = 'rotation', num_layers: int = 1)[source]
Bases:
Module
This class represents a Steerable Equivariant Convolutional Neural Network (Equivariant CNN).
The network is equivariant under all planar rotations. The network consists of a sequence of equivariant convolutional layers, each followed by batch normalization and a FourierELU activation function. The number of output channels of the convolutional layers is the same for all layers.
- class equiadapt.ESCNNWRNEquivariantNetwork(in_shape: tuple, out_channels: int = 64, kernel_size: int = 9, group_type: str = 'rotation', num_layers: int = 12, num_rotations: int = 4)[source]
Bases:
Module
This class represents a Wide Residual Network (WRN) that is equivariant under rotations or roto-reflections.
The network consists of a sequence of equivariant convolutional layers, each followed by batch normalization and a ReLU activation function. The number of output channels of the convolutional layers is the same for all layers. The input is added to the output of the layer (residual connection).
- class equiadapt.ESCNNWideBasic(in_type: FieldType, middle_type: FieldType, out_type: FieldType, kernel_size: int = 3)[source]
Bases:
EquivariantModule
This class represents a wide basic layer for an Equivariant Convolutional Neural Network (Equivariant CNN).
The layer consists of a sequence of equivariant convolutional layers, each followed by batch normalization and a ReLU activation function. The number of output channels of the convolutional layers is the same for all layers. The input is added to the output of the layer (residual connection).
- class equiadapt.ESCNNWideBottleneck(in_type: FieldType, middle_type: FieldType, out_type: FieldType, kernel_size: int = 3)[source]
Bases:
EquivariantModule
This class represents a wide bottleneck layer for an Equivariant Convolutional Neural Network (Equivariant CNN).
The layer consists of a sequence of equivariant convolutional layers, each followed by batch normalization and a ReLU activation function. The number of output channels of the convolutional layers is the same for all layers. The input is added to the output of the layer (residual connection).
- class equiadapt.EquivariantPointcloudCanonicalization(canonicalization_network: Module, canonicalization_hyperparams: DictConfig)[source]
Bases:
ContinuousGroupPointcloudCanonicalization
This class represents the equivariant point cloud canonicalization module.
It inherits from the ContinuousGroupPointcloudCanonicalization class.
- Parameters:
canonicalization_network (torch.nn.Module) – The canonicalization network module.
canonicalization_hyperparams (DictConfig) – The hyperparameters for the canonicalization.
- canonicalization_network
The canonicalization network module.
- Type:
torch.nn.Module
- canonicalization_hyperparams
The hyperparameters for the canonicalization.
- Type:
DictConfig
- class equiadapt.GroupEquivariantImageCanonicalization(canonicalization_network: Module, canonicalization_hyperparams: DictConfig, in_shape: tuple)[source]
Bases:
DiscreteGroupImageCanonicalization
This class represents a discrete group equivariant image canonicalization model.
The model is designed to be equivariant under a discrete group of transformations, which can include rotations and reflections.
- get_group_activations(x: Tensor) Tensor [source]
Gets the group activations for the input image.
This method takes an image as input, applies transformations before forwarding it through the canonicalization network, and then returns the group activations.
- Parameters:
x (torch.Tensor) – The input image.
- Returns:
The group activations.
- Return type:
torch.Tensor
- class equiadapt.IdentityCanonicalization(canonicalization_network: Module = Identity())[source]
Bases:
BaseCanonicalization
This class represents an identity canonicalization method.
Identity canonicalization is a no-op; it doesn’t change the input data. It’s useful as a default or placeholder when no other canonicalization method is specified.
- canonicalization_network
The network used for canonicalization. Defaults to torch.nn.Identity.
- Type:
torch.nn.Module
- canonicalize()[source]
Canonicalizes the input data. In this class, it returns the input data unchanged.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List] [source]
Canonicalize the input data.
This method takes the input data and returns it unchanged, along with the targets if provided. It’s a no-op in the IdentityCanonicalization class.
- Parameters:
x – The input data.
targets – (Optional) Additional targets that need to be canonicalized.
**kwargs – Additional arguments.
- Returns:
A tuple containing the unchanged input data and targets if targets are provided, otherwise just the unchanged input data.
- get_identity_metric() Tensor [source]
Gets the identity metric.
For the IdentityCanonicalization class, this is always 1.
- Returns:
A tensor containing the value 1.
- Return type:
torch.Tensor
- get_prior_regularization_loss() Tensor [source]
Gets the prior regularization loss.
For the IdentityCanonicalization class, this is always 0.
- Returns:
A tensor containing the value 0.
- Return type:
torch.Tensor
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor [source]
Inverts the canonicalization.
For the IdentityCanonicalization class, this is a no-op and returns the input unchanged.
- Parameters:
x_canonicalized_out (torch.Tensor) – The canonicalized output.
**kwargs – Additional arguments.
- Returns:
The unchanged x_canonicalized_out.
- Return type:
torch.Tensor
- class equiadapt.LieParameterization(group_type: str, group_dim: int)[source]
Bases:
Module
A class for parameterizing Lie groups and their representations for a single block.
- Parameters:
- get_en_rep(params: Tensor, reflect_indicators: Tensor) Tensor [source]
Computes the representation for E(n) group, including both rotations and translations.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim), where the first part corresponds to rotation/reflection parameters and the last ‘n’ parameters correspond to translation.
- Returns:
The representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- get_group_rep(params: Tensor) Tensor [source]
Computes the representation for the specified Lie group.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).
- Returns:
The group representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- get_on_rep(params: Tensor, reflect_indicators: Tensor) Tensor [source]
Computes the representation for O(n) group, optionally including reflections.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).
reflect_indicators (torch.Tensor) – Indicators of whether to reflect, of shape (batch_size, 1).
- Returns:
The representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- get_sen_rep(params: Tensor) Tensor [source]
Computes the representation for SEn group.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).
- Returns:
The representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- class equiadapt.OptimizedGroupEquivariantImageCanonicalization(canonicalization_network: Module, canonicalization_hyperparams: DictConfig, in_shape: tuple)[source]
Bases:
DiscreteGroupImageCanonicalization
This class represents an optimized (discrete) group equivariant image canonicalization model.
The model is designed to be equivariant under a discrete group of transformations, which can include rotations and reflections.
- group_augment()[source]
Augment the input images by applying group transformations (rotations and reflections).
- get_group_activations(x: Tensor) Tensor [source]
Gets the group activations for the input image.
- Parameters:
x (torch.Tensor) – The input image.
- Returns:
The group activations.
- Return type:
torch.Tensor
- get_optimization_specific_loss() Tensor [source]
Gets the loss specific to the optimization process.
- Returns:
The loss.
- Return type:
torch.Tensor
- group_augment(x: Tensor) Tensor [source]
Augment the input images by applying group transformations (rotations and reflections).
- Parameters:
x (torch.Tensor) – The input image.
- Returns:
The augmented image.
- Return type:
torch.Tensor
- rotate_and_maybe_reflect(x: Tensor, degrees: Tensor, reflect: bool = False) List[Tensor] [source]
Rotate and maybe reflect the input images.
- Parameters:
x (torch.Tensor) – The input image.
degrees (torch.Tensor) – The degrees of rotation.
reflect (bool, optional) – Whether to reflect the image. Defaults to False.
- Returns:
The list of rotated and maybe reflected images.
- Return type:
List[torch.Tensor]
- class equiadapt.OptimizedSteerableImageCanonicalization(canonicalization_network: Module, canonicalization_hyperparams: DictConfig, in_shape: tuple)[source]
Bases:
ContinuousGroupImageCanonicalization
This class represents an optimized steerable image canonicalization model.
The model is designed to be equivariant under a continuous group of transformations, which can include rotations and reflections.
- get_rotation_matrix_from_vector()[source]
This method takes the input vector and returns the rotation matrix.
- get_groupelement(x: Tensor) dict [source]
Maps the input image to the group element.
- Parameters:
x (torch.Tensor) – The input image.
- Returns:
The group element.
- Return type:
- get_optimization_specific_loss() Tensor [source]
This method returns the optimization specific loss
- Returns:
optimization specific loss
- Return type:
torch.Tensor
- get_rotation_matrix_from_vector(vectors: Tensor) Tensor [source]
This method takes the input vector and returns the rotation matrix
- Parameters:
vectors (torch.Tensor) – input vector
- Returns:
rotation matrices
- Return type:
torch.Tensor
- group_augment(x: Tensor) Tuple[Tensor, Tensor] [source]
Augmentation of the input images by applying random rotations and, if applicable, reflections, with corresponding transformation matrices.
- Parameters:
x (torch.Tensor) – Input images of shape (batch_size, in_channels, height, width).
- Returns:
Augmented images. torch.Tensor: Corresponding transformation matrices.
- Return type:
torch.Tensor
- class equiadapt.ResNet18Network(in_shape: tuple, out_channels: int, kernel_size: int, num_layers: int = 2, out_vector_size: int = 128)[source]
Bases:
Module
This class represents a neural network based on the ResNet-18 architecture.
The network uses a pre-trained ResNet-18 model without its weights. The final fully connected layer of the ResNet-18 model is replaced with a new fully connected layer.
- resnet18
The ResNet-18 model.
- Type:
torchvision.models.ResNet
- class equiadapt.RotationEquivariantConv(in_channels: int, out_channels: int, kernel_size: int, num_rotations: int = 4, stride: int = 1, padding: int = 0, bias: bool = True, device: str = 'cuda')[source]
Bases:
Module
This class represents a rotation equivariant convolutional layer.
The layer is equivariant to a group of rotations. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias.
- get_rotated_permuted_weights()[source]
Returns the weights of the layer after rotation and permutation.
- forward(x: Tensor) Tensor [source]
Performs a forward pass through the layer.
- Parameters:
x (torch.Tensor) – The input data. It should have the shape (batch_size, in_channels, num_rotations, height, width).
- Returns:
The output of the layer. It has the shape (batch_size, out_channels, num_rotations, height, width).
- Return type:
torch.Tensor
- get_rotated_permuted_weights(weights: Tensor, num_rotations: int = 4) Tensor [source]
Returns the weights of the layer after rotation and permutation.
- Parameters:
weights (torch.Tensor) – The weights of the layer.
num_rotations (int, optional) – The number of rotations in the group. Defaults to 4.
- Returns:
The weights after rotation and permutation.
- Return type:
torch.Tensor
- class equiadapt.RotationEquivariantConvLift(in_channels: int, out_channels: int, kernel_size: int, num_rotations: int = 4, stride: int = 1, padding: int = 0, bias: bool = True, device: str = 'cuda')[source]
Bases:
Module
This class represents a rotation equivariant convolutional layer with lifting.
The layer is equivariant to a group of rotations. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias.
- forward(x: Tensor) Tensor [source]
Performs a forward pass through the layer.
- Parameters:
x (torch.Tensor) – The input data. It should have the shape (batch_size, in_channels, height, width).
- Returns:
The output of the layer. It has the shape (batch_size, out_channels, num_rotations, height, width).
- Return type:
torch.Tensor
- get_rotated_weights(weights: Tensor, num_rotations: int = 4) Tensor [source]
Returns the weights of the layer after rotation.
- Parameters:
weights (torch.Tensor) – The weights of the layer.
num_rotations (int, optional) – The number of rotations in the group. Defaults to 4.
- Returns:
The weights after rotation.
- Return type:
torch.Tensor
- class equiadapt.RotoReflectionEquivariantConv(in_channels: int, out_channels: int, kernel_size: int, num_rotations: int = 4, stride: int = 1, padding: int = 0, bias: bool = True, device: str = 'cuda')[source]
Bases:
Module
This class represents a roto-reflection equivariant convolutional layer.
The layer is equivariant to a group of rotations and reflections. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias.
- get_rotoreflected_permuted_weights()[source]
Returns the weights of the layer after rotation, reflection, and permutation.
- forward(x: Tensor) Tensor [source]
Performs a forward pass through the layer.
- Parameters:
x (torch.Tensor) – The input data. It should have the shape (batch_size, in_channels, num_group_elements, height, width).
- Returns:
The output of the layer. It has the shape (batch_size, out_channels, num_group_elements, height, width).
- Return type:
torch.Tensor
- get_rotoreflected_permuted_weights(weights: Tensor, num_rotations: int = 4) Tensor [source]
Returns the weights of the layer after rotation, reflection, and permutation.
- Parameters:
weights (torch.Tensor) – The weights of the layer.
num_rotations (int, optional) – The number of rotations in the group. Defaults to 4.
- Returns:
The weights after rotation, reflection, and permutation.
- Return type:
torch.Tensor
- class equiadapt.RotoReflectionEquivariantConvLift(in_channels: int, out_channels: int, kernel_size: int, num_rotations: int = 4, stride: int = 1, padding: int = 0, bias: bool = True, device: str = 'cuda')[source]
Bases:
Module
This class represents a roto-reflection equivariant convolutional layer with lifting.
The layer is equivariant to a group of rotations and reflections. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias.
- get_rotoreflected_weights()[source]
Returns the weights of the layer after rotation, reflection, and permutation.
- forward(x: Tensor) Tensor [source]
Performs a forward pass through the layer.
- Parameters:
x (torch.Tensor) – The input data. It should have the shape (batch_size, in_channels, height, width).
- Returns:
The output of the layer. It has the shape (batch_size, out_channels, num_group_elements, height, width).
- Return type:
torch.Tensor
- get_rotoreflected_weights(weights: Tensor, num_rotations: int = 4) Tensor [source]
Returns the weights of the layer after rotation and reflection.
- Parameters:
weights (torch.Tensor) – The weights of the layer.
num_rotations (int, optional) – The number of rotations in the group. Defaults to 4.
- Returns:
The weights after rotation, reflection, and permutation.
- Return type:
torch.Tensor
- class equiadapt.SteerableImageCanonicalization(canonicalization_network: Module, canonicalization_hyperparams: DictConfig, in_shape: tuple)[source]
Bases:
ContinuousGroupImageCanonicalization
This class represents a steerable image canonicalization model.
The model is designed to be equivariant under a continuous group of euclidean transformations - rotations and reflections.
- get_rotation_matrix_from_vector()[source]
This method takes the input vector and returns the rotation matrix.
- class equiadapt.VNBatchNorm(num_features: int, dim: int)[source]
Bases:
Module
Vector Neuron Batch Normalization layer.
VNBatchNorm applies batch normalization to the input features.
- class equiadapt.VNBilinear(in_channels1: int, in_channels2: int, out_channels: int)[source]
Bases:
Module
Vector Neuron Bilinear layer.
VNBilinear applies a bilinear layer to the input features.
- forward(x: Tensor, labels: Tensor) Tensor [source]
Forward pass of the VNBilinear layer.
- Parameters:
x (torch.Tensor) – Input features of shape [B, N_feat, 3, N_samples, …].
labels (torch.Tensor) – Labels of shape [B, N_feat, N_samples].
- Returns:
Output features after applying the bilinear transformation.
- Return type:
torch.Tensor
- class equiadapt.VNLeakyReLU(in_channels: int, share_nonlinearity: bool = False, negative_slope: float = 0.2)[source]
Bases:
Module
Vector Neuron Leaky ReLU layer.
VNLLeakyReLU applies a LeakyReLU activation to the input features.
- class equiadapt.VNLinear(in_channels: int, out_channels: int)[source]
Bases:
Module
Vector Neuron Linear layer.
This layer applies a linear transformation to the input tensor.
- class equiadapt.VNLinearLeakyReLU(in_channels: int, out_channels: int, dim: int = 5, share_nonlinearity: bool = False, negative_slope: float = 0.2)[source]
Bases:
Module
Vector Neuron Linear Leaky ReLU layer.
VNLinearLeakyReLU applies a linear transformation followed by a LeakyReLU activation to the input features.
- class equiadapt.VNMaxPool(in_channels: int)[source]
Bases:
Module
Vector Neuron Max Pooling layer.
VNMaxPool applies max pooling to the input features.
- class equiadapt.VNSmall(hyperparams: DictConfig)[source]
Bases:
Module
VNSmall is a small variant of the vector neuron equivariant network used for canonicalization of point clouds.
- Parameters:
hyperparams (DictConfig) – Hyperparameters for the network.
- conv_pos
Convolutional layer for positional encoding.
- Type:
- conv1
First convolutional layer.
- Type:
- bn1
Batch normalization layer.
- Type:
- conv2
Second convolutional layer.
- Type:
- dropout
Dropout layer.
- Type:
nn.Dropout
- forward(point_cloud: Tensor) Tensor [source]
Forward pass of the VNSmall network.
For every pointcloud in the batch, the network outputs three vectors that transform equivariantly with respect to SO3 group.
- Parameters:
point_cloud (torch.Tensor) – Input point cloud tensor of shape (batch_size, num_points, 3).
- Returns:
Output tensor of shape (batch_size, 3, 3).
- Return type:
torch.Tensor
- class equiadapt.VNSoftplus(in_channels: int, share_nonlinearity: bool = False, negative_slope: float = 0.0)[source]
Bases:
Module
Vector Neuron Softplus layer.
VNSoftplus applies a softplus activation to the input features.
- class equiadapt.VNStdFeature(in_channels: int, dim: int = 4, normalize_frame: bool = False, share_nonlinearity: bool = False, negative_slope: float = 0.2)[source]
Bases:
Module
Vector Neuron Standard Feature module.
This module performs standard feature extraction using Vector Neuron layers. It takes point features as input and applies a series of VNLinearLeakyReLU layers followed by a linear layer to produce the standard features.
- Shape:
Input: (B, N_feat, 3, N_samples, …)
- Output:
x_std: (B, N_feat, dim, N_samples, …)
z0: (B, dim, 3)
Example
>>> model = VNStdFeature(in_channels=64, dim=4, normalize_frame=True) >>> input = torch.randn(2, 64, 3, 100) >>> output, frame_vectors = model(input)
- forward(x: Tensor) Tuple[Tensor, Tensor] [source]
Forward pass of the VNStdFeature module.
- Parameters:
x (torch.Tensor) – Input point features of shape (B, N_feat, 3, N_samples, …).
- Returns:
Tuple containing the standard features and the frame vectors.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Note
The frame vectors are computed only if normalize_frame is set to True.
The shape of the standard features depends on the value of dim.
- equiadapt.get_action_on_image_features(feature_map: Tensor, group_info_dict: dict, group_element_dict: dict, induced_rep_type: str = 'regular') Tensor [source]
Applies a group action to the feature map.
- Parameters:
- Returns:
The feature map after the group action has been applied.
- Return type:
torch.Tensor
- equiadapt.get_graph_feature_cross(x: Tensor, k: int = 20, idx: Tensor | None = None) Tensor [source]
Computes the graph feature cross for a given input tensor.
- Parameters:
x (torch.Tensor) – The input tensor of shape (batch_size, num_dims, num_points).
k (int, optional) – The number of nearest neighbors to consider. Defaults to 20.
idx (torch.Tensor, optional) – The indices of the nearest neighbors. Defaults to None.
- Returns:
The computed graph feature cross tensor of shape (batch_size, num_dims*3, num_points, k).
- Return type:
torch.Tensor
- equiadapt.gram_schmidt(vectors: Tensor) Tensor [source]
Applies the Gram-Schmidt process to orthogonalize a set of three vectors in a batch-wise manner.
- Parameters:
vectors (torch.Tensor) – A batch of vectors of shape (batch_size, n_vectors, vector_dim), where n_vectors is the number of vectors to orthogonalize (here 3).
- Returns:
The orthogonalized vectors of the same shape as the input.
- Return type:
torch.Tensor
Examples
>>> vectors = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]]) >>> result = gram_schmidt(vectors) >>> print(result) tensor([[[1.0000, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 1.0000]]])