equiadapt.common package
Submodules
equiadapt.common.basecanonicalization module
This module defines a base class for canonicalization and its subclasses for different types of canonicalization methods.
Canonicalization is a process that transforms the input data into a canonical (standard) form. This can be cheap alternative to building equivariant models as it can be used to transform the input data into a canonical form and then use a standard model to make predictions. Canonicalizarion allows you to use any existing arcitecture (even pre-trained ones) for your task without having to worry about equivariance.
The module contains the following classes:
BaseCanonicalization: This is an abstract base class that defines the interface for all canonicalization methods.
IdentityCanonicalization: This class represents an identity canonicalization method, which is a no-op; it doesn’t change the input data.
DiscreteGroupCanonicalization: This class represents a discrete group canonicalization method, which transforms the input data into a canonical form using a discrete group.
ContinuousGroupCanonicalization: This class represents a continuous group canonicalization method, which transforms the input data into a canonical form using a continuous group.
Each class has methods to perform the canonicalization, invert it, and calculate the prior regularization loss and identity metric.
- class equiadapt.common.basecanonicalization.BaseCanonicalization(canonicalization_network: Module)[source]
Bases:
ModuleThis is the base class for canonicalization.
This class is used as a base for all canonicalization methods. Subclasses should implement the canonicalize method to define the specific canonicalization process.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
This method takes an input data with, optionally, targets that need to be canonicalized
- Parameters:
x – input data
targets – (optional) additional targets that need to be canonicalized, such as boxes for promptable instance segmentation
**kwargs – additional arguments
- Returns:
the canonicalized version of the data and targets
- forward(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
Forward method for the canonicalization which takes the input data and returns the canonicalized version of the data
- Parameters:
x – input data
targets – (optional) additional targets that need to be canonicalized, such as boxes for promptable instance segmentation
**kwargs – additional arguments
- Returns:
canonicalized version of the input data
- Return type:
canonicalized_x
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]
This method takes the output of the canonicalized data and returns the output for the original data orientation
- Parameters:
canonicalized_outputs – output of the prediction network for canonicalized data
- Returns:
- output of the prediction network for the original data orientation,
by using the group element used to canonicalize the original data
- Return type:
outputs
- class equiadapt.common.basecanonicalization.ContinuousGroupCanonicalization(canonicalization_network: Module, beta: float = 1.0)[source]
Bases:
BaseCanonicalizationThis class represents a continuous group canonicalization method.
Continuous group canonicalization is a method that transforms the input data into a canonical form using a continuous group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for continuous group canonicalization.
- canonicalization_network
The network used for canonicalization.
- Type:
torch.nn.Module
- canonicalizationnetworkout_to_groupelement()[source]
Converts the output of the canonicalization network to a group element in a differentiable manner.
- canonicalizationnetworkout_to_groupelement(group_activations: Tensor) Tensor[source]
Converts the output of the canonicalization network to a group element in a differentiable manner.
- Parameters:
group_activations (torch.Tensor) – The activations for each group element.
- Returns:
The group element.
- Return type:
torch.Tensor
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
Canonicalizes the input data.
- Parameters:
x (torch.Tensor) – The input data.
targets (List, optional) – Additional targets that need to be canonicalized.
**kwargs – Additional arguments.
- Returns:
The canonicalized input data and targets. Simultaneously, it updates a dictionary containing the information about the canonicalization.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, List]]
- get_identity_metric() Tensor[source]
Gets the identity metric.
The identity metric is calculated as 1 minus the mean of the mean squared error between the group element matrix representation and the identity matrix.
- Returns:
The identity metric.
- Return type:
torch.Tensor
- get_prior_regularization_loss() Tensor[source]
Gets the prior regularization loss.
The prior regularization loss is calculated as the mean squared error between the group element matrix representation and the identity matrix.
- Returns:
The prior regularization loss.
- Return type:
torch.Tensor
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]
Inverts the canonicalization.
- Parameters:
x_canonicalized_out (torch.Tensor) – The canonicalized output.
**kwargs – Additional arguments.
- Returns:
The output for the original data orientation.
- Return type:
torch.Tensor
- class equiadapt.common.basecanonicalization.DiscreteGroupCanonicalization(canonicalization_network: Module, beta: float = 1.0, gradient_trick: str = 'straight_through')[source]
Bases:
BaseCanonicalizationThis class represents a discrete group canonicalization method.
Discrete group canonicalization is a method that transforms the input data into a canonical form using a discrete group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for discrete group canonicalization.
- canonicalization_network
The network used for canonicalization.
- Type:
torch.nn.Module
- gradient_trick
The method used for backpropagation through the discrete operation. Defaults to “straight_through”.
- Type:
- groupactivations_to_groupelementonehot()[source]
Converts group activations to one-hot encoded group elements in a differentiable manner.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
Canonicalizes the input data.
- Parameters:
x (torch.Tensor) – The input data.
targets (List, optional) – Additional targets that need to be canonicalized.
**kwargs – Additional arguments.
- Returns:
The canonicalized input data and targets. Simultaneously, it updates a dictionary containing the information about the canonicalization.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, List]]
- get_identity_metric() Tensor[source]
Gets the identity metric.
- Returns:
The identity metric.
- Return type:
torch.Tensor
- get_prior_regularization_loss() Tensor[source]
Gets the prior regularization loss.
- Returns:
The prior regularization loss.
- Return type:
torch.Tensor
- groupactivations_to_groupelementonehot(group_activations: Tensor) Tensor[source]
Converts group activations to one-hot encoded group elements in a differentiable manner.
- Parameters:
group_activations (torch.Tensor) – The activations for each group element.
- Returns:
The one-hot encoding of the group elements.
- Return type:
torch.Tensor
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]
Inverts the canonicalization.
- Parameters:
x_canonicalized_out (torch.Tensor) – The canonicalized output.
**kwargs – Additional arguments.
- Returns:
The output for the original data orientation.
- Return type:
torch.Tensor
- class equiadapt.common.basecanonicalization.IdentityCanonicalization(canonicalization_network: Module = Identity())[source]
Bases:
BaseCanonicalizationThis class represents an identity canonicalization method.
Identity canonicalization is a no-op; it doesn’t change the input data. It’s useful as a default or placeholder when no other canonicalization method is specified.
- canonicalization_network
The network used for canonicalization. Defaults to torch.nn.Identity.
- Type:
torch.nn.Module
- canonicalize()[source]
Canonicalizes the input data. In this class, it returns the input data unchanged.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
Canonicalize the input data.
This method takes the input data and returns it unchanged, along with the targets if provided. It’s a no-op in the IdentityCanonicalization class.
- Parameters:
x – The input data.
targets – (Optional) Additional targets that need to be canonicalized.
**kwargs – Additional arguments.
- Returns:
A tuple containing the unchanged input data and targets if targets are provided, otherwise just the unchanged input data.
- get_identity_metric() Tensor[source]
Gets the identity metric.
For the IdentityCanonicalization class, this is always 1.
- Returns:
A tensor containing the value 1.
- Return type:
torch.Tensor
- get_prior_regularization_loss() Tensor[source]
Gets the prior regularization loss.
For the IdentityCanonicalization class, this is always 0.
- Returns:
A tensor containing the value 0.
- Return type:
torch.Tensor
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]
Inverts the canonicalization.
For the IdentityCanonicalization class, this is a no-op and returns the input unchanged.
- Parameters:
x_canonicalized_out (torch.Tensor) – The canonicalized output.
**kwargs – Additional arguments.
- Returns:
The unchanged x_canonicalized_out.
- Return type:
torch.Tensor
equiadapt.common.utils module
- class equiadapt.common.utils.LieParameterization(group_type: str, group_dim: int)[source]
Bases:
ModuleA class for parameterizing Lie groups and their representations for a single block.
- Parameters:
- get_en_rep(params: Tensor, reflect_indicators: Tensor) Tensor[source]
Computes the representation for E(n) group, including both rotations and translations.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim), where the first part corresponds to rotation/reflection parameters and the last ‘n’ parameters correspond to translation.
- Returns:
The representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- get_group_rep(params: Tensor) Tensor[source]
Computes the representation for the specified Lie group.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).
- Returns:
The group representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- get_on_rep(params: Tensor, reflect_indicators: Tensor) Tensor[source]
Computes the representation for O(n) group, optionally including reflections.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).
reflect_indicators (torch.Tensor) – Indicators of whether to reflect, of shape (batch_size, 1).
- Returns:
The representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- get_sen_rep(params: Tensor) Tensor[source]
Computes the representation for SEn group.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).
- Returns:
The representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- equiadapt.common.utils.gram_schmidt(vectors: Tensor) Tensor[source]
Applies the Gram-Schmidt process to orthogonalize a set of three vectors in a batch-wise manner.
- Parameters:
vectors (torch.Tensor) – A batch of vectors of shape (batch_size, n_vectors, vector_dim), where n_vectors is the number of vectors to orthogonalize (here 3).
- Returns:
The orthogonalized vectors of the same shape as the input.
- Return type:
torch.Tensor
Examples
>>> vectors = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]]) >>> result = gram_schmidt(vectors) >>> print(result) tensor([[[1.0000, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 1.0000]]])
Module contents
- class equiadapt.common.BaseCanonicalization(canonicalization_network: Module)[source]
Bases:
ModuleThis is the base class for canonicalization.
This class is used as a base for all canonicalization methods. Subclasses should implement the canonicalize method to define the specific canonicalization process.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
This method takes an input data with, optionally, targets that need to be canonicalized
- Parameters:
x – input data
targets – (optional) additional targets that need to be canonicalized, such as boxes for promptable instance segmentation
**kwargs – additional arguments
- Returns:
the canonicalized version of the data and targets
- forward(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
Forward method for the canonicalization which takes the input data and returns the canonicalized version of the data
- Parameters:
x – input data
targets – (optional) additional targets that need to be canonicalized, such as boxes for promptable instance segmentation
**kwargs – additional arguments
- Returns:
canonicalized version of the input data
- Return type:
canonicalized_x
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]
This method takes the output of the canonicalized data and returns the output for the original data orientation
- Parameters:
canonicalized_outputs – output of the prediction network for canonicalized data
- Returns:
- output of the prediction network for the original data orientation,
by using the group element used to canonicalize the original data
- Return type:
outputs
- class equiadapt.common.ContinuousGroupCanonicalization(canonicalization_network: Module, beta: float = 1.0)[source]
Bases:
BaseCanonicalizationThis class represents a continuous group canonicalization method.
Continuous group canonicalization is a method that transforms the input data into a canonical form using a continuous group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for continuous group canonicalization.
- canonicalization_network
The network used for canonicalization.
- Type:
torch.nn.Module
- canonicalizationnetworkout_to_groupelement()[source]
Converts the output of the canonicalization network to a group element in a differentiable manner.
- canonicalizationnetworkout_to_groupelement(group_activations: Tensor) Tensor[source]
Converts the output of the canonicalization network to a group element in a differentiable manner.
- Parameters:
group_activations (torch.Tensor) – The activations for each group element.
- Returns:
The group element.
- Return type:
torch.Tensor
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
Canonicalizes the input data.
- Parameters:
x (torch.Tensor) – The input data.
targets (List, optional) – Additional targets that need to be canonicalized.
**kwargs – Additional arguments.
- Returns:
The canonicalized input data and targets. Simultaneously, it updates a dictionary containing the information about the canonicalization.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, List]]
- get_identity_metric() Tensor[source]
Gets the identity metric.
The identity metric is calculated as 1 minus the mean of the mean squared error between the group element matrix representation and the identity matrix.
- Returns:
The identity metric.
- Return type:
torch.Tensor
- get_prior_regularization_loss() Tensor[source]
Gets the prior regularization loss.
The prior regularization loss is calculated as the mean squared error between the group element matrix representation and the identity matrix.
- Returns:
The prior regularization loss.
- Return type:
torch.Tensor
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]
Inverts the canonicalization.
- Parameters:
x_canonicalized_out (torch.Tensor) – The canonicalized output.
**kwargs – Additional arguments.
- Returns:
The output for the original data orientation.
- Return type:
torch.Tensor
- class equiadapt.common.DiscreteGroupCanonicalization(canonicalization_network: Module, beta: float = 1.0, gradient_trick: str = 'straight_through')[source]
Bases:
BaseCanonicalizationThis class represents a discrete group canonicalization method.
Discrete group canonicalization is a method that transforms the input data into a canonical form using a discrete group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for discrete group canonicalization.
- canonicalization_network
The network used for canonicalization.
- Type:
torch.nn.Module
- gradient_trick
The method used for backpropagation through the discrete operation. Defaults to “straight_through”.
- Type:
- groupactivations_to_groupelementonehot()[source]
Converts group activations to one-hot encoded group elements in a differentiable manner.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
Canonicalizes the input data.
- Parameters:
x (torch.Tensor) – The input data.
targets (List, optional) – Additional targets that need to be canonicalized.
**kwargs – Additional arguments.
- Returns:
The canonicalized input data and targets. Simultaneously, it updates a dictionary containing the information about the canonicalization.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, List]]
- get_identity_metric() Tensor[source]
Gets the identity metric.
- Returns:
The identity metric.
- Return type:
torch.Tensor
- get_prior_regularization_loss() Tensor[source]
Gets the prior regularization loss.
- Returns:
The prior regularization loss.
- Return type:
torch.Tensor
- groupactivations_to_groupelementonehot(group_activations: Tensor) Tensor[source]
Converts group activations to one-hot encoded group elements in a differentiable manner.
- Parameters:
group_activations (torch.Tensor) – The activations for each group element.
- Returns:
The one-hot encoding of the group elements.
- Return type:
torch.Tensor
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]
Inverts the canonicalization.
- Parameters:
x_canonicalized_out (torch.Tensor) – The canonicalized output.
**kwargs – Additional arguments.
- Returns:
The output for the original data orientation.
- Return type:
torch.Tensor
- class equiadapt.common.IdentityCanonicalization(canonicalization_network: Module = Identity())[source]
Bases:
BaseCanonicalizationThis class represents an identity canonicalization method.
Identity canonicalization is a no-op; it doesn’t change the input data. It’s useful as a default or placeholder when no other canonicalization method is specified.
- canonicalization_network
The network used for canonicalization. Defaults to torch.nn.Identity.
- Type:
torch.nn.Module
- canonicalize()[source]
Canonicalizes the input data. In this class, it returns the input data unchanged.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
Canonicalize the input data.
This method takes the input data and returns it unchanged, along with the targets if provided. It’s a no-op in the IdentityCanonicalization class.
- Parameters:
x – The input data.
targets – (Optional) Additional targets that need to be canonicalized.
**kwargs – Additional arguments.
- Returns:
A tuple containing the unchanged input data and targets if targets are provided, otherwise just the unchanged input data.
- get_identity_metric() Tensor[source]
Gets the identity metric.
For the IdentityCanonicalization class, this is always 1.
- Returns:
A tensor containing the value 1.
- Return type:
torch.Tensor
- get_prior_regularization_loss() Tensor[source]
Gets the prior regularization loss.
For the IdentityCanonicalization class, this is always 0.
- Returns:
A tensor containing the value 0.
- Return type:
torch.Tensor
- invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]
Inverts the canonicalization.
For the IdentityCanonicalization class, this is a no-op and returns the input unchanged.
- Parameters:
x_canonicalized_out (torch.Tensor) – The canonicalized output.
**kwargs – Additional arguments.
- Returns:
The unchanged x_canonicalized_out.
- Return type:
torch.Tensor
- class equiadapt.common.LieParameterization(group_type: str, group_dim: int)[source]
Bases:
ModuleA class for parameterizing Lie groups and their representations for a single block.
- Parameters:
- get_en_rep(params: Tensor, reflect_indicators: Tensor) Tensor[source]
Computes the representation for E(n) group, including both rotations and translations.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim), where the first part corresponds to rotation/reflection parameters and the last ‘n’ parameters correspond to translation.
- Returns:
The representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- get_group_rep(params: Tensor) Tensor[source]
Computes the representation for the specified Lie group.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).
- Returns:
The group representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- get_on_rep(params: Tensor, reflect_indicators: Tensor) Tensor[source]
Computes the representation for O(n) group, optionally including reflections.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).
reflect_indicators (torch.Tensor) – Indicators of whether to reflect, of shape (batch_size, 1).
- Returns:
The representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- get_sen_rep(params: Tensor) Tensor[source]
Computes the representation for SEn group.
- Parameters:
params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).
- Returns:
The representation of shape (batch_size, rep_dim, rep_dim).
- Return type:
torch.Tensor
- equiadapt.common.gram_schmidt(vectors: Tensor) Tensor[source]
Applies the Gram-Schmidt process to orthogonalize a set of three vectors in a batch-wise manner.
- Parameters:
vectors (torch.Tensor) – A batch of vectors of shape (batch_size, n_vectors, vector_dim), where n_vectors is the number of vectors to orthogonalize (here 3).
- Returns:
The orthogonalized vectors of the same shape as the input.
- Return type:
torch.Tensor
Examples
>>> vectors = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]]) >>> result = gram_schmidt(vectors) >>> print(result) tensor([[[1.0000, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 1.0000]]])