Transformer-based auto-regressive models such as GPT have shown remarkable performance in generative tasks but struggle when it comes to complex decision-making and reasoning tasks requiring search.
There are two main factor attributes for this performance: (1) the snowballing of errors, where a single mistake can compound and lead to increasingly poor performance in subsequent steps, and (2) a difficulty in ‘lookahead tasks’, where the model must predict the consequences of its actions several steps ahead. Both of these issues can be attributed to limited ability to search and backtrack. Even supplementing language models with search ability during inference has not shown much improvement.
What if language models can learn to search during training itself ? thereby improving its ability to discover more flexible search strategies through self-improvement and better equipped itself to handle error compounding and lookahead tasks.
Introducing Stream of Search (SoS) framework, a unified language for search that captures an array of different symbolic search strategies such as exploration, backtracking, and pruning. To understand SoS let’s instantiate a simple countdown problem with input numbers and a target number. The goal here is to combine input numbers with simple arithmetic operations to get to the given target number. To start with, the Stream of Search (SoS) dataset is created containing search trajectories generated by diverse search strategies, including exploration and backtracking. Then a transformer-based language model is pretrain on this SoS dataset. Finally, the model is finetuned with two policy improvement methods: Advantage-Induced Policy Alignment (APA) optimize for correctness and Self-Taught Reasoner (STaR) for expert iteration.
During evaluation, SoS models were able to solve approximately 36% of the previously unsolved problems and about 4% of the difficult problems including problems that cannot be solved by any of the heuristic solvers. Finally, SoS + APA and SoS + STaR models also have better models of the environment, making fewer errors while searching, and finding the solution more quickly.