Source code for equiadapt.images.canonicalization_networks.custom_equivariant_networks

from typing import Tuple

import torch
import torch.nn as nn

from .custom_group_equivariant_layers import (
    RotationEquivariantConv,
    RotationEquivariantConvLift,
    RotoReflectionEquivariantConv,
    RotoReflectionEquivariantConvLift,
)


[docs] class CustomEquivariantNetwork(nn.Module): """ This class represents a custom equivariant network. The network is equivariant to a specified group, which can be either the rotation group or the roto-reflection group. The network consists of a sequence of equivariant convolutional layers, each followed by a ReLU activation function. Methods: __init__: Initializes the CustomEquivariantNetwork instance. forward: Performs a forward pass through the network. """
[docs] def __init__( self, in_shape: Tuple[int, int, int, int], out_channels: int, kernel_size: int, group_type: str = "rotation", num_rotations: int = 4, num_layers: int = 1, device: str = "cuda" if torch.cuda.is_available() else "cpu", ): """ Initializes the CustomEquivariantNetwork instance. Args: in_shape (Tuple[int, int, int, int]): The shape of the input data. out_channels (int): The number of output channels. kernel_size (int): The size of the kernel in the convolutional layers. group_type (str, optional): The type of group the network is equivariant to. Defaults to "rotation". num_rotations (int, optional): The number of rotations in the group. Defaults to 4. num_layers (int, optional): The number of layers in the network. Defaults to 1. device (str, optional): The device to run the network on. Defaults to "cuda" if available, otherwise "cpu". """ super().__init__() if group_type == "rotation": layer_list = [ RotationEquivariantConvLift( in_shape[0], out_channels, kernel_size, num_rotations, device=device ) # type: ignore ] for i in range(num_layers - 1): layer_list.append(nn.ReLU()) # type: ignore layer_list.append( RotationEquivariantConv( out_channels, out_channels, 1, num_rotations, device=device ) # type: ignore ) self.eqv_network = nn.Sequential(*layer_list) elif group_type == "roto-reflection": layer_list = [ RotoReflectionEquivariantConvLift( in_shape[0], out_channels, kernel_size, num_rotations, device=device ) # type: ignore ] for i in range(num_layers - 1): layer_list.append(nn.ReLU()) # type: ignore layer_list.append( RotoReflectionEquivariantConv( out_channels, out_channels, 1, num_rotations, device=device ) # type: ignore ) self.eqv_network = nn.Sequential(*layer_list) else: raise ValueError("group_type must be rotation or roto-reflection for now.")
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a forward pass through the network. Args: x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width). Returns: torch.Tensor: The output of the network. It has the shape (batch_size, group_size). """ feature_map = self.eqv_network(x) group_activatiobs = torch.mean(feature_map, dim=(1, 3, 4)) return group_activatiobs