equiadapt.nbody.canonicalization package
Submodules
equiadapt.nbody.canonicalization.euclidean_group module
- class equiadapt.nbody.canonicalization.euclidean_group.EuclideanGroupNBody(canonicalization_network: Module)[source]
Bases:
ContinuousGroupCanonicalizationA class representing the continuous group for N-body canonicalization.
- Parameters:
canonicalization_network (torch.nn.Module) – The canonicalization network.
canonicalization_hyperparams (dict) – Hyperparameters for the canonicalization.
- canonicalize(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, Tensor][source]
Canonicalize the input data.
- Parameters:
nodes – Node attributes.
targets – Target data.
**kwargs – Additional keyword arguments. Includes locs, edges, vel, edge_attr, and charges.
- Returns:
The canonicalized location and velocity.
- forward(x: Tensor, targets: List | None = None, **kwargs: Any) Tensor | Tuple[Tensor, List][source]
Forward pass of the continuous group.
- Parameters:
nodes – Node attributes.
**kwargs – Additional keyword arguments. Includes locs, edges, vel, edge_attr, and charges.
- Returns:
The result of the canonicalization.
- get_groupelement(nodes: Tensor, loc: Tensor, edges: Tensor, vel: Tensor, edge_attr: Tensor, charges: Tensor) Dict[str, Tensor][source]
Get the group element information.
- Parameters:
nodes – Nodes data.
loc – Location data.
edges – Edges data.
vel – Velocity data.
edge_attr – Edge attributes data.
charges – Charges data.
- Returns:
A dictionary containing the group element information.