Training Large Language Models to Reason in a Continuous Latent Space

llm
research paper
Author

Santosh Sawant

Published

December 10, 2024

Large language models (LLMs) are restricted to reason in the “language space”, where they typically express the reasoning process with a chain-of-thought (CoT) to solve a complex reasoning problem. For example, most word tokens are primarily for textual coherence and not essential for reasoning, while some critical tokens require complex planning and pose huge challenges to LLMs.

To explore the potential of LLM reasoning in an unrestricted latent space, researchers from meta have introduced Coconut (Chain of Continuous Thought). Coconut involves a simple modification to the traditional CoT process: instead of mapping between hidden states and language tokens using the language model head and embedding layer, Coconut directly feeds the last hidden state (a continuous thought) as the input embedding for the next token. This modification frees the reasoning from being within the language space, and the system can be optimized end-to-end by gradient descent, as continuous thoughts are fully differentiable. To enhance the training of latent reasoning, Coconut employs a multi-stage training strategy, which effectively utilizes language reasoning chains to guide the training process.

Unlike language-based reasoning, continuous thoughts in Coconut can encode multiple potential next steps simultaneously, allowing for a reasoning process akin to breadth-first search (BFS). While the model may not initially make the correct decision, it can maintain many possible options within the continuous thoughts and progressively eliminate incorrect paths through reasoning, guided by some implicit value functions. This advanced reasoning mechanism surpasses traditional CoT, even though the model is not explicitly trained or instructed to operate in this manner, as seen in previous works.

Experimentally, Coconut successfully enhances the reasoning capabilities of LLMs. For mathematical reasoning, using continuous thoughts improves accuracy, similar to reasoning chains in language. This suggests that more continuous thoughts could help solve increasingly complex problems. In logical reasoning tasks (e.g., ProntoQA and ProsQA), Coconut and its variants outperform language-based reasoning chains, while using fewer tokens during inference. These results highlight the potential of latent reasoning and offer insights for future research.

Paper : Coconut: Training Large Language Models to Reason in a Continuous Latent Space