Over the last few years, Transformers have displaced Recurrent Neural Networks (RNNs) in sequence modeling tasks, owing to their highly parallel training efficiency and unmatched scaling performance. However, this training efficiency comes at the cost of inference cost that scales linearly with the number of tokens, compared to the fixed-cost inference of RNNs.
To overcome this limitation researchers have proper various novel methods such as Linear Transformer which combine good of both worlds i.e, train models with sequence parallelism (i.e. as transformers), but operate as RNNs at inference time. State-space models (SSMs) such as Mumba show impressive performance at smaller scales, matching or exceeding the performance of softmax transformers. However, a gap remains for long-context NLU tasks, showing a persistent advantage of softmax attention.
Rather than pre-training linear models, can we convert an existing transformer into an RNN ?
Well that is what researchers have done with Scalable UPtraining for Recurrent Attention (SUPRA), a linearization strategy to up train state-of-the-art LLMs into performant RNNs, enabling the study of the strengths and limitations of recurrent models at scale with minimal compute cost. SUPRA replaced the softmax normalization with GroupNorm (GN) and introduced a small MLP to project the queries and keys, converting a pre-trained attention block (left) to a linear attention (right). The model can be trained in parallel as a transformer and used recurrently at inference time with a mathematically equivalent reformulation. This allows it to leverage the strong pre-training data and performance of existing transformer LLMs, while requiring 5% of the training cost.
Compared to pre-training linear models from scratch, the SUPRA strategy produces competitive models comparable to the best available recurrent LLMs (RWKV and Mamba) at the 7B scale. Not only the strengths of linear models on standard NLU benchmarks was identify but also the enduring limitations on in-context (i.e. MMLU) and long-context (NarrativeQA, Qasper) tasks, showing that linearized models do not inherit these capabilities from the base softmax transformers.
Paper : https://arxiv.org/pdf/2405.06640