r"""
Perceiver Model
===============
This file has code on the neural network of the pervceiver architecture. ``gperc.models.Perceiver`` sits at
the heart of this project.
"""
import torch
from torch import nn
[docs]def build_position_encoding(position_encoding_type, config, num_index_items, emb_dim):
r"""Get the positional encoding matrix. If ``position_encoding_type == "trainable"`` then a random normal
matrix is returned, if it is "sinusoid" then
Args:
position_encoding_type (str): type of embedding, should be one of "trainable", "sinusoid"
config: ``gperc.PerceiverConfig``
num_index_items (int): number of items in the embedding, eg. ``vocab_size``
emb_dim (int): embedding dimension
Returns:
``torch.nn.Parameter``: Item that can be used as a parameter in a ``torch.nn.Embedding``
"""
if position_encoding_type == "trainable":
# first define the positional encodings
latent_pos_emb = nn.Parameter(torch.normal(mean=0.0, std=config.pos_init_std, size=(num_index_items, emb_dim)))
return latent_pos_emb
elif position_encoding_type == "sinusoid":
# then define the positional encodings
def get_pos_encoding(position):
return torch.sin(position / (10000 ** (2 * (position / num_index_items))))
return nn.Parameter(torch.stack([get_pos_encoding(i) for i in range(num_index_items)]).unsqueeze(0))
[docs]class Block(nn.Module):
def __init__(self, kv_dim, q_dim, num_heads, ffw_dim, dropout=0.0, add_residual=False):
r"""Generic block with Attention and MLP layers
Args:
kv_dim (int): dimension of the key-value embeddings
q_dim (int): dimension of the query embeddings
num_heads (int): number of heads in the multihead attention
ffw_dim (int): dimension of the feed-forward layer
dropout (float, optional): dropout rate
add_residual (bool, optional): whether to add residual to the query
"""
super().__init__()
assert q_dim % num_heads == 0, "Latent Dimension must be divisible by number of heads"
self.kv_dim = kv_dim
self.q_dim = q_dim
self.dim = q_dim
self.num_heads = num_heads
self.ffw_dim = ffw_dim
self.add_residual = add_residual
# layer norm the inputs
self.lnkv = nn.LayerNorm(kv_dim)
self.lnq = nn.LayerNorm(q_dim)
# items for attention
self.fv = nn.Linear(kv_dim, q_dim)
self.fk = nn.Linear(kv_dim, q_dim)
self.fq = nn.Linear(q_dim, q_dim)
self.drop_att = nn.Dropout(dropout)
self.fo = nn.Linear(q_dim, q_dim)
# items for mlp
self.lnqkv = nn.LayerNorm(q_dim)
self.mlp = nn.Sequential(
nn.Linear(q_dim, ffw_dim),
nn.GELU(),
nn.Linear(ffw_dim, q_dim),
)
self.drop_mlp = nn.Dropout(dropout)
[docs] def forward(self, kv, q):
r"""Forward pass of the block that taken in a a key-value tensor and a query tensor and performs
the attention and mlp layers. Since it consumes ``kv`` and ``q`` seperately, the blocks are responisble
for cross attention like features. Returns a
Args:
kv (torch.Tensor): tensor to extract information from
q (torch.Tensor): tensor for querying the information
Returns:
Tuple[torch.Tensor, torch.Tensor]: tuple of output Tensor and Attention matrix
"""
# first layer norm the inputs
# print("kv", kv.shape)
# print("q", q.shape)
_q = self.lnq(q)
_kv = self.lnkv(kv)
# print("q:", q.shape)
# print("kv:", kv.shape)
# then compute the query, key, value and split for multihead attention
Q, K, V = self.fq(_q), self.fk(_kv), self.fv(_kv)
# Q = einops(Q, 'b n d -> b h n m', d = self.q_dim, n = self.num_heads, m = self.q_dim // self.num_heads)
# K = einops.rearrange(K, 'b n d -> b h n m', d = self.kv_dim, n = self.num_heads, m = self.kv_dim // self.num_heads)
# V = einops.rearrange(V, 'b n d -> b h n m', d = self.kv_dim, n = self.num_heads, m = self.kv_dim // self.num_heads)
Q, K, V = tuple(map(lambda x: x.view(x.shape[0], self.num_heads, -1, x.shape[-1] // self.num_heads), (Q, K, V)))
# print("Q:", Q.shape)
# print("K:", K.shape)
# print("V:", V.shape)
# print(K.permute(0, 1, 3, 2).shape)
A = Q @ K.permute(0, 1, 3, 2) * (self.dim ** -0.5) # [b, h, n, e/h] @ [b, h, e/h, m] -> [b, h, n, m]
A = self.drop_att(A.softmax(dim=-1)) # [b, h, n, m]
# print("A:", A.shape)
# print((A @ V).shape)
out = (A @ V).reshape((q.shape[0], -1, self.q_dim)) # [b, h, n, m] @ [b, h, m, e/h] -> [b, h, n, e/h] -> [b, n, e]
# print("out:", out.shape)
out = self.fo(out)
# print("out:", out.shape)
# Optionally include a residual to the query.
# Consider omitting the residual if the semantics of query and output
# are different, e.g. if queries are positions and outputs are pixels.
if self.add_residual:
out = out + q
# print(">>>>>>>>", out.shape)
# now we will pass it through the mlp
out = self.mlp(self.lnqkv(out)) + out
out = self.drop_mlp(out)
return out, A
[docs]class Encoder(nn.Module):
def __init__(self, config):
r"""Generic Encoder Block of the model which takes in ``input_array`` as key-value and \
``latent_array`` as query.
Args:
config: Config
"""
super().__init__()
self.encoder_block = Block(
kv_dim=config.input_dim,
q_dim=config.latent_dim,
num_heads=config.num_heads,
ffw_dim=config.ffw_latent,
dropout=config.dropout,
add_residual=True,
)
[docs] def forward(self, input_array, latent_query):
r"""Performs the forward pass of the encoder block.
Args:
input_array (torch.Tensor): Input array to the Encoder block, read paper for reference
latent_query (torch.Tensor): Latent query to the Encoder block, read paper for reference
Returns:
Tuple[torch.Tensor, List[torch.Tensor]]: The output of the encoder block and the attention matrices
"""
out, A = self.encoder_block(input_array, latent_query)
return out, [A]
[docs]class Processor(nn.Module):
def __init__(self, config):
r"""Generic Processor Block of the model which takes in ``latent_array`` as key-value-query
Args:
config: Config
"""
super().__init__()
self.processors = nn.ModuleList(
[
Block(
kv_dim=config.latent_dim,
q_dim=config.latent_dim,
num_heads=config.num_heads,
ffw_dim=config.ffw_latent,
dropout=config.dropout,
add_residual=True,
)
for _ in range(config.num_layers)
]
)
[docs] def forward(self, x):
r"""Performs the forward pass of the processor block.
Args:
x (torch.Tensor): Input array to the Processor block, this should always be the latent_array
Returns:
Tuple[torch.Tensor, List[torch.Tensor]]: The output of the processor block and the attention matrices
"""
attentions = []
for i, processor in enumerate(self.processors):
x, A = processor(x, x)
attentions.append(A)
return x, attentions
[docs]class Decoder(nn.Module):
def __init__(self, config):
r"""Generic Decoder Block of the model which takes in ``latent_array`` as key-value and \
``output_array`` as query.
Args:
config: Config
"""
super().__init__()
def __check_conditionals():
assert config.decoder_cross_attention or config.decoder_projection, "Must have either cross attention or projection"
if config.decoder_projection:
assert hasattr(config, "n_classes") and config.n_classes, "Must have n_classes > 0 if using projection"
__check_conditionals()
self.config = config
if config.decoder_cross_attention:
self.decoder_block = Block(
kv_dim=config.latent_dim,
q_dim=config.output_dim,
num_heads=config.num_heads,
ffw_dim=config.ffw_latent,
dropout=config.dropout,
add_residual=config.decoder_residual,
)
if config.decoder_projection:
self.projection = nn.Linear(config.latent_dim, config.n_classes)
[docs] def forward(self, latents, decoder_query):
r"""Performs the forward pass of the decoder block.
Args:
latents (torch.Tensor): Latent array to the Decoder block, read paper for reference
decoder_query (torch.Tensor): Output array to the Decoder block, read paper for reference
Returns:
Tuple[torch.Tensor, List[torch.Tensor]]: The output of the decoder block and the attention matrices
"""
attentions = []
if self.config.decoder_cross_attention:
x, A = self.decoder_block(latents, decoder_query)
attentions.append(A)
else:
x = latents.mean(dim=1)
if hasattr(self, "projection"):
x = self.projection(x)
return x, attentions
[docs]class Perceiver(nn.Module):
def __init__(self, config, input_preprocessing=None, output_postprocessing=None):
r"""Unassuming Perceiver Architecture that sits at the heart of this project.
Args:
config: Config
input_preprocessing (Callable, optional): callable object that takes in ``input_array`` and performs \
operation on it.
output_postprocessing (Callable, optional): callable object that takes in ``output_array`` and performs \
operation on it.
"""
super().__init__()
self.config = config
self.input_preprocessing = input_preprocessing
self.output_postprocessing = output_postprocessing
self.pos_emb_latent = build_position_encoding("trainable", config, config.latent_len, config.latent_dim)
self.pos_emb_decoder = build_position_encoding("trainable", config, config.output_len, config.output_dim)
self.encoder = Encoder(config)
self.processor = Processor(config)
self.decoder = Decoder(config)
[docs] def num_parameters(self, include_non_trainable: bool = True):
r"""function that returns the number of parameters in the modle
Args:
include_non_trainable (bool, optional): If true includes tensors that have ``requires_grad=False`` as well
Returns:
int: number of parameters in the model
"""
if include_non_trainable:
return sum(p.numel() for p in self.parameters())
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
[docs] def forward(self, input_array: torch.Tensor, output_array: torch.Tensor = None, return_attentions: bool = False):
r"""Performs the forward pass of the Perceiver.
Args:
input_array (torch.Tensor): Input array to the Perceiver, read paper for reference
output_array (torch.Tensor, optional): Output array to the Perceiver, read paper for reference
return_attentions (bool, optional): If true returns the attention matrices
Returns:
Tuple[torch.Tensor, List[torch.Tensor]] if ``return_attentions`` is True else torch.Tensor: \
The output of the Perceiver and the attention matrices
"""
if self.input_preprocessing:
input_array = self.input_preprocessing(input_array)
assert len(input_array.shape) == 3, "Input array must be of shape (batch_size, input_len, input_dim)"
# enc -> proc -> decode
latent_array = torch.cat([self.pos_emb_latent[None, ...] for _ in range(input_array.shape[0])], dim=0)
latents, enc_att = self.encoder(input_array, latent_array)
latents, proc_att = self.processor(latents)
if output_array is None:
decoder_query = torch.cat([self.pos_emb_decoder[None, ...] for _ in range(latents.shape[0])], dim=0)
else:
decoder_query = output_array + self.pos_emb_decoder[None, ...] # add the positional embedding to output array
out, dec_att = self.decoder(latents, decoder_query)
if self.output_postprocessing:
out = self.output_postprocessing(out)
if return_attentions:
return out, [*enc_att, *proc_att, *dec_att]
else:
return out
# ====== use case specific models ====== #
[docs]class PerceiverMLM(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.emb = torch.nn.Embedding(config.vocab_size, config.input_dim)
self.pos_emb = torch.nn.Parameter(torch.normal(mean=0, std=0.02, size=(config.input_len, config.input_dim)))
self.perc = Perceiver(config)
[docs] def num_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
[docs] def forward(self, x):
pos = torch.cat([self.pos_emb[None, ...] for _ in range(x.shape[0])], dim=0)
x = self.emb(x) + pos
logits = self.perc(x, x)
return logits
[docs]class PerceiverImage(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.emb = build_position_encoding("trainable", config, 1024, 3)
self.perceiver = Perceiver(config)
[docs] def num_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
[docs] def forward(self, x):
pos_emb = torch.cat([self.emb[None, ...] for _ in range(x.shape[0])], dim=0)
out = x + pos_emb
return self.perceiver(out)