Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

llm
research paper
Author

Santosh Sawant

Published

January 22, 2024

 

@article{cai2024medusa,
  title   = {Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads},
  author  = {Tianle Cai and Yuhong Li and Zhengyang Geng and Hongwu Peng and Jason D. Lee and Deming Chen and Tri Dao},
  year    = {2024},
  journal = {arXiv preprint arXiv: 2401.10774}
}

 

Why is it hard to run inference for large transformer models? Besides the increasing size of SoTA models, there are two main factors contributing to the inference challenge

This paper introduces MEDUSA, a method for improving inference in Large Language Models (LLMs) by adding extra decoding heads to predict multiple tokens in parallel. MEDUSA achieves significant speedup without compromising generation quality.

Medusa adds extra “heads” to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned during training. During generation, these heads each produce multiple likely words for the corresponding position. These options are then combined and processed using a tree-based attention mechanism. Finally, a typical acceptance scheme is employed to pick the longest plausible prefix from the candidates for further decoding.

So how does Medusa solve the challenges associated with speculative decoding ?

During experimentation, Medusa delivers approximately a 2x speed (1.94x) increase across a range of Vicuna models. Will be interesting to see Medusa’s performance with other open source foundational models.

Paper : https://arxiv.org/pdf/2401.10774.pdf