Source code for gperc.configs

r"""
Configs
=======

``PerceiverConfig`` is the final config object that is fed to the model, but it requires knowing
exactly what you need to know about the data and the architecture. For this very purpose, there are
some simpler configs that are more convenient to use in some cases. They are:

* ``TextConfig``: A config that is used for text classification tasks.
* ``ImageConfig``: A config that is used for image tasks, supports ``classification`` and ``segmentation``.

Documentation
-------------
"""

import json
from pprint import pformat
from typing import Callable, Tuple


[docs]class PerceiverConfig: def __init__( self, input_len: int = 64, input_dim: int = 8, latent_len: int = 4, latent_dim: int = 16, output_len: int = 1, output_dim: int = 10, ffw_latent: int = 32, pos_init_std: float = 0.02, num_heads: int = 2, num_layers: int = 2, dropout: float = 0.1, decoder_cross_attention: bool = False, decoder_residual: bool = False, decoder_projection: bool = True, output_pos_enc: bool = False, seed: int = 4, pre_processing: Callable = None, post_processing: Callable = None, **kwargs ): r"""Since perciever is such a powerful and versatile model, we need a good config for this. Different application we will simply define different configurations and wrap them in some model registry-kinda thing. There are many attributes in the config file and the user must understand what they are doing. I highly recommend reading `examples <stories.1.html>`__ before you start working with this. Args: input_len (int, optional): (``m``) The length of the input space input_dim (int, optional): (``c``) The dimension of the input space latent_len (int, optional): (``n``) The length of the latent space latent_dim (int, optional): (``d``) The dimension of the latent space output_len (int, optional): (``o``) The length of the output space output_dim (int, optional): (``e``) The dimension of the output space ffw_latent (int, optional): The dimension of the latent space in the feed-forward pos_init_std (float, optional): The standard deviation of the position encoding num_heads (int, optional): The number of heads in the multi-head attention num_layers (int, optional): The number of layers in the encoder and decoder dropout (float, optional): The dropout rate decoder_cross_attention (bool, optional): Whether to use cross attention in the decoder. If true the output shape will be ``[batch_size, o, e]`` otherwise it will take ``mean`` over the input ``latent_array`` and return ``[batch_size, e]``. decoder_residual (bool, optional): Whether ``output_array`` combines with ``latent_array`` decoder_projection (bool, optional): Whether to use a projection layer in the decoder, used for classification output_pos_enc (bool, optional): Whether to use position encoding in the decoder seed (int, optional): The seed for the random number generator pre_processing (Callable, optional): A function that takes processes the ``input_array`` tensor post_processing (Callable, optional): A function that takes processes the ``output_array`` tensor **kwargs: Any other arguments to be stored in the config """ self.input_len = input_len self.input_dim = input_dim self.latent_len = latent_len self.latent_dim = latent_dim self.output_len = output_len self.output_dim = output_dim self.ffw_latent = ffw_latent self.pos_init_std = pos_init_std self.num_heads = num_heads self.num_layers = num_layers self.dropout = dropout self.decoder_cross_attention = decoder_cross_attention self.decoder_residual = decoder_residual self.decoder_projection = decoder_projection self.output_pos_enc = output_pos_enc self.seed = seed self.pre_processing = pre_processing self.post_processing = post_processing for k, v in kwargs.items(): setattr(self, k, v) def __repr__(self) -> str: return pformat(self.__dict__, indent=2, sort_dicts=True)
[docs] def to_json(self, path): with open(path, "w") as f: json.dump(self.__dict__, f, indent=2, sort_keys=True)
[docs] def from_json(self, path): with open(path, "r") as f: self.__dict__ = json.load(f)
[docs]class TextConfig(PerceiverConfig): def __init__(self, latent_dim: int, vocab_size: int, max_len: int, latent_frac: float, **kwargs): r"""Config class to specially deal with the text modality cases Args: latent_dim (int): The dimension of the latent space vocab_size (int): The size of the vocabulary max_len (int): The maximum length of the input sequence latent_frac (float): ``latent_len`` will be this multiplied by ``max_len`` """ super().__init__(**kwargs) self.vocab_size = vocab_size self.max_len = max_len self.input_len = max_len self.input_dim = latent_dim self.latent_len = int(latent_frac * max_len) self.latent_dim = latent_dim self.output_len = max_len self.output_dim = latent_dim self.decoder_cross_attention = True self.decoder_residual = True self.decoder_projection = True self.n_classes = vocab_size
[docs]class ImageConfig(PerceiverConfig): def __init__(self, image_shape: Tuple, latent_len: int, latent_dim: int, n_classes: int, task: str = "classification", **kwargs): r"""Config class to specially deal with the image modality cases Args: image_shape (Tuple): The shape of the image in [H, W, C] latent_len (int): The length of the latent space latent_dim (int): The dimension of the latent space n_classes (int): The number of classes after the output space task (str, optional): The task to be performed, can be one of ``classification``, and ``segmentation`` """ assert task in ["classification", "segmentation"], "task must be one of 'classification' or 'segmentation'" super().__init__(**kwargs) self.image_shape = image_shape self.input_len = image_shape[0] * image_shape[1] # image if flattened to a fix shape self.input_dim = image_shape[2] self.latent_len = latent_len self.latent_dim = latent_dim self.output_len = 1 self.output_dim = latent_dim self.n_classes = n_classes self.task = task if task == "classification": """When performing a classification task, we do not need to query from the output_array meaning that there is no need for cross_attention or residual connection, but there needs to be a projection layer to the number of classes.""" self.decoder_cross_attention = False self.decoder_residual = False self.decoder_projection = True elif task == "segmentation": """When performing segmentation task, the output_array will query the latent but we should not use the residual connection, and we should use a projection layer to the number of classes. Avoiding residual connection is recommended in the paper.""" self.decoder_cross_attention = True self.decoder_residual = False self.decoder_projection = True self.output_len = image_shape[0] * image_shape[1]