equiadapt.common package

Submodules

equiadapt.common.basecanonicalization module

This module defines a base class for canonicalization and its subclasses for different types of canonicalization methods.

Canonicalization is a process that transforms the input data into a canonical (standard) form. This can be cheap alternative to building equivariant models as it can be used to transform the input data into a canonical form and then use a standard model to make predictions. Canonicalizarion allows you to use any existing arcitecture (even pre-trained ones) for your task without having to worry about equivariance.

The module contains the following classes:

  • BaseCanonicalization: This is an abstract base class that defines the interface for all canonicalization methods.

  • IdentityCanonicalization: This class represents an identity canonicalization method, which is a no-op; it doesn’t change the input data.

  • DiscreteGroupCanonicalization: This class represents a discrete group canonicalization method, which transforms the input data into a canonical form using a discrete group.

  • ContinuousGroupCanonicalization: This class represents a continuous group canonicalization method, which transforms the input data into a canonical form using a continuous group.

Each class has methods to perform the canonicalization, invert it, and calculate the prior regularization loss and identity metric.

class equiadapt.common.basecanonicalization.BaseCanonicalization(canonicalization_network: Module)[source]

Bases: Module

This is the base class for canonicalization.

This class is used as a base for all canonicalization methods. Subclasses should implement the canonicalize method to define the specific canonicalization process.

canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

This method takes an input data with, optionally, targets that need to be canonicalized

Parameters:
  • x – input data

  • targets – (optional) additional targets that need to be canonicalized, such as boxes for promptable instance segmentation

  • **kwargs – additional arguments

Returns:

the canonicalized version of the data and targets

forward(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

Forward method for the canonicalization which takes the input data and returns the canonicalized version of the data

Parameters:
  • x – input data

  • targets – (optional) additional targets that need to be canonicalized, such as boxes for promptable instance segmentation

  • **kwargs – additional arguments

Returns:

canonicalized version of the input data

Return type:

canonicalized_x

invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]

This method takes the output of the canonicalized data and returns the output for the original data orientation

Parameters:

canonicalized_outputs – output of the prediction network for canonicalized data

Returns:

output of the prediction network for the original data orientation,

by using the group element used to canonicalize the original data

Return type:

outputs

class equiadapt.common.basecanonicalization.ContinuousGroupCanonicalization(canonicalization_network: Module, beta: float = 1.0)[source]

Bases: BaseCanonicalization

This class represents a continuous group canonicalization method.

Continuous group canonicalization is a method that transforms the input data into a canonical form using a continuous group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for continuous group canonicalization.

canonicalization_network

The network used for canonicalization.

Type:

torch.nn.Module

beta

A parameter for the softmax function. Defaults to 1.0.

Type:

float

__init__()[source]

Initializes the ContinuousGroupCanonicalization instance.

canonicalizationnetworkout_to_groupelement()[source]

Converts the output of the canonicalization network to a group element in a differentiable manner.

canonicalize()[source]

Canonicalizes the input data.

invert_canonicalization()[source]

Inverts the canonicalization.

get_prior_regularization_loss()[source]

Gets the prior regularization loss.

get_identity_metric()[source]

Gets the identity metric.

canonicalizationnetworkout_to_groupelement(group_activations: Tensor) Tensor[source]

Converts the output of the canonicalization network to a group element in a differentiable manner.

Parameters:

group_activations (torch.Tensor) – The activations for each group element.

Returns:

The group element.

Return type:

torch.Tensor

canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

Canonicalizes the input data.

Parameters:
  • x (torch.Tensor) – The input data.

  • targets (List, optional) – Additional targets that need to be canonicalized.

  • **kwargs – Additional arguments.

Returns:

The canonicalized input data and targets. Simultaneously, it updates a dictionary containing the information about the canonicalization.

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, List]]

get_identity_metric() Tensor[source]

Gets the identity metric.

The identity metric is calculated as 1 minus the mean of the mean squared error between the group element matrix representation and the identity matrix.

Returns:

The identity metric.

Return type:

torch.Tensor

get_prior_regularization_loss() Tensor[source]

Gets the prior regularization loss.

The prior regularization loss is calculated as the mean squared error between the group element matrix representation and the identity matrix.

Returns:

The prior regularization loss.

Return type:

torch.Tensor

invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]

Inverts the canonicalization.

Parameters:
  • x_canonicalized_out (torch.Tensor) – The canonicalized output.

  • **kwargs – Additional arguments.

Returns:

The output for the original data orientation.

Return type:

torch.Tensor

class equiadapt.common.basecanonicalization.DiscreteGroupCanonicalization(canonicalization_network: Module, beta: float = 1.0, gradient_trick: str = 'straight_through')[source]

Bases: BaseCanonicalization

This class represents a discrete group canonicalization method.

Discrete group canonicalization is a method that transforms the input data into a canonical form using a discrete group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for discrete group canonicalization.

canonicalization_network

The network used for canonicalization.

Type:

torch.nn.Module

beta

A parameter for the softmax function. Defaults to 1.0.

Type:

float

gradient_trick

The method used for backpropagation through the discrete operation. Defaults to “straight_through”.

Type:

str

__init__()[source]

Initializes the DiscreteGroupCanonicalization instance.

groupactivations_to_groupelementonehot()[source]

Converts group activations to one-hot encoded group elements in a differentiable manner.

canonicalize()[source]

Canonicalizes the input data.

invert_canonicalization()[source]

Inverts the canonicalization.

get_prior_regularization_loss()[source]

Gets the prior regularization loss.

get_identity_metric()[source]

Gets the identity metric.

canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

Canonicalizes the input data.

Parameters:
  • x (torch.Tensor) – The input data.

  • targets (List, optional) – Additional targets that need to be canonicalized.

  • **kwargs – Additional arguments.

Returns:

The canonicalized input data and targets. Simultaneously, it updates a dictionary containing the information about the canonicalization.

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, List]]

get_identity_metric() Tensor[source]

Gets the identity metric.

Returns:

The identity metric.

Return type:

torch.Tensor

get_prior_regularization_loss() Tensor[source]

Gets the prior regularization loss.

Returns:

The prior regularization loss.

Return type:

torch.Tensor

groupactivations_to_groupelementonehot(group_activations: Tensor) Tensor[source]

Converts group activations to one-hot encoded group elements in a differentiable manner.

Parameters:

group_activations (torch.Tensor) – The activations for each group element.

Returns:

The one-hot encoding of the group elements.

Return type:

torch.Tensor

invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]

Inverts the canonicalization.

Parameters:
  • x_canonicalized_out (torch.Tensor) – The canonicalized output.

  • **kwargs – Additional arguments.

Returns:

The output for the original data orientation.

Return type:

torch.Tensor

class equiadapt.common.basecanonicalization.IdentityCanonicalization(canonicalization_network: Module = Identity())[source]

Bases: BaseCanonicalization

This class represents an identity canonicalization method.

Identity canonicalization is a no-op; it doesn’t change the input data. It’s useful as a default or placeholder when no other canonicalization method is specified.

canonicalization_network

The network used for canonicalization. Defaults to torch.nn.Identity.

Type:

torch.nn.Module

__init__()[source]

Initializes the IdentityCanonicalization instance.

canonicalize()[source]

Canonicalizes the input data. In this class, it returns the input data unchanged.

canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

Canonicalize the input data.

This method takes the input data and returns it unchanged, along with the targets if provided. It’s a no-op in the IdentityCanonicalization class.

Parameters:
  • x – The input data.

  • targets – (Optional) Additional targets that need to be canonicalized.

  • **kwargs – Additional arguments.

Returns:

A tuple containing the unchanged input data and targets if targets are provided, otherwise just the unchanged input data.

get_identity_metric() Tensor[source]

Gets the identity metric.

For the IdentityCanonicalization class, this is always 1.

Returns:

A tensor containing the value 1.

Return type:

torch.Tensor

get_prior_regularization_loss() Tensor[source]

Gets the prior regularization loss.

For the IdentityCanonicalization class, this is always 0.

Returns:

A tensor containing the value 0.

Return type:

torch.Tensor

invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]

Inverts the canonicalization.

For the IdentityCanonicalization class, this is a no-op and returns the input unchanged.

Parameters:
  • x_canonicalized_out (torch.Tensor) – The canonicalized output.

  • **kwargs – Additional arguments.

Returns:

The unchanged x_canonicalized_out.

Return type:

torch.Tensor

equiadapt.common.utils module

class equiadapt.common.utils.LieParameterization(group_type: str, group_dim: int)[source]

Bases: Module

A class for parameterizing Lie groups and their representations for a single block.

Parameters:
  • group_type (str) – The type of Lie group (e.g., ‘SOn’, ‘SEn’, ‘On’, ‘En’).

  • group_dim (int) – The dimension of the Lie group.

group_type

Type of Lie group.

Type:

str

group_dim

Dimension of the Lie group.

Type:

int

get_en_rep(params: Tensor, reflect_indicators: Tensor) Tensor[source]

Computes the representation for E(n) group, including both rotations and translations.

Parameters:

params (torch.Tensor) – Input parameters of shape (batch_size, param_dim), where the first part corresponds to rotation/reflection parameters and the last ‘n’ parameters correspond to translation.

Returns:

The representation of shape (batch_size, rep_dim, rep_dim).

Return type:

torch.Tensor

get_group_rep(params: Tensor) Tensor[source]

Computes the representation for the specified Lie group.

Parameters:

params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).

Returns:

The group representation of shape (batch_size, rep_dim, rep_dim).

Return type:

torch.Tensor

get_on_rep(params: Tensor, reflect_indicators: Tensor) Tensor[source]

Computes the representation for O(n) group, optionally including reflections.

Parameters:
  • params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).

  • reflect_indicators (torch.Tensor) – Indicators of whether to reflect, of shape (batch_size, 1).

Returns:

The representation of shape (batch_size, rep_dim, rep_dim).

Return type:

torch.Tensor

get_sen_rep(params: Tensor) Tensor[source]

Computes the representation for SEn group.

Parameters:

params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).

Returns:

The representation of shape (batch_size, rep_dim, rep_dim).

Return type:

torch.Tensor

get_son_bases() Tensor[source]

Generates the basis of the Lie group of SOn.

Returns:

The son basis of shape (num_params, group_dim, group_dim).

Return type:

torch.Tensor

get_son_rep(params: Tensor) Tensor[source]

Computes the representation for SOn group.

Parameters:

params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).

Returns:

The representation of shape (batch_size, rep_dim, rep_dim).

Return type:

torch.Tensor

equiadapt.common.utils.gram_schmidt(vectors: Tensor) Tensor[source]

Applies the Gram-Schmidt process to orthogonalize a set of three vectors in a batch-wise manner.

Parameters:

vectors (torch.Tensor) – A batch of vectors of shape (batch_size, n_vectors, vector_dim), where n_vectors is the number of vectors to orthogonalize (here 3).

Returns:

The orthogonalized vectors of the same shape as the input.

Return type:

torch.Tensor

Examples

>>> vectors = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]])
>>> result = gram_schmidt(vectors)
>>> print(result)
tensor([[[1.0000, 0.0000, 0.0000],
         [0.0000, 1.0000, 0.0000],
         [0.0000, 0.0000, 1.0000]]])

Module contents

class equiadapt.common.BaseCanonicalization(canonicalization_network: Module)[source]

Bases: Module

This is the base class for canonicalization.

This class is used as a base for all canonicalization methods. Subclasses should implement the canonicalize method to define the specific canonicalization process.

canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

This method takes an input data with, optionally, targets that need to be canonicalized

Parameters:
  • x – input data

  • targets – (optional) additional targets that need to be canonicalized, such as boxes for promptable instance segmentation

  • **kwargs – additional arguments

Returns:

the canonicalized version of the data and targets

forward(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

Forward method for the canonicalization which takes the input data and returns the canonicalized version of the data

Parameters:
  • x – input data

  • targets – (optional) additional targets that need to be canonicalized, such as boxes for promptable instance segmentation

  • **kwargs – additional arguments

Returns:

canonicalized version of the input data

Return type:

canonicalized_x

invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]

This method takes the output of the canonicalized data and returns the output for the original data orientation

Parameters:

canonicalized_outputs – output of the prediction network for canonicalized data

Returns:

output of the prediction network for the original data orientation,

by using the group element used to canonicalize the original data

Return type:

outputs

class equiadapt.common.ContinuousGroupCanonicalization(canonicalization_network: Module, beta: float = 1.0)[source]

Bases: BaseCanonicalization

This class represents a continuous group canonicalization method.

Continuous group canonicalization is a method that transforms the input data into a canonical form using a continuous group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for continuous group canonicalization.

canonicalization_network

The network used for canonicalization.

Type:

torch.nn.Module

beta

A parameter for the softmax function. Defaults to 1.0.

Type:

float

__init__()[source]

Initializes the ContinuousGroupCanonicalization instance.

canonicalizationnetworkout_to_groupelement()[source]

Converts the output of the canonicalization network to a group element in a differentiable manner.

canonicalize()[source]

Canonicalizes the input data.

invert_canonicalization()[source]

Inverts the canonicalization.

get_prior_regularization_loss()[source]

Gets the prior regularization loss.

get_identity_metric()[source]

Gets the identity metric.

canonicalizationnetworkout_to_groupelement(group_activations: Tensor) Tensor[source]

Converts the output of the canonicalization network to a group element in a differentiable manner.

Parameters:

group_activations (torch.Tensor) – The activations for each group element.

Returns:

The group element.

Return type:

torch.Tensor

canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

Canonicalizes the input data.

Parameters:
  • x (torch.Tensor) – The input data.

  • targets (List, optional) – Additional targets that need to be canonicalized.

  • **kwargs – Additional arguments.

Returns:

The canonicalized input data and targets. Simultaneously, it updates a dictionary containing the information about the canonicalization.

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, List]]

get_identity_metric() Tensor[source]

Gets the identity metric.

The identity metric is calculated as 1 minus the mean of the mean squared error between the group element matrix representation and the identity matrix.

Returns:

The identity metric.

Return type:

torch.Tensor

get_prior_regularization_loss() Tensor[source]

Gets the prior regularization loss.

The prior regularization loss is calculated as the mean squared error between the group element matrix representation and the identity matrix.

Returns:

The prior regularization loss.

Return type:

torch.Tensor

invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]

Inverts the canonicalization.

Parameters:
  • x_canonicalized_out (torch.Tensor) – The canonicalized output.

  • **kwargs – Additional arguments.

Returns:

The output for the original data orientation.

Return type:

torch.Tensor

class equiadapt.common.DiscreteGroupCanonicalization(canonicalization_network: Module, beta: float = 1.0, gradient_trick: str = 'straight_through')[source]

Bases: BaseCanonicalization

This class represents a discrete group canonicalization method.

Discrete group canonicalization is a method that transforms the input data into a canonical form using a discrete group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for discrete group canonicalization.

canonicalization_network

The network used for canonicalization.

Type:

torch.nn.Module

beta

A parameter for the softmax function. Defaults to 1.0.

Type:

float

gradient_trick

The method used for backpropagation through the discrete operation. Defaults to “straight_through”.

Type:

str

__init__()[source]

Initializes the DiscreteGroupCanonicalization instance.

groupactivations_to_groupelementonehot()[source]

Converts group activations to one-hot encoded group elements in a differentiable manner.

canonicalize()[source]

Canonicalizes the input data.

invert_canonicalization()[source]

Inverts the canonicalization.

get_prior_regularization_loss()[source]

Gets the prior regularization loss.

get_identity_metric()[source]

Gets the identity metric.

canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

Canonicalizes the input data.

Parameters:
  • x (torch.Tensor) – The input data.

  • targets (List, optional) – Additional targets that need to be canonicalized.

  • **kwargs – Additional arguments.

Returns:

The canonicalized input data and targets. Simultaneously, it updates a dictionary containing the information about the canonicalization.

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, List]]

get_identity_metric() Tensor[source]

Gets the identity metric.

Returns:

The identity metric.

Return type:

torch.Tensor

get_prior_regularization_loss() Tensor[source]

Gets the prior regularization loss.

Returns:

The prior regularization loss.

Return type:

torch.Tensor

groupactivations_to_groupelementonehot(group_activations: Tensor) Tensor[source]

Converts group activations to one-hot encoded group elements in a differentiable manner.

Parameters:

group_activations (torch.Tensor) – The activations for each group element.

Returns:

The one-hot encoding of the group elements.

Return type:

torch.Tensor

invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]

Inverts the canonicalization.

Parameters:
  • x_canonicalized_out (torch.Tensor) – The canonicalized output.

  • **kwargs – Additional arguments.

Returns:

The output for the original data orientation.

Return type:

torch.Tensor

class equiadapt.common.IdentityCanonicalization(canonicalization_network: Module = Identity())[source]

Bases: BaseCanonicalization

This class represents an identity canonicalization method.

Identity canonicalization is a no-op; it doesn’t change the input data. It’s useful as a default or placeholder when no other canonicalization method is specified.

canonicalization_network

The network used for canonicalization. Defaults to torch.nn.Identity.

Type:

torch.nn.Module

__init__()[source]

Initializes the IdentityCanonicalization instance.

canonicalize()[source]

Canonicalizes the input data. In this class, it returns the input data unchanged.

canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

Canonicalize the input data.

This method takes the input data and returns it unchanged, along with the targets if provided. It’s a no-op in the IdentityCanonicalization class.

Parameters:
  • x – The input data.

  • targets – (Optional) Additional targets that need to be canonicalized.

  • **kwargs – Additional arguments.

Returns:

A tuple containing the unchanged input data and targets if targets are provided, otherwise just the unchanged input data.

get_identity_metric() Tensor[source]

Gets the identity metric.

For the IdentityCanonicalization class, this is always 1.

Returns:

A tensor containing the value 1.

Return type:

torch.Tensor

get_prior_regularization_loss() Tensor[source]

Gets the prior regularization loss.

For the IdentityCanonicalization class, this is always 0.

Returns:

A tensor containing the value 0.

Return type:

torch.Tensor

invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]

Inverts the canonicalization.

For the IdentityCanonicalization class, this is a no-op and returns the input unchanged.

Parameters:
  • x_canonicalized_out (torch.Tensor) – The canonicalized output.

  • **kwargs – Additional arguments.

Returns:

The unchanged x_canonicalized_out.

Return type:

torch.Tensor

class equiadapt.common.LieParameterization(group_type: str, group_dim: int)[source]

Bases: Module

A class for parameterizing Lie groups and their representations for a single block.

Parameters:
  • group_type (str) – The type of Lie group (e.g., ‘SOn’, ‘SEn’, ‘On’, ‘En’).

  • group_dim (int) – The dimension of the Lie group.

group_type

Type of Lie group.

Type:

str

group_dim

Dimension of the Lie group.

Type:

int

get_en_rep(params: Tensor, reflect_indicators: Tensor) Tensor[source]

Computes the representation for E(n) group, including both rotations and translations.

Parameters:

params (torch.Tensor) – Input parameters of shape (batch_size, param_dim), where the first part corresponds to rotation/reflection parameters and the last ‘n’ parameters correspond to translation.

Returns:

The representation of shape (batch_size, rep_dim, rep_dim).

Return type:

torch.Tensor

get_group_rep(params: Tensor) Tensor[source]

Computes the representation for the specified Lie group.

Parameters:

params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).

Returns:

The group representation of shape (batch_size, rep_dim, rep_dim).

Return type:

torch.Tensor

get_on_rep(params: Tensor, reflect_indicators: Tensor) Tensor[source]

Computes the representation for O(n) group, optionally including reflections.

Parameters:
  • params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).

  • reflect_indicators (torch.Tensor) – Indicators of whether to reflect, of shape (batch_size, 1).

Returns:

The representation of shape (batch_size, rep_dim, rep_dim).

Return type:

torch.Tensor

get_sen_rep(params: Tensor) Tensor[source]

Computes the representation for SEn group.

Parameters:

params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).

Returns:

The representation of shape (batch_size, rep_dim, rep_dim).

Return type:

torch.Tensor

get_son_bases() Tensor[source]

Generates the basis of the Lie group of SOn.

Returns:

The son basis of shape (num_params, group_dim, group_dim).

Return type:

torch.Tensor

get_son_rep(params: Tensor) Tensor[source]

Computes the representation for SOn group.

Parameters:

params (torch.Tensor) – Input parameters of shape (batch_size, param_dim).

Returns:

The representation of shape (batch_size, rep_dim, rep_dim).

Return type:

torch.Tensor

equiadapt.common.gram_schmidt(vectors: Tensor) Tensor[source]

Applies the Gram-Schmidt process to orthogonalize a set of three vectors in a batch-wise manner.

Parameters:

vectors (torch.Tensor) – A batch of vectors of shape (batch_size, n_vectors, vector_dim), where n_vectors is the number of vectors to orthogonalize (here 3).

Returns:

The orthogonalized vectors of the same shape as the input.

Return type:

torch.Tensor

Examples

>>> vectors = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]])
>>> result = gram_schmidt(vectors)
>>> print(result)
tensor([[[1.0000, 0.0000, 0.0000],
         [0.0000, 1.0000, 0.0000],
         [0.0000, 0.0000, 1.0000]]])