Axial Positional Embedding

A type of positional embedding that is very effective when working with attention networks on multi-dimensional data, or for language models in general.
Install
$ pip install axial-positional-embedding
Usage
import torch
from axial_positional_embedding import AxialPositionalEmbedding
pos_emb = AxialPositionalEmbedding(
dim = 512,
axial_shape = (64, 64),
axial_dims = (256, 256)
)
tokens = torch.randn(1, 1024, 512)
tokens = pos_emb(tokens) + tokens
A continuous version with better extrapolation ability (each axis parameterized by a 2 layer MLP)
import torch
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
pos_emb = ContinuousAxialPositionalEmbedding(
dim = 512,
num_axial_dims = 3
)
tokens = torch.randn(1, 8, 16, 32, 512)
axial_pos_emb = pos_emb((8, 16, 32))
tokens = axial_pos_emb + tokens
Citations
@inproceedings{kitaev2020reformer,
title = {Reformer: The Efficient Transformer},
author = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
booktitle = {International Conference on Learning Representations},
year = {2020},
url = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{ho2019axial,
title = {Axial Attention in Multidimensional Transformers},
author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
year = {2019},
archivePrefix = {arXiv}
}