equiadapt.nbody.canonicalization package

Submodules

equiadapt.nbody.canonicalization.euclidean_group module

class equiadapt.nbody.canonicalization.euclidean_group.EuclideanGroupNBody(canonicalization_network: Module)[source]

Bases: ContinuousGroupCanonicalization

A class representing the continuous group for N-body canonicalization.

Parameters:
  • canonicalization_network (torch.nn.Module) – The canonicalization network.

  • canonicalization_hyperparams (dict) – Hyperparameters for the canonicalization.

canonicalization_info_dict

A dictionary containing the group element information.

Type:

dict

canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, Tensor][source]

Canonicalize the input data.

Parameters:
  • nodes – Node attributes.

  • targets – Target data.

  • **kwargs – Additional keyword arguments. Includes locs, edges, vel, edge_attr, and charges.

Returns:

The canonicalized location and velocity.

forward(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]

Forward pass of the continuous group.

Parameters:
  • nodes – Node attributes.

  • **kwargs – Additional keyword arguments. Includes locs, edges, vel, edge_attr, and charges.

Returns:

The result of the canonicalization.

get_groupelement(nodes: Tensor, loc: Tensor, edges: Tensor, vel: Tensor, edge_attr: Tensor, charges: Tensor) Dict[str, Tensor][source]

Get the group element information.

Parameters:
  • nodes – Nodes data.

  • loc – Location data.

  • edges – Edges data.

  • vel – Velocity data.

  • edge_attr – Edge attributes data.

  • charges – Charges data.

Returns:

A dictionary containing the group element information.

invert_canonicalization(x_canonicalized_out: Tensor, **kwargs: Any) Tensor[source]

This method takes as input the canonicalized output and returns the original output.

modified_gram_schmidt(vectors: Tensor) Tensor[source]

Apply the modified Gram-Schmidt process to the input vectors.

Parameters:

vectors – Input vectors.

Returns:

The orthonormalized vectors.

Module contents