`Fast Autoregressive Transformers with Linear Attention
`
`Angelos Katharopoulos 1 2 Apoorv Vyas 1 2 Nikolaos Pappas 3 Franc¸ois Fleuret 2 4 *
`
`attention complexity to O(cid:16)
`
`(cid:17)
`
`by the global receptive field of self-attention, which pro-
`cesses contexts of N inputs with a quadratic memory and
`
`time complexity O(cid:0)N 2(cid:1). As a result, in practice trans-
`
`formers are slow to train and their context is limited. This
`disrupts temporal coherence and hinders the capturing of
`long-term dependencies. Dai et al. (2019) addressed the lat-
`ter by attending to memories from previous contexts albeit
`at the expense of computational efficiency.
`Lately, researchers shifted their attention to approaches that
`increase the context length without sacrificing efficiency.
`Towards this end, Child et al. (2019) introduced sparse
`√
`factorizations of the attention matrix to reduce the self-
`. Kitaev et al. (2020) fur-
`N
`N
`ther reduced the complexity to O (N log N ) using locality-
`sensitive hashing. This made scaling to long sequences
`possible. Even though the aforementioned models can be
`efficiently trained on large sequences, they do not speed-up
`autoregressive inference.
`In this paper, we introduce the linear transformer model
`that significantly reduces the memory footprint and scales
`linearly with respect to the context length. We achieve this
`by using a kernel-based formulation of self-attention and
`the associative property of matrix products to calculate the
`self-attention weights (§ 3.2). Using our linear formula-
`tion, we also express causal masking with linear complexity
`and constant memory (§ 3.3). This reveals the relation be-
`tween transformers and RNNs, which enables us to perform
`autoregressive inference orders of magnitude faster (§ 3.4).
`Our evaluation on image generation and automatic speech
`recognition demonstrates that linear transformer can reach
`the performance levels of transformer, while being up to
`three orders of magnitude faster during inference.
`
`2. Related Work
`In this section, we provide an overview of the most relevant
`works that seek to address the large memory and computa-
`tional requirements of transformers. Furthermore, we dis-
`cuss methods that theoretically analyze the core component
`of the transformer model, namely self-attention. Finally,
`we present another line of work that seeks to alleviate the
`softmax bottleneck in the attention computation.
`
`products to reduce the complexity from O(cid:0)N 2(cid:1)
`
`Abstract
`Transformers achieve remarkable performance in
`several tasks but due to their quadratic complex-
`ity, with respect to the input’s length, they are
`prohibitively slow for very long sequences. To ad-
`dress this limitation, we express the self-attention
`as a linear dot-product of kernel feature maps and
`make use of the associativity property of matrix
`to O (N ), where N is the sequence length. We
`show that this formulation permits an iterative
`implementation that dramatically accelerates au-
`toregressive transformers and reveals their rela-
`tionship to recurrent neural networks. Our lin-
`ear transformers achieve similar performance to
`vanilla transformers and they are up to 4000x
`faster on autoregressive prediction of very long
`sequences.
`
`1. Introduction
`Transformer models were originally introduced by Vaswani
`et al. (2017) in the context of neural machine translation
`(Sutskever et al., 2014; Bahdanau et al., 2015) and have
`demonstrated impressive results on a variety of tasks dealing
`with natural language (Devlin et al., 2019), audio (Sperber
`et al., 2018), and images (Parmar et al., 2019). Apart from
`tasks with ample supervision, transformers are also effec-
`tive in transferring knowledge to tasks with limited or no
`supervision when they are pretrained with autoregressive
`(Radford et al., 2018; 2019) or masked language modeling
`objectives (Devlin et al., 2019; Yang et al., 2019; Song et al.,
`2019; Liu et al., 2020).
`However, these benefits often come with a very high compu-
`tational and memory cost. The bottleneck is mainly caused
`
`1Idiap Research Institute, Switzerland 2EPFL, Switzerland
`3University of Washington, Seattle, USA 4University of Geneva,
`Switzerland. *Work done at Idiap. Correspondence to: Angelos
`Katharopoulos <firstname.lastname@idiap.ch>.
`
`Proceedings of the 37 th International Conference on Machine
`Learning, Online, PMLR 119, 2020. Copyright 2020 by the au-
`thor(s).
`
`arXiv:2006.16236v3 [cs.LG] 31 Aug 2020
`
`1
`
`Petitioner, EX1005
`IPR2024-01234
`Hugging Face, Inc., v. FriendliAI Inc.
`
`
`
`2.1. Efficient Transformers
`
`2.2. Understanding Self-Attention
`
`Transformers are RNNs
`
`Existing works seek to improve memory efficiency in
`transformers through weight pruning (Michel et al., 2019),
`weight factorization (Lan et al., 2020), weight quantization
`(Zafrir et al., 2019) or knowledge distillation. Clark et al.
`(2020) proposed a new pretraining objective called replaced
`token detection that is more sample efficient and reduces the
`overall computation. Lample et al. (2019) used product-key
`attention to increase the capacity of any layer with negligible
`computational overhead.
`Reducing the memory or computational requirements with
`these methods leads to training or inference time speedups,
`but, fundamentally, the time complexity is still quadratic
`with respect to the sequence length which hinders scaling
`to long sequences. In contrast, we show that our method
`reduces both memory and time complexity of transformers
`both theoretically (§ 3.2) and empirically (§ 4.1).
`Another line of research aims at increasing the “context” of
`self-attention in transformers. Context refers to the maxi-
`mum part of the sequence that is used for computing self-
`attention. Dai et al. (2019) introduced Transformer-XL
`which achieves state-of-the-art in language modeling by
`learning dependencies beyond a fixed length context without
`disrupting the temporal coherence. However, maintaining
`previous contexts in memory introduces significant addi-
`tional computational cost.
`In contrast, Sukhbaatar et al.
`(2019) extended the context length significantly by learning
`the optimal attention span per attention head, while main-
`taining control over the memory footprint and computation
`time. Note that both approaches have the same asymptotic
`complexity as the vanilla model. In contrast, we improve the
`asymptotic complexity of the self-attention, which allows
`us to use significantly larger context.
`More related to our model are the works of Child et al.
`(2019) and Kitaev et al. (2020). The former (Child et al.,
`2019) introduced sparse factorizations of the attention ma-
`trix reducing the overall complexity from quadratic to
`√
`for generative modeling of long sequences.
`N
`N
`More recently, Kitaev et al. (2020) proposed Reformer. This
`method further reduces complexity to O (N log N ) by us-
`ing locality-sensitive hashing (LSH) to perform fewer dot
`products. Note that in order to be able to use LSH, Reformer
`constrains the keys, for the attention, to be identical to the
`queries. As a result this method cannot be used for decoding
`tasks where the keys need to be different from the queries.
`In comparison, linear transformers impose no constraints
`on the queries and keys and scale linearly with respect to the
`sequence length. Furthermore, they can be used to perform
`inference in autoregressive tasks three orders of magnitude
`faster, achieving comparable performance in terms of vali-
`dation perplexity.
`
`O(cid:16)
`
`(cid:17)
`
`There have been few efforts to better understand self-
`attention from a theoretical perspective. Tsai et al. (2019)
`proposed a kernel-based formulation of attention in trans-
`formers which considers attention as applying a kernel
`smoother over the inputs with the kernel scores being the
`similarity between inputs. This formulation provides a bet-
`ter way to understand attention components and integrate
`the positional embedding. In contrast, we use the kernel
`formulation to speed up the calculation of self-attention and
`lower its computational complexity. Also, we observe that
`if a kernel with positive similarity scores is applied on the
`queries and keys, linear attention converges normally.
`More recently, Cordonnier et al. (2020) provided theoret-
`ical proofs and empirical evidence that a multi-head self-
`attention with sufficient number of heads can express any
`convolutional layer. Here, we instead show that a self-
`attention layer trained with an autoregressive objective can
`be seen as a recurrent neural network and this observation
`can be used to significantly speed up inference time of au-
`toregressive transformer models.
`
`2.3. Linearized softmax
`
`For many years, softmax has been the bottleneck for train-
`ing classification models with a large number of categories
`(Goodman, 2001; Morin & Bengio, 2005; Mnih & Hinton,
`2009). Recent works (Blanc & Rendle, 2017; Rawat et al.,
`2019), have approximated softmax with a linear dot product
`of feature maps to speed up the training through sampling.
`Inspired from these works, we linearize the softmax atten-
`tion in transformers. Concurrently with this work, Shen
`et al. (2020) explored the use of linearized attention for the
`task of object detection in images. In comparison, we do not
`only linearize the attention computation, but also develop
`an autoregressive transformer model with linear complex-
`ity and constant memory for both inference and training.
`Moreover, we show that through the lens of kernels, every
`transformer can be seen as a recurrent neural network.
`
`3. Linear Transformers
`In this section, we formalize our proposed linear trans-
`former. We present that changing the attention from the tra-
`ditional softmax attention to a feature map based dot product
`attention results in better time and memory complexity as
`well as a causal model that can perform sequence generation
`in linear time, similar to a recurrent neural network.
`Initially, in § 3.1, we introduce a formulation for the trans-
`former architecture introduced in (Vaswani et al., 2017).
`Subsequently, in § 3.2 and § 3.3 we present our proposed
`linear transformer and finally, in § 3.4 we rewrite the trans-
`former as a recurrent neural network.
`
`2
`
`
`
`(cid:80)N
`(cid:80)N
`φ (Qi)T(cid:80)N
`φ (Qi)T(cid:80)N
`φ (Q) φ (K)T(cid:17)
`
`The above equation is simpler to follow when the numerator
`is written in vectorized form as follows,
`
`(cid:16)
`
`(cid:17)
`
`(cid:16)
`
`φ (K)T V
`(6)
`V = φ (Q)
`.
`Note that the feature map φ (·) is applied rowwise to the
`matrices Q and K.
`From equation 2, it is evident that the computational cost of
`
`Transformers are RNNs
`
`Given such a kernel with a feature representation φ (x) we
`can rewrite equation 2 as follows,
`
`V (cid:48)
`i =
`
`j=1 φ (Qi)T φ (Kj) Vj
`j=1 φ (Qi)T φ (Kj)
`and then further simplify it by making use of the associative
`property of matrix multiplication to
`
`,
`
`(4)
`
`V (cid:48)
`i =
`
`j=1 φ (Kj) V T
`j
`j=1 φ (Kj)
`
`.
`
`(5)
`
`3.1. Transformers
`Let x ∈ RN×F denote a sequence of N feature vectors of
`dimensions F . A transformer is a function T : RN×F →
`RN×F defined by the composition of L transformer layers
`T1(·), . . . , TL(·) as follows,
`
`(1)
`Tl(x) = fl(Al(x) + x).
`The function fl(·) transforms each feature independently of
`the others and is usually implemented with a small two-layer
`feedforward network. Al(·) is the self attention function and
`is the only part of the transformer that acts across sequences.
`The self attention function Al(·) computes, for every posi-
`tion, a weighted average of the feature representations of
`all other positions with a weight proportional to a similar-
`ity score between the representations. Formally, the input
`sequence x is projected by three matrices WQ ∈ RF×D,
`WK ∈ RF×D and WV ∈ RF×M to corresponding rep-
`resentations Q, K and V . The output for all positions,
`Al(x) = V (cid:48), is computed as follows,
`
`Q = xWQ,
`K = xWK ,
`V = xWV ,
`Al(x) = V (cid:48) = softmax
`
`(cid:19)
`
`(cid:18) QK T√
`
`D
`
`V.
`
`(2)
`
`Note that in the previous equation, the softmax function is
`applied rowwise to QK T . Following common terminology,
`the Q, K and V are referred to as the “queries”, “keys” and
`“values” respectively.
`Equation 2 implements a specific form of self-attention
`called softmax attention where the similarity score is the
`exponential of the dot product between a query and a key.
`Given that subscripting a matrix with i returns the i-th row
`as a vector, we can write a generalized attention equation
`for any similarity function as follows,
`
`j=1 sim (Qi, Kj) Vj
`j=1 sim (Qi, Kj)
`Equation 3 is equivalent to equation 2 if we substitute the
`similarity function with sim (q, k) = exp
`.
`
`.
`
`(3)
`
`(cid:17)
`
`(cid:16) qT k√
`
`D
`
`(cid:80)N
`(cid:80)N
`
`V (cid:48)
`i =
`
`3.2. Linearized Attention
`
`The definition of attention in equation 2 is generic and can be
`used to define several other attention implementations such
`as polynomial attention or RBF kernel attention (Tsai et al.,
`2019). Note that the only constraint we need to impose
`to sim (·), in order for equation 3 to define an attention
`function, is to be non-negative. This includes all kernels
`k(x, y) : R2×F → R
`+.
`
`softmax attention scales with O(cid:0)N 2(cid:1), where N represents
`
`the sequence length. The same is true for the memory re-
`quirements because the full attention matrix must be stored
`to compute the gradients with respect to the queries, keys
`and values. In contrast, our proposed linear transformer
`from equation 5 has time and memory complexity O (N ) be-
`j=1 φ (Kj) V T
`j=1 φ (Kj)
`once and reuse them for every query.
`
`cause we can compute(cid:80)N
`
`j and(cid:80)N
`
`3.2.1. FEATURE MAPS AND COMPUTATIONAL COST
`
`For softmax attention, the total cost in terms of multiplica-
`
`tions and additions scales as O(cid:0)N 2 max (D, M )(cid:1), where
`
`D is the dimensionality of the queries and keys and M is
`the dimensionality of the values. On the contrary, for linear
`attention, we first compute the feature maps of dimension-
`ality C. Subsequently, computing the new values requires
`O (N CM ) additions and multiplications.
`The previous analysis does not take into account the choice
`of kernel and feature function. Note that the feature func-
`tion that corresponds to the exponential kernel is infinite
`dimensional, which makes the linearization of exact soft-
`max attention infeasible. On the other hand, the polynomial
`kernel, for example, has an exact finite dimensional feature
`map and has been shown to work equally well with the expo-
`nential or RBF kernel (Tsai et al., 2019). The computational
`cost for a linearized polynomial transformer of degree 2
`
`is O(cid:0)N D2M(cid:1). This makes the computational complexity
`
`favorable when N > D2. Note that this is true in practice
`since we want to be able to process sequences with tens of
`thousands of elements.
`For our experiments, that deal with smaller sequences, we
`employ a feature map that results in a positive similarity
`
`3
`
`
`
`Transformers are RNNs
`
`(cid:80)i
`(cid:80)i
`
`V (cid:48)
`i =
`
`function as defined below,
`
`(7)
`φ (x) = elu(x) + 1,
`where elu(·) denotes the exponential linear unit (Clevert
`et al., 2015) activation function. We prefer elu(·) over relu(·)
`to avoid setting the gradients to 0 when x is negative. This
`feature map results in an attention function that requires
`O (N DM ) multiplications and additions. In our experi-
`mental section, we show that the feature map of equation 7
`performs on par to the full transformer, while significantly
`reducing the computational and memory requirements.
`
`3.3. Causal Masking
`
`The transformer architecture can be used to efficiently train
`autoregressive models by masking the attention computa-
`tion such that the i-th position can only be influenced by
`a position j if and only if j ≤ i, namely a position cannot
`be influenced by the subsequent positions. Formally, this
`causal masking changes equation 3 as follows,
`
`.
`
`(8)
`
`j=1 sim (Qi, Kj) Vj
`j=1 sim (Qi, Kj)
`Following the reasoning of § 3.2, we linearize the masked
`attention as described below,
`
`φ (Qi)T(cid:80)i
`φ (Qi)T(cid:80)i
`i(cid:88)
`i(cid:88)
`
`V (cid:48)
`i =
`
`j=1 φ (Kj) V T
`j
`j=1 φ (Kj)
`
`By introducing Si and Zi as follows,
`
`Si =
`
`φ (Kj) V T
`j ,
`
`j=1
`
`Zi =
`
`φ (Kj) ,
`
`j=1
`
`we can simplify equation 9 to
`
`V (cid:48)
`i =
`
`φ (Qi)T Si
`φ (Qi)T Zi
`
`.
`
`.
`
`(9)
`
`(10)
`
`(11)
`
`(12)
`
`Note that, Si and Zi can be computed from Si−1 and Zi−1
`in constant time hence making the computational complex-
`ity of linear transformers with causal masking linear with
`respect to the sequence length.
`
`3.3.1. GRADIENT COMPUTATION
`
`A naive implementation of equation 12, in any deep learning
`framework, requires storing all intermediate values Si in
`order to compute the gradients. This increases the mem-
`ory consumption by max (D, M ) times; thus hindering the
`
`applicability of causal linear attention to longer sequences
`or deeper models. To address this, we derive the gradients
`of the numerator in equation 9 as cumulative sums. This
`allows us to compute both the forward and backward pass
`of causal linear attention in linear time and constant mem-
`ory. A detailed derivation is provided in the supplementary
`material.
`Given the numerator ¯Vi and the gradient of a scalar loss
`function with respect to the numerator ∇ ¯ViL, we derive
`∇φ(Qi)L, ∇φ(Ki)L and ∇ViL as follows,
`
` i(cid:88)
`
`T
` Vi ,
`(cid:16)∇ ¯VjL(cid:17)T
`T
`(cid:16)∇ ¯VjL(cid:17)T
`
`φ (Kj) V T
`j
`
`,
`
`j=1
`
`φ (Qj)
`
`φ (Qj)
`
`φ (Ki) .
`
`(13)
`
`(14)
`
`(15)
`
`∇φ(Qi)L = ∇ ¯ViL
`
` N(cid:88)
` N(cid:88)
`
`j=i
`
`∇φ(Ki)L =
`
`∇ViL =
`
`j=i
`
`The cumulative sum terms in equations 9, 13-15 are com-
`puted in linear time and require constant memory with re-
`spect to the sequence length. This results in an algorithm
`with computational complexity O (N CM ) and memory
`O (N max (C, M )) for a given feature map of C dimen-
`sions. A pseudocode implementation of the forward and
`backward pass of the numerator is given in algorithm 1.
`
`3.3.2. TRAINING AND INFERENCE
`
`When training an autoregressive transformer model the full
`ground truth sequence is available. This makes layerwise
`parallelism possible both for fl(·) of equation 1 and the
`attention computation. As a result, transformers are more
`efficient to train than recurrent neural networks. On the
`other hand, during inference the output for timestep i is the
`input for timestep i + 1. This makes autoregressive models
`impossible to parallelize. Moreover, the cost per timestep
`for transformers is not constant; instead, it scales with the
`square of the current sequence length because attention must
`be computed for all previous timesteps.
`Our proposed linear transformer model combines the best
`of both worlds. When it comes to training, the computations
`can be parallelized and take full advantage of GPUs or other
`accelerators. When it comes to inference, the cost per time
`and memory for one prediction is constant for our model.
`
`This means we can simply store the φ (Kj) V Tj matrix as an
`internal state and update it at every time step like a recurrent
`neural network. This results in inference thousands of
`times faster than other transformer models.
`
`4
`
`
`
`Transformers are RNNs
`
`3.4. Transformers are RNNs
`
`In literature, transformer models are considered to be a fun-
`damentally different approach to recurrent neural networks.
`However, from the causal masking formulation in § 3.3 and
`the discussion in the previous section, it becomes evident
`that any transformer layer with causal masking can be writ-
`ten as a model that, given an input, modifies an internal state
`and then predicts an output, namely a Recurrent Neural
`Network (RNN). Note that, in contrast to Universal Trans-
`formers (Dehghani et al., 2018), we consider the recurrence
`with respect to time and not depth.
`In the following equations, we formalize the transformer
`layer of equation 1 as a recurrent neural network. The
`resulting RNN has two hidden states, namely the attention
`memory s and the normalizer memory z. We use subscripts
`to denote the timestep in the recurrence.
`
`(16)
`(17)
`(18)
`(19)
`
`(20)
`
`(cid:32)
`
`(cid:33)
`
`yi = fl
`
`+ xi
`
`.
`
`s0 = 0,
`z0 = 0,
`si = si−1 + φ (xiWK ) (xiWV )T ,
`zi = zi−1 + φ (xiWK ) ,
`φ (xiWQ)T si
`φ (xiWQ)T zi
`In the above equations, xi denotes the i-th input and yi the
`i-th output for a specific transformer layer. Note that our
`formulation does not impose any constraint on the feature
`function and it can be used for representing any transformer
`model, in theory even the ones using softmax attention. This
`formulation is a first step towards better understanding the
`relationship between transformers and popular recurrent net-
`works (Hochreiter & Schmidhuber, 1997) and the processes
`used for storing and retrieving information.
`
`4. Experiments
`In this section, we analyze experimentally the performance
`of the proposed linear transformer. Initially, in § 4.1, we
`evaluate the linearized attention in terms of computational
`cost, memory consumption and convergence on synthetic
`data. To further showcase the effectiveness of linear trans-
`formers, we evaluate our model on two real-world appli-
`cations, image generation in § 4.2 and automatic speech
`recognition in § 4.3. We show that our model achieves
`competitive performance with respect to the state-of-the-art
`transformer architectures, while requiring significantly less
`GPU memory and computation.
`Throughout our experiments, we compare our model with
`two baselines, the full transformer with softmax attention
`and the Reformer (Kitaev et al., 2020), the latter being a
`state-of-the-art accelerated transformer architecture. For the
`Reformer, we use a PyTorch reimplementation of the pub-
`
`Algorithm 1 Linear transformers with causal masking
`function forward(φ (Q), φ (K), V ):
`V (cid:48) ← 0, S ← 0
`for i = 1, . . . , N do
`S ← S + φ (Ki) V T
`¯Vi ← φ (Qi) S
`end
`return ¯V
`end
`function backward(φ (Q), φ (K), V , G):
`/* G is the gradient of the loss
`with respect to the output of
`forward
`S ← 0, ∇φ(Q)L ← 0
`for i = 1, . . . , N do
`S ← S + φ (Ki) V T
`∇φ(Qi)L ← GiST
`
`equation 10
`
`*/
`
`equation 13
`
`i
`
`i
`
`end
`S ← 0, ∇φ(K)L ← 0, ∇V L ← 0
`for i = N, . . . , 1 do
`S ← S + φ (Qi) GT
`∇ViL ← ST φ (Ki)
`∇φ(Ki)L ← SVi
`end
`return ∇φ(Q)L, ∇φ(K)L, ∇V L
`end
`
`i
`
`equation 15
`equation 14
`
`lished code and for the full transformer we use the default
`PyTorch implementation. Note that for Reformer, we do
`not use the reversible layers, however, this does not affect
`the results as we only measure the memory consumption
`with respect to the self attention layer. In all experiments,
`we use softmax (Vaswani et al., 2017) to refer to the stan-
`dard transformer architecture, linear for our proposed linear
`transformers and lsh-X for Reformer (Kitaev et al., 2020),
`where X denotes the hashing rounds.
`For training the linear transformers, we use the feature map
`of equation 7. Our PyTorch (Paszke et al., 2019) code with
`documentation and examples can be found at https://
`linear-transformers.com/. The constant memory
`gradient computation of equations 13-15 is implemented in
`approximately 200 lines of CUDA code.
`
`4.1. Synthetic Tasks
`
`4.1.1. CONVERGENCE ANALYSIS
`
`To examine the convergence properties of linear transform-
`ers we train on an artifical copy task with causal masking.
`Namely, the transformers have to copy a series of symbols
`similar to the sequence duplication task of Kitaev et al.
`(2020). We use a sequence of maximum length 128 with 10
`
`5
`
`
`
`Transformers are RNNs
`
`Figure 1: Comparison of the computational requirements for a forward/backward pass for Reformer (lsh-X), softmax
`attention and linear attention. Linear and Reformer models scale linearly with the sequence length unlike softmax which
`scales with the square of the sequence length both in memory and time. Full details of the experiment can be found in § 4.1.
`
`Every method is evaluated up to the maximum sequence
`length that fits the GPU memory. For this benchmark we
`use an NVidia GTX 1080 Ti with 11GB of memory. This
`results in a maximum sequence length of 4,096 elements
`for softmax and 16,384 for lsh-4 and lsh-8. As expected,
`softmax scales quadratically with respect to the sequence
`length. Our method is faster and requires less memory than
`the baselines for every configuration, as seen in figure 1.
`We observe that both Reformer and linear attention scale
`linearly with the sequence length. Note that although the
`asymptotic complexity for Reformer is O (N log N ), log N
`is small enough and does not affect the computation time.
`
`4.2. Image Generation
`
`Transformers have shown great results on the task of condi-
`tional or unconditional autoregressive generation (Radford
`et al., 2019; Child et al., 2019), however, sampling from
`transformers is slow due to the task being inherently se-
`quential and the memory scaling with the square of the
`sequence length. In this section, we train causally masked
`transformers to predict images pixel by pixel. Our achieved
`performance in terms of bits per dimension is on par with
`softmax attention while being able to generate images more
`than 1,000 times faster and with constant memory per
`image from the first to the last pixel. We refer the reader
`to our supplementary for comparisons in terms of training
`evolution, quality of generated images and time to generate
`a single image. In addition, we also compare with a faster
`softmax transformer that caches the keys and values during
`inference, in contrast to the PyTorch implementation.
`
`4.2.1. MNIST
`
`First, we evaluate our model on image generation with au-
`toregressive transformers on the widely used MNIST dataset
`(LeCun et al., 2010). The architecture for this experiment
`comprises 8 attention layers with 8 attention heads each. We
`
`Figure 2: Convergence comparison of softmax, linear and
`reformer attention on a sequence duplication task. linear
`converges stably and reaches the same final performance as
`softmax. The details of the experiment are in § 4.1.
`
`different symbols separated by a dedicated separator symbol.
`For all three methods, we train a 4 layer transformer with
`8 attention heads using a batch size of 64 and the RAdam
`optimizer (Liu et al., 2019) with a learning rate of 10−3
`which is reduced to 10−4 after 3000 updates. Figure 2 de-
`picts the loss with respect to the number of gradient steps.
`We observe that linear converges smoothly and reaches a
`lower loss than lsh due to the lack of noise introduced by
`hashing. In particular, it reaches the same loss as softmax.
`
`4.1.2. MEMORY AND COMPUTATIONAL REQUIREMENTS
`
`In this subsection, we compare transformers with respect
`to their computational and memory requirements. We com-
`pute the attention and the gradients for a synthetic input
`with varying sequence lengths N ∈ {29, 210, . . . , 216} and
`measure the peak allocated GPU memory and required time
`for each variation of transformer. We scale the batch size
`inversely with the sequence length and report the time and
`memory per sample in the batch.
`
`29
`
`210
`
`211
`
`213
`212
`Sequence Length
`
`214
`
`215
`
`216
`
`102
`
`101
`
`100
`
`Time(milliseconds)
`
`linear(ours)
`softmax
`lsh-1
`lsh-4
`lsh-8
`
`29
`
`210
`
`211
`
`213
`212
`Sequence Length
`
`214
`
`215
`
`216
`
`103
`
`102
`
`101
`
`GPUMemory(MB)
`
`linear (ours)
`softmax
`lsh-4
`
`100
`
`10−1
`
`10−2
`
`10−3
`
`10−4
`
`CrossEntropyLoss
`
`0
`
`2000
`
`6000
`4000
`Gradient steps
`
`8000
`
`10000
`
`6
`
`
`
`Transformers are RNNs
`
`Method
`Softmax
`LSH-1
`LSH-4
`Linear (ours)
`
`Bits/dim
`0.621
`0.745
`0.676
`0.644
`
`Images/sec
`(1×)
`0.45
`(1.5×)
`0.68
`(0.6×)
`0.27
`(317×)
`142.8
`
`Method
`Softmax
`LSH-1
`LSH-4
`Linear (ours)
`
`Bits/dim
`3.47
`3.39
`3.51
`3.40
`
`Images/sec
`(1×)
`0.004
`(3.75×)
`0.015
`(1.25×)
`0.005
`(4,462×)
`17.85
`
`Table 1: Comparison of autoregressive image generation of
`MNIST images. Our linear transformers achieve almost the
`same bits/dim as the full softmax attention but more than
`300 times higher throughput in image generation. The full
`details of the experiment are in § 4.2.1.
`
`Table 2: We train autoregressive transformers for 1 week
`on a single GPU to generate CIFAR-10 images. Our linear
`transformer completes 3 times more epochs than softmax,
`which results in better perplexity. Our model generates
`images 4,000× faster than the baselines. The full details of
`the experiment are in § 4.2.2.
`
`set the embedding size to 256 which is 32 dimensions per
`head. Our feed forward dimensions are 4 times larger than
`our embedding size. We model the output with a mixture
`of 10 logistics as introduced by Salimans et al. (2017). We
`use the RAdam optimizer with a learning rate of 10−4 and
`train all models for 250 epochs. For the reformer baseline,
`we use 1 and 4 hashing rounds. Furthermore, as suggested
`in Kitaev et al. (2020), we use 64 buckets and chunks with
`approximately 32 elements. In particular, we divide the
`783 long input sequence to 27 chunks of 29 elements each.
`Since the sequence length is realtively small, namely only
`784 pixels, to remove differences due to different batch sizes
`we use a batch size of 10 for all methods.
`Table 1 summarizes the results. We observe that linear
`transformers achieve almost the same performance, in terms
`of final perplexity, as softmax transformers while being
`able to generate images more than 300 times faster. This is
`achieved due to the low memory requirements of our model,
`which is able to simultaneously generate 10,000 MNIST
`images with a single GPU. In particular, the memory is
`constant with respect to the sequence length because the
`only thing that needs to be stored between pixels are the
`si and zi values as described in equations 18 and 19. On
`the other hand, both softmax and Reformer require memory
`that increases with the length of the sequence.
`Image completions and unconditional samples from our
`MNIST model can be seen in figure 3. We observe that
`our linear transformer generates very convincing samples
`with sharp boundaries and no noise. In the case of image
`completion, we also observe that the transformer learns to
`use the same stroke style and width as the original image
`effectively attending over long temporal distances. Note that
`as the achieved perplexity is more or less the same for all
`models, we do not observe qualitative differences between
`the generated samples from different models.
`
`4.2.2. CIFAR-10
`
`The benefits of our linear formulation increase as the se-
`quence length increases. To showcase that, we train 16 layer
`
`transformers to generate CIFAR-10 images (Krizhevsky
`et al., 2009). For each layer we use the same configuration
`as in the previous experiment. For Reformer, we use again
`64 buckets and 83 chunks of 37 elements, which is approx-
`imately 32, as suggested in the paper. Since the sequence
`length is almost 4 times larger than for the previous exper-
`iment, the full transformer can only be used with a batch
`size of 1 in the largest GPU that is available to us, namely
`an NVidia P40 with 24GB of memory. For both the linear
`transformer and reformer, we use a batch size of 4. All
`models are trained for 7 days. We report results in terms of
`bits per dimension and image generation throughput in table
`2. Note that although the main point of this experiment is
`not the final perplexity, it is evident that as the sequence
`length grows, the fast transformer models become increas-
`ingly more efficient per GPU hour, achieving better scores
`than their slower counterparts.
`As the memory and time to generate a single pixel scales
`quadratically with the number of pixels for both Reformer
`and softmax attention, the increase in throughput for our lin-
`ear transformer is even more pronounced. In particular, for
`every image generated by the softmax transformer, our
`method can generate 4,460 images. Image completions
`and unconditional samples from our model can be seen in
`figure 4. We observe that our model generates images with
`spatial consistency and can complete images convincigly
`without significantly hindering the recognition of the image
`category. For instance, in figure 4b, all images have success-
`fully completed the dog’s nose (first row) or the windshield
`of the truck (last row).
`
`4.3. Automatic Speech Recognition
`
`To show that our method can also be used for non-
`autoregressive tasks, we evaluate the performance of linear
`transformers in end-to-end automatic speech recognition
`using Connectionist Temporal Classification (CTC) loss
`(Graves et al., 2006). In this setup, we predict a distribu-
`tion over phonemes for each input frame in a non autore-
`
`7
`
`
`
`Unconditional samples
`
`Unconditional samples
`
`Transformers are RNNs
`
`Image completion
`
`Image completion
`
`(a)
`
`(b)
`
`(c)
`
`(a)
`
`(b)
`
`(c)
`
`Figure 3: Unconditional samples and image completions
`generated by our method for MNIST. (a) depicts the oc-
`cluded orignal images, (b) the completions and (c) the orig-
`inal. Our model achieves comparable bits/dimension to
`softmax, while having more than 300 times higher through-
`put, generating 142 images/second. For details see § 4.2.1.
`
`Figure 4: Unconditional samples and image completions
`generated by our method for CIFAR-10. (a) depicts the
`occlu