import math
from typing import Any, Dict, List, Optional, Tuple, Union
import kornia as K
import torch
from omegaconf import DictConfig
from torch.nn import functional as F
from torchvision import transforms
from equiadapt.common.basecanonicalization import DiscreteGroupCanonicalization
from equiadapt.images.utils import (
flip_boxes,
flip_masks,
get_action_on_image_features,
rotate_boxes,
rotate_masks,
)
[docs]
class DiscreteGroupImageCanonicalization(DiscreteGroupCanonicalization):
"""
This class represents a discrete group image canonicalization model.
The model is designed to be equivariant under a discrete group of transformations, which can include rotations and reflections.
Other discrete group canonicalizers can be derived from this class.
Methods:
__init__: Initializes the DiscreteGroupImageCanonicalization instance.
groupactivations_to_groupelement: Takes the activations for each group element as input and returns the group element.
get_groupelement: Maps the input image to a group element.
transformations_before_canonicalization_network_forward: Applies transformations to the input images before passing it through the canonicalization network.
canonicalize: Canonicalizes the input images.
invert_canonicalization: Inverts the canonicalization of the output of the canonicalized image.
"""
[docs]
def __init__(
self,
canonicalization_network: torch.nn.Module,
canonicalization_hyperparams: DictConfig,
in_shape: tuple,
):
"""
Initializes the DiscreteGroupImageCanonicalization instance.
Args:
canonicalization_network (torch.nn.Module): The canonicalization network.
canonicalization_hyperparams (DictConfig): The hyperparameters for the canonicalization process.
in_shape (tuple): The shape of the input images.
"""
super().__init__(canonicalization_network)
self.beta = canonicalization_hyperparams.beta
assert (
len(in_shape) == 3
), "Input shape should be in the format (channels, height, width)"
# Define all the image transformations here which are used during canonicalization
# pad and crop the input image if it is not rotated MNIST
is_grayscale = in_shape[0] == 1
self.pad = (
torch.nn.Identity()
if is_grayscale
else transforms.Pad(math.ceil(in_shape[-1] * 0.5), padding_mode="edge")
)
self.crop = (
torch.nn.Identity()
if is_grayscale
else transforms.CenterCrop((in_shape[-2], in_shape[-1]))
)
self.crop_canonization = (
torch.nn.Identity()
if is_grayscale
else transforms.CenterCrop(
(
math.ceil(
in_shape[-2] * canonicalization_hyperparams.input_crop_ratio
),
math.ceil(
in_shape[-1] * canonicalization_hyperparams.input_crop_ratio
),
)
)
)
self.resize_canonization = (
torch.nn.Identity()
if is_grayscale
else transforms.Resize(size=canonicalization_hyperparams.resize_shape)
)
[docs]
def groupactivations_to_groupelement(self, group_activations: torch.Tensor) -> dict:
"""
This method takes the activations for each group element as input and returns the group element
Args:
group_activations (torch.Tensor): activations for each group element.
Returns:
dict: group element.
"""
# convert the group activations to one hot encoding of group element
# this conversion is differentiable and will be used to select the group element
group_elements_one_hot = self.groupactivations_to_groupelementonehot(
group_activations
)
angles = torch.linspace(0.0, 360.0, self.num_rotations + 1)[
: self.num_rotations
].to(self.device)
group_elements_rot_comp = (
torch.cat([angles, angles], dim=0)
if self.group_type == "roto-reflection"
else angles
)
group_element_dict = {}
group_element_rot_comp = torch.sum(
group_elements_one_hot * group_elements_rot_comp, dim=-1
)
group_element_dict["rotation"] = group_element_rot_comp
if self.group_type == "roto-reflection":
reflect_identifier_vector = torch.cat(
[torch.zeros(self.num_rotations), torch.ones(self.num_rotations)], dim=0
).to(self.device)
group_element_reflect_comp = torch.sum(
group_elements_one_hot * reflect_identifier_vector, dim=-1
)
group_element_dict["reflection"] = group_element_reflect_comp
return group_element_dict
[docs]
def get_group_activations(self, x: torch.Tensor) -> torch.Tensor:
"""
Gets the group activations for the input images.
Args:
x (torch.Tensor): The input images.
Returns:
torch.Tensor: The group activations.
"""
raise NotImplementedError(
"get_group_activations is not implemented for"
"the DiscreteGroupImageCanonicalization class"
)
[docs]
def get_groupelement(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Maps the input image to a group element.
Args:
x (torch.Tensor): The input images.
Returns:
dict[str, torch.Tensor]: The corresponding group elements.
"""
group_activations = self.get_group_activations(x)
group_element_dict = self.groupactivations_to_groupelement(group_activations)
# Check whether canonicalization_info_dict is already defined
if not hasattr(self, "canonicalization_info_dict"):
self.canonicalization_info_dict = {}
self.canonicalization_info_dict["group_element"] = group_element_dict # type: ignore
self.canonicalization_info_dict["group_activations"] = group_activations
return group_element_dict
[docs]
def canonicalize(
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
"""
Canonicalizes the input images.
Args:
x (torch.Tensor): The input images.
targets (Optional[List], optional): The targets for instance segmentation. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
Union[torch.Tensor, Tuple[torch.Tensor, List]]: The canonicalized image, and optionally the targets.
"""
self.device = x.device
group_element_dict = self.get_groupelement(x)
x = self.pad(x)
if "reflection" in group_element_dict.keys():
reflect_indicator = group_element_dict["reflection"][:, None, None, None]
x = (1 - reflect_indicator) * x + reflect_indicator * K.geometry.hflip(x)
x = K.geometry.rotate(x, -group_element_dict["rotation"])
x = self.crop(x)
if targets:
# canonicalize the targets (for instance segmentation, masks and boxes)
image_width = x.shape[-1]
if "reflection" in group_element_dict.keys():
# flip masks and boxes
for t in range(len(targets)):
targets[t]["boxes"] = flip_boxes(targets[t]["boxes"], image_width)
targets[t]["masks"] = flip_masks(targets[t]["masks"])
# rotate masks and boxes
for t in range(len(targets)):
targets[t]["boxes"] = rotate_boxes(
targets[t]["boxes"], group_element_dict["rotation"][t], image_width
)
targets[t]["masks"] = rotate_masks(
targets[t]["masks"], -group_element_dict["rotation"][t].item() # type: ignore
)
return x, targets
return x
[docs]
def invert_canonicalization(
self, x_canonicalized_out: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
"""
Inverts the canonicalization of the output of the canonicalized image.
Args:
x_canonicalized_out (torch.Tensor): The output of the canonicalized image.
**kwargs (Any): Additional keyword arguments.
Returns:
torch.Tensor: The output corresponding to the original image.
"""
induced_rep_type = kwargs.get("induced_rep_type", "regular")
return get_action_on_image_features(
feature_map=x_canonicalized_out,
group_info_dict=self.group_info_dict,
group_element_dict=self.canonicalization_info_dict["group_element"], # type: ignore
induced_rep_type=induced_rep_type,
)
[docs]
class GroupEquivariantImageCanonicalization(DiscreteGroupImageCanonicalization):
"""
This class represents a discrete group equivariant image canonicalization model.
The model is designed to be equivariant under a discrete group of transformations, which can include rotations and reflections.
Methods:
__init__: Initializes the GroupEquivariantImageCanonicalization instance.
get_group_activations: Gets the group activations for the input images.
"""
[docs]
def __init__(
self,
canonicalization_network: torch.nn.Module,
canonicalization_hyperparams: DictConfig,
in_shape: tuple,
):
"""
Initializes the GroupEquivariantImageCanonicalization instance.
Args:
canonicalization_network (torch.nn.Module): The canonicalization network.
canonicalization_hyperparams (DictConfig): The hyperparameters for the canonicalization process.
in_shape (tuple): The shape of the input images.
"""
super().__init__(
canonicalization_network, canonicalization_hyperparams, in_shape
)
self.group_type = canonicalization_network.group_type
self.num_rotations = canonicalization_network.num_rotations
self.num_group = (
self.num_rotations
if self.group_type == "rotation"
else 2 * self.num_rotations
)
self.group_info_dict = {
"num_rotations": self.num_rotations,
"num_group": self.num_group,
}
[docs]
def get_group_activations(self, x: torch.Tensor) -> torch.Tensor:
"""
Gets the group activations for the input image.
This method takes an image as input, applies transformations before forwarding it through the canonicalization network,
and then returns the group activations.
Args:
x (torch.Tensor): The input image.
Returns:
torch.Tensor: The group activations.
"""
x = self.transformations_before_canonicalization_network_forward(x)
group_activations = self.canonicalization_network(x)
return group_activations
[docs]
class OptimizedGroupEquivariantImageCanonicalization(
DiscreteGroupImageCanonicalization
):
"""
This class represents an optimized (discrete) group equivariant image canonicalization model.
The model is designed to be equivariant under a discrete group of transformations, which can include rotations and reflections.
Methods:
__init__: Initializes the OptimizedGroupEquivariantImageCanonicalization instance.
rotate_and_maybe_reflect: Rotate and maybe reflect the input images.
group_augment: Augment the input images by applying group transformations (rotations and reflections).
get_group_activations: Gets the group activations for the input images.
get_optimization_specific_loss: Gets the loss specific to the optimization process.
"""
[docs]
def __init__(
self,
canonicalization_network: torch.nn.Module,
canonicalization_hyperparams: DictConfig,
in_shape: tuple,
):
"""
Initializes the OptimizedGroupEquivariantImageCanonicalization instance.
Args:
canonicalization_network (torch.nn.Module): The canonicalization network.
canonicalization_hyperparams (DictConfig): The hyperparameters for the canonicalization process.
in_shape (tuple): The shape of the input images.
"""
super().__init__(
canonicalization_network, canonicalization_hyperparams, in_shape
)
self.group_type = canonicalization_hyperparams.group_type
self.num_rotations = canonicalization_hyperparams.num_rotations
self.artifact_err_wt = canonicalization_hyperparams.artifact_err_wt
self.num_group = (
self.num_rotations
if self.group_type == "rotation"
else 2 * self.num_rotations
)
self.out_vector_size = canonicalization_network.out_vector_size
# group optimization specific cropping and padding (required for group_augment())
group_augment_in_shape = canonicalization_hyperparams.resize_shape
self.crop_group_augment = (
torch.nn.Identity()
if in_shape[0] == 1
else transforms.CenterCrop(group_augment_in_shape)
)
self.pad_group_augment = (
torch.nn.Identity()
if in_shape[0] == 1
else transforms.Pad(
math.ceil(group_augment_in_shape * 0.5), padding_mode="edge"
)
)
self.reference_vector = torch.nn.Parameter(
torch.randn(1, self.out_vector_size),
requires_grad=canonicalization_hyperparams.learn_ref_vec,
)
self.group_info_dict = {
"num_rotations": self.num_rotations,
"num_group": self.num_group,
}
[docs]
def rotate_and_maybe_reflect(
self, x: torch.Tensor, degrees: torch.Tensor, reflect: bool = False
) -> List[torch.Tensor]:
"""
Rotate and maybe reflect the input images.
Args:
x (torch.Tensor): The input image.
degrees (torch.Tensor): The degrees of rotation.
reflect (bool, optional): Whether to reflect the image. Defaults to False.
Returns:
List[torch.Tensor]: The list of rotated and maybe reflected images.
"""
x_augmented_list = []
for degree in degrees:
x_rot = self.pad_group_augment(x)
x_rot = K.geometry.rotate(x_rot, -degree)
if reflect:
x_rot = K.geometry.hflip(x_rot)
x_rot = self.crop_group_augment(x_rot)
x_augmented_list.append(x_rot)
return x_augmented_list
[docs]
def group_augment(self, x: torch.Tensor) -> torch.Tensor:
"""
Augment the input images by applying group transformations (rotations and reflections).
Args:
x (torch.Tensor): The input image.
Returns:
torch.Tensor: The augmented image.
"""
degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1].to(self.device)
x_augmented_list = self.rotate_and_maybe_reflect(x, degrees)
if self.group_type == "roto-reflection":
x_augmented_list += self.rotate_and_maybe_reflect(x, degrees, reflect=True)
return torch.cat(x_augmented_list, dim=0)
[docs]
def get_group_activations(self, x: torch.Tensor) -> torch.Tensor:
"""
Gets the group activations for the input image.
Args:
x (torch.Tensor): The input image.
Returns:
torch.Tensor: The group activations.
"""
x = self.transformations_before_canonicalization_network_forward(x)
x_augmented = self.group_augment(
x
) # size (batch_size * group_size, in_channels, height, width)
vector_out = self.canonicalization_network(
x_augmented
) # size (batch_size * group_size, reference_vector_size)
self.canonicalization_info_dict = {"vector_out": vector_out}
if self.artifact_err_wt:
# select a random rotation for each image in the batch
rotation_indices = torch.randint(
0, self.num_rotations, (x_augmented.shape[0],)
).to(self.device)
# apply the rotation degree to the images
x_dummy = self.pad_group_augment(x_augmented)
x_dummy = K.geometry.rotate(
x_dummy, -rotation_indices * 360 / self.num_rotations
)
x_dummy = self.crop_group_augment(x_dummy)
# invert the image back to the original orientation
x_dummy = self.pad_group_augment(x_dummy)
x_dummy = K.geometry.rotate(
x_dummy, rotation_indices * 360 / self.num_rotations
)
x_dummy = self.crop_group_augment(x_dummy)
vector_out_dummy = self.canonicalization_network(
x_dummy
) # size (batch_size * group_size, reference_vector_size)
self.canonicalization_info_dict.update(
{"vector_out_dummy": vector_out_dummy}
)
scalar_out = F.cosine_similarity(
self.reference_vector.repeat(vector_out.shape[0], 1), vector_out
) # size (batch_size * group_size, 1)
group_activations = scalar_out.reshape(
self.num_group, -1
).T # size (batch_size, group_size)
return group_activations
[docs]
def get_optimization_specific_loss(self) -> torch.Tensor:
"""
Gets the loss specific to the optimization process.
Returns:
torch.Tensor: The loss.
"""
vectors = self.canonicalization_info_dict["vector_out"]
# compute error to reduce rotation artifacts
rotation_artifact_error = 0
if self.artifact_err_wt:
vectors_dummy = self.canonicalization_info_dict["vector_out_dummy"]
rotation_artifact_error = torch.nn.functional.mse_loss(
vectors_dummy, vectors
) # type: ignore
# error to ensure that the vectors are (as much as possible) orthogonal
vectors = vectors.reshape(self.num_group, -1, self.out_vector_size).permute(
(1, 0, 2)
) # (batch_size, group_size, vector_out_size)
distances = vectors @ vectors.permute((0, 2, 1))
mask = 1.0 - torch.eye(self.num_group).to(
self.device
) # (group_size, group_size)
return (
torch.abs(distances * mask).mean()
+ self.artifact_err_wt * rotation_artifact_error
)