Source code for equiadapt.nbody.canonicalization_networks.custom_group_equivariant_layers

import torch
import torch.nn as nn

EPS = 1e-6


[docs] class VNSoftplus(nn.Module): """Custom module implementing the VN Softplus layer.""" def __init__( self, in_channels: int, share_nonlinearity: bool = False, negative_slope: float = 0.0, ) -> None: """ Initialize the VNSoftplus layer. Args: in_channels (int): Number of input channels. share_nonlinearity (bool, optional): Whether to share the nonlinearity across channels. Defaults to False. negative_slope (float, optional): Negative slope of the LeakyReLU activation. Defaults to 0.0. """ super().__init__() if share_nonlinearity: self.map_to_dir = nn.Linear(in_channels, 1, bias=False) else: self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) self.negative_slope = negative_slope
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the VNSoftplus layer. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor. """ d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) dotprod = (x * d).sum(2, keepdim=True) angle_between = torch.acos( dotprod / ( torch.norm(x, dim=2, keepdim=True) * torch.norm(d, dim=2, keepdim=True) + EPS ) ) mask = torch.cos(angle_between / 2) ** 2 d_norm_sq = (d * d).sum(2, keepdim=True) x_out = self.negative_slope * x + (1 - self.negative_slope) * ( mask * x + (1 - mask) * (x - (dotprod / (d_norm_sq + EPS)) * d) ) return x_out
[docs] class VNLeakyReLU(nn.Module): """Custom module implementing the VN LeakyReLU layer.""" def __init__( self, in_channels: int, share_nonlinearity: bool = False, negative_slope: float = 0.2, ) -> None: """ Initialize the VNLeakyReLU layer. Args: in_channels (int): Number of input channels. share_nonlinearity (bool, optional): Whether to share the nonlinearity across channels. Defaults to False. negative_slope (float, optional): Negative slope of the LeakyReLU activation. Defaults to 0.2. """ super().__init__() if share_nonlinearity: self.map_to_dir = nn.Linear(in_channels, 1, bias=False) else: self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) self.negative_slope = negative_slope
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the VNLeakyReLU layer. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor. """ d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) dotprod = (x * d).sum(2, keepdim=True) mask = (dotprod >= 0).float() d_norm_sq = (d * d).sum(2, keepdim=True) x_out = self.negative_slope * x + (1 - self.negative_slope) * ( mask * x + (1 - mask) * (x - (dotprod / (d_norm_sq + EPS)) * d) ) return x_out