Source code for equiadapt.images.canonicalization_networks.custom_group_equivariant_layers

import math

import kornia as K
import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class RotationEquivariantConvLift(nn.Module): """ This class represents a rotation equivariant convolutional layer with lifting. The layer is equivariant to a group of rotations. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias. Methods: __init__: Initializes the RotationEquivariantConvLift instance. get_rotated_weights: Returns the weights of the layer after rotation. forward: Performs a forward pass through the layer. """
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: int, num_rotations: int = 4, stride: int = 1, padding: int = 0, bias: bool = True, device: str = "cuda", ): """ Initializes the RotationEquivariantConvLift instance. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. kernel_size (int): The size of the kernel. num_rotations (int, optional): The number of rotations in the group. Defaults to 4. stride (int, optional): The stride of the convolution. Defaults to 1. padding (int, optional): The padding of the convolution. Defaults to 0. bias (bool, optional): Whether to include a bias term. Defaults to True. device (str, optional): The device to run the layer on. Defaults to "cuda". """ super().__init__() self.weights = nn.Parameter( torch.empty(out_channels, in_channels, kernel_size, kernel_size).to(device) ) torch.nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) if bias: self.bias = nn.Parameter(torch.empty(out_channels).to(device)) torch.nn.init.zeros_(self.bias) else: self.bias = None # type: ignore self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.padding = padding self.num_rotations = num_rotations self.kernel_size = kernel_size
[docs] def get_rotated_weights( self, weights: torch.Tensor, num_rotations: int = 4 ) -> torch.Tensor: """ Returns the weights of the layer after rotation. Args: weights (torch.Tensor): The weights of the layer. num_rotations (int, optional): The number of rotations in the group. Defaults to 4. Returns: torch.Tensor: The weights after rotation. """ device = weights.device weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1) rotated_weights = K.geometry.rotate( weights, torch.linspace(0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32)[ :num_rotations ].to(device), ) rotated_weights = rotated_weights.reshape( self.num_rotations, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size, ).transpose(0, 1) return rotated_weights.flatten(0, 1)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a forward pass through the layer. Args: x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width). Returns: torch.Tensor: The output of the layer. It has the shape (batch_size, out_channels, num_rotations, height, width). """ batch_size = x.shape[0] rotated_weights = self.get_rotated_weights(self.weights, self.num_rotations) # shape (out_channels * num_rotations, in_channels, kernel_size, kernel_size) x = F.conv2d(x, rotated_weights, stride=self.stride, padding=self.padding) x = x.reshape( batch_size, self.out_channels, self.num_rotations, x.shape[2], x.shape[3] ) if self.bias is not None: x = x + self.bias[None, :, None, None, None] return x
[docs] class RotoReflectionEquivariantConvLift(nn.Module): """ This class represents a roto-reflection equivariant convolutional layer with lifting. The layer is equivariant to a group of rotations and reflections. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias. Methods: __init__: Initializes the RotoReflectionEquivariantConvLift instance. get_rotoreflected_weights: Returns the weights of the layer after rotation, reflection, and permutation. forward: Performs a forward pass through the layer. """
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: int, num_rotations: int = 4, stride: int = 1, padding: int = 0, bias: bool = True, device: str = "cuda", ): """ Initializes the RotoReflectionEquivariantConvLift instance. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. kernel_size (int): The size of the kernel. num_rotations (int, optional): The number of rotations in the group. Defaults to 4. stride (int, optional): The stride of the convolution. Defaults to 1. padding (int, optional): The padding of the convolution. Defaults to 0. bias (bool, optional): Whether to include a bias term. Defaults to True. device (str, optional): The device to run the layer on. Defaults to "cuda". """ super().__init__() num_group_elements = 2 * num_rotations self.weights = nn.Parameter( torch.empty(out_channels, in_channels, kernel_size, kernel_size).to(device) ) torch.nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) if bias: self.bias = nn.Parameter(torch.empty(out_channels).to(device)) torch.nn.init.zeros_(self.bias) else: self.bias = None # type: ignore self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.padding = padding self.num_rotations = num_rotations self.kernel_size = kernel_size self.num_group_elements = num_group_elements
[docs] def get_rotoreflected_weights( self, weights: torch.Tensor, num_rotations: int = 4 ) -> torch.Tensor: """ Returns the weights of the layer after rotation and reflection. Args: weights (torch.Tensor): The weights of the layer. num_rotations (int, optional): The number of rotations in the group. Defaults to 4. Returns: torch.Tensor: The weights after rotation, reflection, and permutation. """ device = weights.device weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1) rotated_weights = K.geometry.rotate( weights, torch.linspace(0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32)[ :num_rotations ].to(device), ) reflected_weights = K.geometry.hflip(rotated_weights) rotoreflected_weights = torch.cat([rotated_weights, reflected_weights], dim=0) rotoreflected_weights = rotoreflected_weights.reshape( self.num_group_elements, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size, ).transpose(0, 1) return rotoreflected_weights.flatten(0, 1)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a forward pass through the layer. Args: x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width). Returns: torch.Tensor: The output of the layer. It has the shape (batch_size, out_channels, num_group_elements, height, width). """ batch_size = x.shape[0] rotoreflected_weights = self.get_rotoreflected_weights( self.weights, self.num_rotations ) # shape (out_channels * num_group_elements, in_channels, kernel_size, kernel_size) x = F.conv2d(x, rotoreflected_weights, stride=self.stride, padding=self.padding) x = x.reshape( batch_size, self.out_channels, self.num_group_elements, x.shape[2], x.shape[3], ) if self.bias is not None: x = x + self.bias[None, :, None, None, None] return x
[docs] class RotationEquivariantConv(nn.Module): """ This class represents a rotation equivariant convolutional layer. The layer is equivariant to a group of rotations. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias. Methods: __init__: Initializes the RotationEquivariantConv instance. get_rotated_permuted_weights: Returns the weights of the layer after rotation and permutation. forward: Performs a forward pass through the layer. """
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: int, num_rotations: int = 4, stride: int = 1, padding: int = 0, bias: bool = True, device: str = "cuda", ): """ Initializes the RotationEquivariantConv instance. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. kernel_size (int): The size of the kernel. num_rotations (int, optional): The number of rotations in the group. Defaults to 4. stride (int, optional): The stride of the convolution. Defaults to 1. padding (int, optional): The padding of the convolution. Defaults to 0. bias (bool, optional): Whether to include a bias term. Defaults to True. device (str, optional): The device to run the layer on. Defaults to "cuda". """ super().__init__() self.weights = nn.Parameter( torch.empty( out_channels, in_channels, num_rotations, kernel_size, kernel_size ).to(device) ) torch.nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) if bias: self.bias = nn.Parameter(torch.empty(out_channels).to(device)) torch.nn.init.zeros_(self.bias) else: self.bias = None # type: ignore self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.padding = padding self.num_rotations = num_rotations self.kernel_size = kernel_size indices = ( torch.arange(num_rotations) .view((1, 1, num_rotations, 1, 1)) .repeat( num_rotations, out_channels * in_channels, 1, kernel_size, kernel_size ) ) self.permute_indices_along_group = ( (indices - torch.arange(num_rotations)[:, None, None, None, None]) % num_rotations ).to(device) self.angle_list = torch.linspace( 0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32 )[:num_rotations].to(device)
[docs] def get_rotated_permuted_weights( self, weights: torch.Tensor, num_rotations: int = 4 ) -> torch.Tensor: """ Returns the weights of the layer after rotation and permutation. Args: weights (torch.Tensor): The weights of the layer. num_rotations (int, optional): The number of rotations in the group. Defaults to 4. Returns: torch.Tensor: The weights after rotation and permutation. """ weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1, 1) permuted_weights = torch.gather(weights, 2, self.permute_indices_along_group) rotated_permuted_weights = K.geometry.rotate( permuted_weights.flatten(1, 2), self.angle_list, ) rotated_permuted_weights = ( rotated_permuted_weights.reshape( self.num_rotations, self.out_channels, self.in_channels, self.num_rotations, self.kernel_size, self.kernel_size, ) .transpose(0, 1) .reshape( self.out_channels * self.num_rotations, self.in_channels * self.num_rotations, self.kernel_size, self.kernel_size, ) ) return rotated_permuted_weights
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a forward pass through the layer. Args: x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, num_rotations, height, width). Returns: torch.Tensor: The output of the layer. It has the shape (batch_size, out_channels, num_rotations, height, width). """ batch_size = x.shape[0] x = x.flatten(1, 2) # shape (batch_size, in_channels * num_rotations, height, width) rotated_permuted_weights = self.get_rotated_permuted_weights( self.weights, self.num_rotations ) # shape (out_channels * num_rotations, in_channels * num_rotations, kernal_size, kernal_size) x = F.conv2d( x, rotated_permuted_weights, stride=self.stride, padding=self.padding ) x = x.reshape( batch_size, self.out_channels, self.num_rotations, x.shape[2], x.shape[3] ) if self.bias is not None: x = x + self.bias[None, :, None, None, None] return x
[docs] class RotoReflectionEquivariantConv(nn.Module): """ This class represents a roto-reflection equivariant convolutional layer. The layer is equivariant to a group of rotations and reflections. The weights of the layer are initialized using the Kaiming uniform initialization method. The layer supports optional bias. Methods: __init__: Initializes the RotoReflectionEquivariantConv instance. get_rotoreflected_permuted_weights: Returns the weights of the layer after rotation, reflection, and permutation. forward: Performs a forward pass through the layer. """
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: int, num_rotations: int = 4, stride: int = 1, padding: int = 0, bias: bool = True, device: str = "cuda", ): """ Initializes the RotoReflectionEquivariantConv instance. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. kernel_size (int): The size of the kernel. num_rotations (int, optional): The number of rotations in the group. Defaults to 4. stride (int, optional): The stride of the convolution. Defaults to 1. padding (int, optional): The padding of the convolution. Defaults to 0. bias (bool, optional): Whether to include a bias term. Defaults to True. device (str, optional): The device to run the layer on. Defaults to "cuda". """ super().__init__() num_group_elements: int = 2 * num_rotations self.weights = nn.Parameter( torch.empty( out_channels, in_channels, num_group_elements, kernel_size, kernel_size ).to(device) ) torch.nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) if bias: self.bias = nn.Parameter(torch.empty(out_channels).to(device)) torch.nn.init.zeros_(self.bias) else: self.bias = None # type: ignore self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.padding = padding self.num_rotations = num_rotations self.kernel_size = kernel_size self.num_group_elements = num_group_elements indices = ( torch.arange(num_rotations) .view((1, 1, num_rotations, 1, 1)) .repeat( num_rotations, out_channels * in_channels, 1, kernel_size, kernel_size ) ) self.permute_indices_along_group = ( indices - torch.arange(num_rotations)[:, None, None, None, None] ) % num_rotations self.permute_indices_along_group_inverse = ( indices + torch.arange(num_rotations)[:, None, None, None, None] ) % num_rotations self.permute_indices_upper_half = torch.cat( [ self.permute_indices_along_group, self.permute_indices_along_group_inverse + num_rotations, ], dim=2, ) self.permute_indices_lower_half = torch.cat( [ self.permute_indices_along_group_inverse + num_rotations, self.permute_indices_along_group, ], dim=2, ) self.permute_indices = torch.cat( [self.permute_indices_upper_half, self.permute_indices_lower_half], dim=0 ).to(device) self.angle_list = torch.cat( [ torch.linspace( 0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32 )[:num_rotations], torch.linspace( 0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32 )[:num_rotations], ] ).to(device)
[docs] def get_rotoreflected_permuted_weights( self, weights: torch.Tensor, num_rotations: int = 4 ) -> torch.Tensor: """ Returns the weights of the layer after rotation, reflection, and permutation. Args: weights (torch.Tensor): The weights of the layer. num_rotations (int, optional): The number of rotations in the group. Defaults to 4. Returns: torch.Tensor: The weights after rotation, reflection, and permutation. """ weights = ( weights.flatten(0, 1) .unsqueeze(0) .repeat(self.num_group_elements, 1, 1, 1, 1) ) # shape (num_group_elements, out_channels * in_channels, num_group_elements, kernel_size, kernel_size) permuted_weights = torch.gather(weights, 2, self.permute_indices) rotated_permuted_weights = K.geometry.rotate( permuted_weights.flatten(1, 2), self.angle_list ) rotoreflected_permuted_weights = torch.cat( [ rotated_permuted_weights[: self.num_rotations], K.geometry.hflip(rotated_permuted_weights[self.num_rotations :]), ] ) rotoreflected_permuted_weights = ( rotoreflected_permuted_weights.reshape( self.num_group_elements, self.out_channels, self.in_channels, self.num_group_elements, self.kernel_size, self.kernel_size, ) .transpose(0, 1) .reshape( self.out_channels * self.num_group_elements, self.in_channels * self.num_group_elements, self.kernel_size, self.kernel_size, ) ) return rotoreflected_permuted_weights
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a forward pass through the layer. Args: x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, num_group_elements, height, width). Returns: torch.Tensor: The output of the layer. It has the shape (batch_size, out_channels, num_group_elements, height, width). """ batch_size = x.shape[0] x = x.flatten(1, 2) # shape (batch_size, in_channels * num_group_elements, height, width) rotoreflected_permuted_weights = self.get_rotoreflected_permuted_weights( self.weights, self.num_rotations ) # shape (out_channels * num_group_elements, in_channels * num_group_elements, kernel_size, kernel_size) x = F.conv2d( x, rotoreflected_permuted_weights, stride=self.stride, padding=self.padding ) x = x.reshape( batch_size, self.out_channels, self.num_group_elements, x.shape[2], x.shape[3], ) if self.bias is not None: x = x + self.bias[None, :, None, None, None] return x