Source code for equiadapt.images.canonicalization_networks.custom_nonequivariant_networks

from typing import List

import torch
import torchvision
from torch import nn


[docs] class ConvNetwork(nn.Module): """ This class represents a convolutional neural network. The network consists of a sequence of convolutional layers, each followed by batch normalization and a GELU activation function. The number of output channels of the convolutional layers increases after every third layer. The network ends with a fully connected layer. Methods: __init__: Initializes the ConvNetwork instance. forward: Performs a forward pass through the network. """
[docs] def __init__( self, in_shape: tuple, out_channels: int, kernel_size: int, num_layers: int = 2, out_vector_size: int = 128, ): """ Initializes the ConvNetwork instance. Args: in_shape (tuple): The shape of the input data. It should be a tuple of the form (in_channels, height, width). out_channels (int): The number of output channels of the first convolutional layer. kernel_size (int): The size of the kernel of the convolutional layers. num_layers (int, optional): The number of convolutional layers. Defaults to 2. out_vector_size (int, optional): The size of the output vector of the network. Defaults to 128. """ super().__init__() in_channels = in_shape[0] layers: List[nn.Module] = [] for i in range(num_layers): if i == 0: layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, 2)) elif i % 3 == 2: layers.append( nn.Conv2d(out_channels, 2 * out_channels, kernel_size, 2, 1) ) out_channels *= 2 else: layers.append(nn.Conv2d(out_channels, out_channels, kernel_size, 2)) layers.append(nn.BatchNorm2d(out_channels)) layers.append(nn.GELU()) self.enc_network = nn.Sequential(*layers) out_shape = self.enc_network(torch.zeros(1, *in_shape)).shape # self.scalar_fc = nn.Linear(out_shape[1] * out_shape[2] * out_shape[3], 1) out_dim = out_shape[1] * out_shape[2] * out_shape[3] self.final_fc = nn.Sequential( nn.BatchNorm1d(out_dim), nn.Dropout1d(0.5), nn.ReLU(), nn.Linear(out_dim, out_vector_size), ) self.out_vector_size = out_vector_size
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a forward pass through the network. Args: x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width). Returns: torch.Tensor: The output of the network. It has the shape (batch_size, out_vector_size). """ batch_size = x.shape[0] out = self.enc_network(x) out = out.reshape(batch_size, -1) return self.final_fc(out)
[docs] class ResNet18Network(nn.Module): """ This class represents a neural network based on the ResNet-18 architecture. The network uses a pre-trained ResNet-18 model without its weights. The final fully connected layer of the ResNet-18 model is replaced with a new fully connected layer. Attributes: resnet18 (torchvision.models.ResNet): The ResNet-18 model. out_vector_size (int): The size of the output vector of the network. """ def __init__( self, in_shape: tuple, out_channels: int, kernel_size: int, num_layers: int = 2, out_vector_size: int = 128, ): """ Initializes the ResNet18Network instance. Args: in_shape (tuple): The shape of the input data. It should be a tuple of the form (in_channels, height, width). out_channels (int): The number of output channels of the first convolutional layer. kernel_size (int): The size of the kernel of the convolutional layers. num_layers (int, optional): The number of convolutional layers. Defaults to 2. out_vector_size (int, optional): The size of the output vector of the network. Defaults to 128. """ super().__init__() self.resnet18 = torchvision.models.resnet18(weights="DEFAULT") self.resnet18.fc = nn.Sequential( nn.Linear(512, out_vector_size), ) self.out_vector_size = out_vector_size
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a forward pass through the network. Args: x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width). Returns: torch.Tensor: The output of the network. It has the shape (batch_size, 1). """ return self.resnet18(x)
[docs] class WideResNet101Network(nn.Module): """ This class represents a neural network based on the WideResNetNetwork architecture. The network uses a pre-trained WideResNet model. The final fully connected layer of the WideResNet101 model is replaced with a new fully connected layer. Attributes: resnet18 (torchvision.models.ResNet): The ResNet-18 model. out_vector_size (int): The size of the output vector of the network. """ def __init__( self, in_shape: tuple, out_channels: int, kernel_size: int, num_layers: int = 2, out_vector_size: int = 128, ): """ Initializes the ResNet18Network instance. Args: in_shape (tuple): The shape of the input data. It should be a tuple of the form (in_channels, height, width). out_channels (int): The number of output channels of the first convolutional layer. kernel_size (int): The size of the kernel of the convolutional layers. num_layers (int, optional): The number of convolutional layers. Defaults to 2. out_vector_size (int, optional): The size of the output vector of the network. Defaults to 128. """ super().__init__() self.wideresnet = torchvision.models.wide_resnet101_2(weights="DEFAULT") self.wideresnet.fc = nn.Sequential( nn.Linear(2048, out_vector_size), ) self.out_vector_size = out_vector_size
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a forward pass through the network. Args: x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width). Returns: torch.Tensor: The output of the network. It has the shape (batch_size, 1). """ return self.wideresnet(x)
[docs] class WideResNet50Network(nn.Module): """ This class represents a neural network based on the WideResNetNetwork architecture. The network uses a pre-trained WideResNet model. The final fully connected layer of the WideResNet50 model is replaced with a new fully connected layer. Attributes: resnet18 (torchvision.models.ResNet): The ResNet-18 model. out_vector_size (int): The size of the output vector of the network. """ def __init__( self, in_shape: tuple, out_channels: int, kernel_size: int, num_layers: int = 2, out_vector_size: int = 128, ): """ Initializes the ResNet18Network instance. Args: in_shape (tuple): The shape of the input data. It should be a tuple of the form (in_channels, height, width). out_channels (int): The number of output channels of the first convolutional layer. kernel_size (int): The size of the kernel of the convolutional layers. num_layers (int, optional): The number of convolutional layers. Defaults to 2. out_vector_size (int, optional): The size of the output vector of the network. Defaults to 128. """ super().__init__() self.wideresnet = torchvision.models.wide_resnet50_2(weights="DEFAULT") self.wideresnet.fc = nn.Sequential( nn.Linear(2048, out_vector_size), ) self.out_vector_size = out_vector_size
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a forward pass through the network. Args: x (torch.Tensor): The input data. It should have the shape (batch_size, in_channels, height, width). Returns: torch.Tensor: The output of the network. It has the shape (batch_size, 1). """ return self.wideresnet(x)