3.2 Detailed Walk-Through of Efficient Transformer Models
This section delves into the details of several key efficient Transformer models, discussing their pros, cons, and unique talking points. The goal here is not to exhaustively detail all such models, but rather to cover a representative sample of models.
Structure of This Section. We begin by discussing local and fixed patterns models such as the Memory Compressed Transformer [
48] and Image Transformer [
55]. We then discuss the Set Transformers [
43], an early approach for utilizing global model memory. Following which, we move on to models that utilize combinations of patterns such as Sparse Transformers [
9], CCNet [
30], and Axial Transformers [
28]. Next, we discuss Longformer [
5] and ETC [
3], as examples of memory-based Sparse Transformer approaches. Our detailed walkthrough then moves on to models that incorporate learnable patterns (LP) such as Routing Transformers [
62], Reformer [
37] and Sinkhorn Transformers [
74]. After which, we introduce Linformer [
87] and Synthesizers [
73], models that can be considered low-rank factorization approaches. We then discuss models based on kernel approaches such as Performer [
10] and Linear Transformers [
36]. Following which, we discuss the models that are based on segment-based recurrence such as Transformer-XL [
16] and Compressive Transformers [
59]. Finally, we discuss the family of Sparse models which primarily leverage Mixture-of-Experts (MoE) type architectures and conditional computation to achieve computational efficiency. The logical flow of this section is aimed to be loosely chronological instead of categorically organized (with the exception of certain buckets like recurrence or sparsity that are more orthogonal approaches). We believe this is pedagogically helpful.
3.2.1 Memory Compressed Transformer.
Memory Compressed Transformer [
48] is one of the early attempts at modifying Transformers to better handle longer sequences. The modification introduced by Memory Compressed Transformers is in two folds: localizing the attention span and using memory compressed attention.
Local Attention Span. A straightforward solution for dealing with long sequences in Transformers is to limit the attention span to a local neighborhood. Liu et al. [
48] proposed dividing the input sequence into blocks of similar length so that self-attention can be computed within each block independently. This keeps the cost of attention per block constant, thus the number of activations scales linearly with the input length.
Memory-Compressed Attention. The idea behind memory compressed attention is to reduce the number of keys and values using a strided convolution, while the queries remain unchanged. This leads to a reduction in the size of the attention matrix as well as the attention computations based on a compression factor that depends on the kernel size and the strides of the convolution. Memory-compressed attention lets the model exchange the information globally across the input sequence as opposed to local attention.
Computation and Memory Complexity. For a block size of \(b\), the computational and memory cost of self-attention in each block is \(\mathcal {O}(b^2)\). Given there are \(n/b\) blocks, the computational and memory cost of local attention is \(\mathcal {O}(b.n)\). For memory-compressed attention, applying a convolution with kernel size and strides of \(k\), the computational and memory cost of the attention mechanism reduces to \(\mathcal {O}(n \cdot n/k)\).
3.2.2 Image Transformer.
Image Transformer [
55], inspired by convolutional neural networks, restricts the receptive field of self-attention to only local neighborhoods. This helps the model scale up to process larger batch sizes while keeping the likelihood loss tractable. Besides the efficiency, adapting the notion of locality can be a desirable inductive bias for processing images. Image Transformer offers the encoder-decoder architecture, where the encoder generates a contextualized representation for every pixel-channel in the inputs and the decoder autoregressively generates one channel per pixel at each time step.
Localized Attention Span. Limiting the receptive field to a local neighborhood [
54,
55] addresses the issues with the computational and memory costs of running global self-attention on large inputs, but changing the neighborhood per query position would prohibit packing the computations of the self-attention into two matrix multiplications. To avoid that, Image Transformer proposes partitioning the inputs into “query blocks” and their associated “memory blocks”, where for all queries from a single query block, the model attends to the same memory block. There are two different schemes for choosing query blocks and their associated memory block neighborhoods:
1-dimensional local attention and
2-dimensional local attention. Here we briefly explain these schemes in the decoder case.
For the 1-dimensional local attention, the image is flattened in the raster order
4 and partitioned into non-overlapping query blocks
\(Q\) of length
\(l_q\), and for each query block, a memory block
\(M\) is built from the same pixels in the
\(Q\) as well as a fixed number of pixels,
\(l_m\), generated before the query pixel. In 2-dimensional local attention, pixels are generated in raster order. For the 2-dimensional local attention, the image is partitioned into multiple non-overlapping rectangular query blocks of length
\(l_q = w_q \times h_q\). The memory block extends the query block to the top, left
\(h_m\) and
\(w_m\) pixels and to the right
\(w_m\) pixels, so
\(l_m = (w_q \times q_h) + 2 \times (h_m + w_m)\). The query pixel can attend to all other pixels. In the 2-dimensional local attention, pixels in the image are generated one query block after another. Generated blocks are in raster order, as well as generated pixels inside every block. Figure
1 illustrates an overview of the Transformer architecture.
Computational and Memory Complexity. In Image Transformer, the attention matrix has the shape of \(l_q \times m\), where \(l_q\) is the chosen length for the query blocks and \(M\) is the length of the memory block (which is in fact \(l_q + l_m\)). Given that memory blocks do not overlap, we have to compute \(n \times l_q\) attention matrices. Thus, the memory and computational complexity of Image Transformer is \(\mathcal {O}(n\cdot m)\).
Restrictions. Image Transformer, and in general restricting the context in the attention mechanism to a local neighborhood, can decrease the cost of memory and computation at the price of losing the global receptive field. This can be an issue where global information is required to solve the task. Also, local-attention has quadratic complexity with respect to the region length, thereby introducing an extra hyper-parameter in the trade-off between performance and computational complexity.
3.2.3 Set Transformer.
The Set Transformer [
43] adapts the Transformer model for
set-input problems—that is, problems wherein the input is a set of features and the output is some function of this set (and is thereby invariant to the permutation, or ordering, of the input features). The Set Transformer leverages attention to capture interactions between elements of the input set. Furthermore, it applies the idea of
inducing points from the sparse Gaussian process literature to reduce the complexity of attention from quadratic to linear in the size of the input set.
Problems involving sets of objects often have a
permutation invariance property: the target value for the set is the same regardless of the order of the objects in the set. Zaheer et al. [
94] proved that all permutation-invariant functions can be represented by the following functional form:
where the pooling function
\(\text{pool}\) is a simple summation and
\(\phi\) and
\(\rho\) are continuous functions. This form can be interpreted as the composition of an
encoder \(\phi\) and
decoder \(\rho \left(\text{pool}(\cdot)\right)\). While this form is a universal approximator in the space of permutation-invariant functions, it is unclear how well such models fit tasks in practice. The Set Transformer proposes a solution that can be viewed as an encoder and pooled decoder, but where, unlike the form given above, the encoder and decoder can attend to input elements individually and the pooling function is parameterized.
Attention Blocks. The model introduces the following constructs:
Multihead Attention Block (MAB), Set Attention Block (SAB), Induced Set Attention Block (ISAB), and
Pooling by Multihead Attention (PMA). They are defined as follows
Here,
\(X \in \mathbb {R}^{N \times d}\) represents
\(N\) \(d\)-dimensional input/outputs stacked row-wise and
\(\text{rFF}\) is a parameterized feed-forward layer that operates on each row of its input matrix separately.
\(I_m \in \mathbb {R}^{m \times d}\) represents
\(m\) trainable \(d\)-dimensional “inducing points” while
\(S_k \in \mathbb {R}^{k \times d}\) represent
\(k\) trainable
\(d\)-dimensional “seed vectors” (with
\(k\) set to 1 except when
\(k \gt 1\) correlated outputs are needed). The Set Transformer’s encoder is just
\(N\) layers of either SAB or ISAB (with
\(N\) often set to 2 in practice) while its decoder is given by:
It is straightforward to see that both ISAB and SAB are
permutation equivariant —in other words, if the input is permuted in some way then the corresponding output of the block is permuted in exactly the same way. Meanwhile, the pooling layer PMA is permutation invariant. Since functional composition, i.e., layering, preserves these properties, the Set Transformer encoder-decoder combination is permutation invariant.
Efficiency. We can understand the
\(m\) inducing points
\(I_m\) learned in each ISAB layer as a form of static model memory. In addition to reducing the
\(\mathcal {O}(N n^2)\) complexity of the self-attending SAB layer to
\(\mathcal {O}(N m n)\), a reduction particularly valuable when the input set is large, the inducing points effectively encode some global structure that helps explain its inputs. For example, in the problem of
amortized clustering, where one attempts to learn to map an input set of points to the centers of clusters of points inside the set, the inducing points learned could be appropriately distributed so that the encoder can effectively compare query elements with each other implicitly via their proximity to the inducing points.
The trainable \(k\) seeds \(S_k\) used in the pooling layer \(\text{PMA}_k\) can be viewed as static model memory in a similar light, reducing the memory and runtime complexity of the architecture.
3.2.4 Sparse Transformer.
The Sparse Transformer [
9] presents a simple initial attempt to reduce the quadratic complexity of the standard self-attention mechanism. The key idea is to reduce the dense attention matrix to a sparse version by only computing attention on a sparse number of
\(q_i,k_j\) pairs. Sparse Transformer employs fixed attention patterns which are defined by strides and local neighborhoods. Computation is
factorized, wherein local and stride patterns are split amongst the heads.
Local Attention Heads. Half of the heads in the Sparse Transformer are dedicated to local attention.
where
\(A_{ij}\) is the attention weight of
\(q_i,k_j\) and
\(\lfloor \: \rfloor\) denote the floor operation. In this case, we only compute the attention if
\(\lfloor {\,{j}/{N}}\rfloor = \lfloor {i/{N}}\rfloor\) (within the same block).
Strided Attention Heads. The other half of the heads are dedicated to fixed strided patterns. Concretely,
The final result of the factorized sparse attention is visualized in Figure
4. We refer interested readers to [
92] for some additional theoretical analysis about the expressiveness of the Sparse attention mechanism.
Parameter and Memory Complexity. The modification in the self-attention mechanism does not alter the parameter costs of the model since the model still retains the
\(Q,K,V\) transforms from the original Transformer model. The memory complexity of the attention layer is reduced from
\(\mathcal {O}(n^2)\) to
\(\mathcal {O}(n\log n)\).
Restrictions. The Sparse Transformer implementation requires custom GPU kernels to implement a specific block-sparse variant of matrix-matrix-multiplication and cannot be easily implemented on other hardware such as TPUs.
3.2.5 Axial Transformer.
Axial Transformer [
28,
89] uses factorization in a simple yet effective setup for the self-attention mechanism to process large inputs that are organized as multidimensional tensors. Instead of applying attention to the flattened version of the input, Axial Transformer simply applies multiple attentions, each along a single axis of the input tensor. Each attention, in fact, mixes information along a particular axis, while keeping information along other axes independent. Since the length of any single axis is typically much smaller than the total number of elements, Axial Transformer significantly saves computation and memory. Figure
3 provides an illustration of attention applied to 2D input in Image Transformer.
Axial Transformer offers an encoder-decoder architecture. For the decoding, to be able to implement the causal mask, Axial Transformer combines axial attentions with shift operations. For instance, for a model on 2-dimensional tensors, pixels are generated in raster order and to do that, first, the model encodes all pixels through an unmasked row and unmasked column attention. Then, for each row, the model applies an unmasked row and masked column attention to integrate the previously sampled rows. Finally, the model shifts the encoded representation up to make sure the conditioning information satisfies causality, and runs a masked row-attention to sample a new row in the image.
An advantage of Axial Transformer over similar methods like Sparse Transformer is that while it provides the global receptive field, it is straightforward to implement and does not require a custom kernel for an efficient implementation.
Computational and Memory Complexity. In terms of memory and computational complexity, on a square image of size \(N\), Axial Transformer performs the attention computation in \(\mathcal {O}(n \sqrt {n})\), which saves \(\mathcal {O}(\sqrt {n})\) over normal self-attention. For instance, with on square image with \(N\) pixels, organized in a \(b\times b\) grid, Axial Transformer runs \(b\) attention sequences of length \(b\), which is of complexity \(\mathcal {O}(b.b^2)\). In a more general case, for a \(d\)-dimensional tensor of shape \(N = N^{1/d}\times \cdots \times N^{1/d}\), Axial Transformer saves a \(\mathcal {O}(N^{(d-1)/d})\) factor of resources over standard self-attention.
3.2.6 Longformer.
Longformer [
5] is a variant of Sparse Transformer. Its key distinction compared to Sparse Transformer is “Dilated Sliding Windows”, which can enable better long-range coverage without sacrificing sparsity. This is achieved by increasing the receptive fields by having gaps in the attention patterns. The Longformer also gradually increases the receptive field as the model goes deeper, dedicating lower levels for modeling local patterns and upper levels for modeling global patterns.
Global Attention. For classification tasks, Longformer adopts global memory tokens that have access to all input sequences.
Parameter and Memory Complexity. The complexity of the model is reduced from \(\mathcal {O}(n^2)\) to \(\mathcal {O}(nk)\) where \(k\) is the size of the window. When using global attention, the Longformer creates another set of query-key-value projections for this global attention, doubling the cost of the parameters at the attention layer.
3.2.7 Extended Transformer Construction (ETC).
The ETC model [
3] is another variation in the Sparse Transformer family. It introduces a new global-local attention mechanism. There are four components to this new attention mechanism, namely (1) global-to-global (g2g), global-to-local (g2l), local-to-global (l2g), and local-to-local (l2l). Aside from the original input to the model, ETC introduces
\(n_g\) auxiliary tokens as a prefix to the original input sequence. These tokens are regarded as global tokens and take part in global-to-
\(*\) and
\(*\)-to-global attention. The local-to-local component acts as the local attention with a fixed radius of
\(k\). Overall, ETC is quite similar to Longformer in the way it introduces global auxiliary tokens. These tokens are trainable parameters and can be interpreted as a form of model memory that pools across the sequence to collect global sequence information.
Memory and Parameter Complexity. The memory complexity of the ETC model is \(\mathcal {O}(n_{g}^2 + n_{g}N)\), where \(n_g\) is the number of global tokens and \(N\) is the input sequence length.
Restrictions. Intuitively, it is easy to observe that ETC cannot be used for auto-regressive decoding. This is because we are not able to compute causal masks because of the global attention.
3.2.8 Big Bird.
The Big Bird model [
93] is another Transformer for modeling longer sequences and is primarily built on top of ETC [
3]. The Big Bird model is comprised of several key components, namely (1) global tokens, (2) random attention (queries attend to random keys), and (3) fixed patterns (local sliding windows).
Global Attention. Fundamentally, the idea of using global model memory can be traced all the way back to Longformer/ETC and Set Transformer model. Notably, the global model memory in Big Bird is extended to contain tokens within the sequence, instead of simply parameterized model memory. The authors call this the “internal transformer construction (ITC)” in which a subset of indices is selected as global tokens. This can be interpreted as a model-memory-based approach.
Sliding Window Attention. The window-ed attention was first proposed in early local-based attention models (Image Transformer, Compressed Attention and/or Sparse Transformer). In Big Bird, each query attends to \(w/2\) tokens to the left and \(w/2\) tokens to the right. This corresponds to a fixed pattern (FP) approach.
Random Attention. Finally, each query attends to \(r\) random keys. This pattern is fixed.
Memory and Parameter Complexity. The memory complexity of the self-attention is linear, i.e., \(O(n)\). The Big Bird model does not introduce new parameters beyond the Transformer model.
Restrictions. Similar to ETC, the Big Bird model cannot be used to autoregressively decode. Hence, qualifying it as an encoder-only model.
3.2.9 Routing Transformer.
The Routing Transformer [
62] is a content-based sparse attention mechanism. It proposes a clustering-based attention mechanism that learns the attention sparsity in a data driven fashion. The first step is to project
\(Q\) and
\(K\) into a routing matrix
\(R\) of dimensions
\(n \times d\)where
\(W_R\) is a
\(d \times d\) orthonormal projection matrix.
\(k\)-Means Clustering. The
\(R\) matrix undergoes
\(k\)-means clustering with a series of parameterized cluster centroids
\(u_1, u_2 \cdots c_k\). The
\(k\)-means in Routing Transformer is trained in an online fashion. To ensure a similar number of tokens in each cluster, the model initializes
\(\sqrt {n}\) clusters, computes each token’s distance against the cluster centroid, and takes an equal top-
\(k\) for each centroid. Since the cluster centroids are trainable parameters, this is also reminiscent of the
all-attention layer proposed by [
71].
Routing Strategy. The routing strategy is then defined as:
where
\(C_i\) is the cluster that vector
\(R_i\) is assigned to. In other words, the token at
\(i\) only attends to tokens in the same cluster.
Memory and Parameter Complexity. The Routing Transformer introduces additional parameters in the clustering mechanism, namely \(k \times d\) centroid vectors and a \(W_r\) projection matrix. The memory complexity is \(\mathcal {O}(n^{1.5})\).
3.2.10 Reformer.
Reformer [
37] is another efficient attention model based on
locality sensitive hashing (LSH). Reformer also introduces
reversible Transformer layers, which contribute to further reducing its memory footprint.
LSH Attention. The LSH attention introduces parameter-sharing between query and keys. It hashes the query-keys into buckets using a random-projection based hashing function. The key idea is that nearby vectors should obtain a similar hash while distant vectors should not, hence being termed as
“locality sensitive”. To perform hashing, a random matrix
\(R \in \mathbb {R}^{k \times b/2}\) is first introduced. Next, the hashing function is defined as:
where
\([;]\) is the concatenation of two vectors. For all queries, attention is computed if and only if the query and key hashes match, i.e.,
\(h(q_i)=h(k_j)\). In other words, attention is computed amongst query and keys if they fall in the same hash bucket. In order to maintain causal masking, Reformer assigns and maintains a position index for every query and key. It is therefore able to compare if each query key comparison is auto-regressively valid.
Memory Efficiency with LSH Attention. The key idea behind LSH attention is to classify tokens into buckets and then process them bucket by bucket in a chunked fashion. To this end, queries are first sorted by bucket number and then by sequence order within the same bucket. During computation, tokens only attend to the same bucket in its own chunk and previous chunk. The chunking and sorted bucketing techniques help to improve the overall efficiency of the Reformer model.
Parameter and Memory Complexity. The memory complexity of Reformer is \(\mathcal {O}(n \log n)\). In terms of parameter costs, Reformer shares queries and keys, which reduces the cost of the QKV transforms by a third. The random projections are not trainable parameters and hence do not incur parameter costs. Overall, Reformer has fewer parameters than vanilla Transformers. The reversible layers in Reformer also reduce the memory consumption during training by enabling activations to be reconstructed from the next layer’s. This reduces memory cost since this eliminates the need to store activations for all layers during backpropagation.
3.2.11 Sinkhorn Transformers.
This section introduces the Sparse Sinkhorn Transformer [
74]. The Sinkhorn Transformer belongs to the family of
learned patterns. This model is a chunked/ blocked model that learns sparse patterns by re-sorting the input key and values in a block-wise fashion and then applying local block-based attention.
where
\(\psi _S\) applies a sorting operator on the sequence length dimension.
Sorting Network. The sorting operator is parameterized by a meta sorting network. Let
\(X\) be the input sequence of dimension
\(N \times d\):
where
\(F_S(.)\) is a parameterized function such as a two-layer feed-forward network with ReLU activation. The output of
\(F_S(.)\) is a tensor of
\(n_B \times n_B\). The BlockSum function learns the sum embeddings of local blocks. The BlockShape function reshapes the input tensor into
\(\mathbb {R}^{N \times d} \rightarrow \mathbb {R}^{n_B \times b \times d}\). Here, we note that
\(N = n_B \times b\), where
\(b\) is the size of the block and
\(n_B\) is the number of total blocks.
Sinkhorn Sorting.
\(\phi\) is the Sinkhorn balancing operator [
1,
67] which converts the
\(n_B \times n_B\) matrix into a soft permutation matrix. Specifically, a series of row- and column-wise normalizations are applied on the matrix output of
\(F_S\text{BlockSum}(X)\). For the sake of brevity, we do not delve into details of this operation. Further details can be found in Adams and Zemel [
1], and Tay et al. [
74].
Parameter and Memory Complexity. The memory complexity of the Sinkhorn Transformer is \(\mathcal {O}(b^2)\) where \(b\) is the block size and \(b=\frac{N}{N_b}\). Additional parameter costs are incurred from the meta sorting network \(F_S(.)\). The number of additional parameters is therefore \(2d^2\) when a two-layer ReLU network is used as the sorting network.
3.2.12 Linformer.
Linformer [
87] is an efficient Transformer based on the idea of low-rank self-attention.
Low-Rank Projections on Length Dimensions. Linformer projects the
\(N \times d\) dimensional keys and values to
\(k \times d\) dimensions using additional projection layers. Note that this is a reduction on the length dimension instead of the key and value dimensions. This can Given the newly projected keys (
\(K^{\prime }\)) and values (
\(V^{\prime }\)), the
\(QK^{\prime }\) matrix is now
\((N \times k)\) dimensions instead of
\((N \times N)\). The attention matrix
\(\text{Softmax}(QK^{\prime })\) multiplies with
\(V^{\prime } \in \mathbb {R}^{k \times d}\) to result in an output tensor of dimensions
\(N \times d\). To some extent, Linformer is reminiscent of depth-wise convolutions [
35]. A projection on the length dimension causes mixing of sequence information (dimension-wise) in a single transformation. Hence, it is non-trivial to maintain causal masking and/or prevent mixing of past and future information when computing attention scores. The formulation of Linformer (for each attention head) can be expressed as:
where
\(W^{Q,K,V}\) are the default linear transformation of
\(X\) into queries (as per vanilla Transformer) and
\(E_{i}, F_i\) are additional
\(k \times N\) projection of the key and values into
\(k \times d\) tensors.
Parameter and Memory Complexity. The memory complexity of Linformer is \(\mathcal {O}(n)\). There is only a minimal parameter costs of the Linformer due to the extra \(N \times k\) length projections. If \(k\) is sufficiently small, there is negligible parameter costs incurred.
3.2.13 Performer.
The Performer [
10,
11] model is characterized by its Generalized Attention mechanism and its usage of random Kernels.
Generalized Attention. The generalized attention entangles
\(Q_i,K_j\) with a kernel function
\(K\). The attention matrix in Performer is computed via:
where
\(K(.)\) is a kernel function that maps
\(d \times d\) to a scalar value
\(\mathbb {R}\) and
\(g,h\) are functions that map
\(d\) to a scalar value
\(\mathbb {R}\).
Fast Attention via Orthogonal Random Features (FAVOR). The above computation is still quadratic in complexity. Hence, the Performer leverages approximation tricks to avoid storing and computing the
\(N \times N\) attention matrix. It leverages
orthogonal random features (ORF) for doing so. The final attention output
\(Y\) of the Performer is described as follows:
where
\(\hat{D}=\text{diag}(Q^{\prime }((K^{\prime })^\top 1_N))\),
\(Q^{\prime }=D_Q\phi (Q^\top)^\top\), and
\(K^{\prime }=D_K\phi (K^\top)^\top\). Note that
\(D_Q=g(Q_i^\top),D_K=h(K_i^\top)\). The function
\(\phi (x)\) is defined as:
where
\(c \gt 0\) is a constant,
\(W \in \mathbb {R}^{M \times d}\) is a random feature matrix and
\(M\) is the dimensionality of this matrix that controls the number of random features. We are able to see that we do not explicitly compute
\(A=QK^\top\) and hence avoid paying the
\(N^2\) cost. For rigorous theoretical analysis and further details, we refer interested readers to [
10].
Parameter/Memory Complexity and Compute Costs. The complexity of the bi-directional FAVOR algorithm is \(\mathcal {O}(Md + N d + MN)\) where \(M\) is the dimensionality of the random features. It is worth noting that the unidirectional variations cannot be causally masked in an efficient linear-time fashion. As such, during training, running unidirectional (causal) implementation of kernel-based attention on an autoregressive task can be several times slower than vanilla Transformer during parallelized training due to the need to do a left to right pass (i.e., scan operation) in similar spirit to Recurrent neural networks. Since many autoregressive tasks trained via parallelization and teacher forcing, this makes training Performer on a generative task prohibitively slow. In order for KV to be causally masked efficiently, one would have to manifest the \(d \times d\) KV matrix at every time step—recovering a quadratic complexity model. We feel this is one of the intricate points that highlight how efficient memory complexity might not equate a faster or more efficient model in practice. We highlight that this only happens during autoregressive training. The inference-time for incremental decoding, however, would benefit from a speed-up.
3.2.14 Linear Transformer.
The Linear Transformer [
36] improves the complexity of self-attention from quadratic to linear by using a kernel-based formulation of self-attention and the associative property of matrix products. Furthermore, it reduces attention with causal masking (which is used in auto-regressive decoding) to a linear-time, constant memory
recurrent neural network (RNN). The model has been shown to improve inference speeds up to
three orders of magnitude without much loss in predictive performance. Linear Transformers are similar to Performers with the exception of the kernel function and therefore also suffer from the same drawbacks (unable to be parallelized across the time dimension during training in an autoregressive teacher forced setting).
The method rests on the simple but powerful observation that the accumulated value
\(V_i^{\prime }\) for the query
\(Q_i\) in position
\(i\) can be written as:
Here,
\(p = N\) in full, unmasked attention and
\(p = i\) in the case of causal masking. Now, in usual softmax attention,
\(\text{sim}(q, k) = \exp (\frac{q^T k}{\sqrt {d}})\). Linear Transformer, however, expresses the similarity as a kernel function. That is,
\(\text{sim}(q, k) := \phi (q)^T \phi (k)\), where
\(\phi\) is a, possibly high-dimensional, feature map. With this choice, we can rewrite
\(V_i^{\prime }\) as:
For unmasked attention, since
\(p = N\) we only need to compute
\(S_N\) and
\(Z_N\) once and we reuse them for the computation at every position
\(0 \le i \le N\). For causal attention, the
\(S_i\)’s and
\(Z_i\)’s can be viewed as states of an RNN that are updated by the following recurrence relations:
with initial condition
\(S_0 = Z_0 = 0\). If the dimension of the key, query, and values are all
\(d\) and the cost to compute
\(\phi\) is
\(\mathcal {O}(c)\), then the overall run-time complexity of Linear Transformer is
\(\mathcal {O}{(N c d)}\). The authors choose
where
\(\text{elu}(\cdot)\) denotes the exponential linear unit [
13]. With this choice of feature map,
\(c = d\) and the end-to-end complexity of the model is
\(\mathcal {O}(N d^2)\).
3.2.15 Synthesizers.
Synthesizer models [
73] are an attempt to study and investigate the true importance of conditioning within the self-attention mechanism and are also the first attempts at unconditional token-mixing. In Tay et al. [
73], the authors study a synthetic self-attention module in which attention weights are approximated instead of being computed by pairwise dot products. Synthesizers are only implicitly related to efficient Transformers and can be considered more as a MLP-Mixer [
81]. However, the factorized variants can be considered a low-rank efficient Transformer model.
Dense Synthesizers. In the Dense Synthesizer, each token
\(x_i\) is projected to a vector of length
\(N\) using a two-layered non-linear feed-forward network. The computation of the attention matrix
\(A\) is described as:
where
\(X \in \mathbb {R}^{N \times d}\) is the input sequence,
\(W_2 \in \mathbb {R}^{d \times N}, W_1 \in \mathbb {R}^{d \times d}\), and
\(\sigma _R\) is the ReLU activation function. Given
\(A\), the output of the Synthetic Dense function is computed as:
where
\(G(X)\) is another parameterized function
\(\mathbb {R}^{N \times d} \rightarrow \mathbb {R}^{N \times d}\).
Random Synthesizers. Another variant of the Synthesizer model uses random matrices for
\(A\). In this case, the output can be expressed by:
where
\(R \in \mathbb {R}^{N \times N}\) is a trainable and/or non-trainable matrix. In Tay et al. [
73], the authors show that Random Synthesizers achieve competitive performance.
Factorized Variants. The Dense and Random Synthesizers also come with factorized variants that consider a low-rank structure of the attention matrix. For factorized Random Synthesizer, the output can be written as:
where
\(R_{1},R_{2} \in \mathbb {R}^{N \times k}\). On the other hand, the Dense Synthesizer can be factorized as follows:
where
\(F_B(.)\) projects onto
\(b\) dimensions and
\(F_C(.)\) projects
\(X_i\) onto
\(c\) dimensions with
\(c \times b=N\).
\(H_B,H_C\) are tile and repeat functions, respectively.
Parameter and Memory Complexity. For Random Synthesizers that adopt a non-trainable
\(R\), there is no need to store
\(N^2\) activations at this layer. For the trainable Random Synthesizer, the memory complexity and parameter complexity remains as
\(N^2\). However, there is no need to compute
\(N^2\) dot products, reducing the computational costs significantly. The Factorized Random Synthesizers reduce the parameter costs to
\(2(N \times k)\).
3.2.16 Transformer-XL.
The Transformer-XL model [
16] relies on segment-based recurrence. Segment-based recurrence can be considered an orthogonal approach to the other techniques discussed since it does not explicitly sparsify the dense self-attention matrix. Instead, it connects adjacent blocks with a recurrent mechanism.
Segment Recurrence. The recurrent mechanism in Transformer-XL is described as:
where SG() is the stop gradient function,
\(\odot\) is the concatenation of two sequences along the length dimension. Notably, the keys and values are conditioned on the previous sequence length
\(\tilde{\mathbf {h}}^{n-1}_{\tau +1}\) instead of
\(\mathbf {h}^{n-1}_{\tau +1}.\)Relative Positional Encodings. Transformer-XL introduces novel relative position encodings. In this scheme, absolute positional encodings are not added to the content embeddings. Instead, they are only considered while computing attention weights where they can be replaced with relative position encodings. Since the relative position encodings are not directly relevant to the efficiency of the model, we refer interested readers to Dai et al. [
16] for more details.
3.2.17 Compressive Transformers.
Compressive Transformers [
59] are a natural extension of the Transformer-XL model. The key idea behind the Compressive Transformer is to maintain a fine-grained memory of past segment activations. This is unlike Transformer-XL, which discards past activations as it moves across segments.
Model Memory. The Compressive Transformer is characterized by a dual model memory system—a primary model memory and a secondary compressed model memory. It maintains a model memory with
\(n_m\) memory slots and
\(n_{cm}\) compressive memory slots. Whenever the model accepts a new input segment, the oldest
\(n_s\) activations in the primary model memory are moved to the compressed model memory where a compression function is applied.
Compression. These memories are compressed with a variety of compression functions such as (1) mean/max pooling, (2) 1D convolutions, (3) dilated convolutions, and (4) most used (e.g., sorted by usage of attention). Memory Reconstruction. In order to better retain memories over long sequences, the Compressive Transformer implements an auto-encoding loss that learns to reconstruct the original memory from its compressed version, i.e., \(L^{ae}=|| \text{old}\_\text{mem} - g(\text{new}\_\text{cm}^{(i)})||\) where \(g(.) : \mathbb {R}^{\frac{n_s}{c} \times d} \rightarrow \mathbb {R}^{n_s \times d}\) is a parameterized function. A second attention reconstruction is a lossy re-construct that attempts to reconstruct the attention over model memory instead of the lossless reconstruction of the model memory itself.
3.2.18 Sparse Models.
In this section, we describe the family of Sparse models. Sparse models typically achieve a high parameter to FLOP ratio by sparsely activating a subset of parameters or activations. It is good to note that while most of the works within the scope of this survey deals with efficient attention, the scope of sparse models goes beyond the attention module and is generally applied more frequently to the feed forward layers [
23,
45]. In this section, we discuss the prime variant for Sparse models, i.e., the Mixture-of-Experts–based Sparse models which includes models such as GShard [
45], Switch Transformer [
23], and GLaM [
21].
Mixture-of-Experts. The key idea behind MoE is to route token
\(x_{i}\) to a set of selected experts determined by a routing function. The routing function typically computed a linear combination over experts using the softmax function and can be interpreted as a form of gating mechanism. The top-k gate values are then selected for each token
\(x_{i}\) and the final output of that layer is determined by a linear combination of selected top-k experts. This MoE layer remains foundational and fundamental to many MoE architectures, with the exception of certain implementation details. For example, Switch uses a top-1 routing strategy while GShard uses a group-level top-2 gating.