Researchers find Transformer models can learn implicit reasoning through "grokking" - but not for all types of reasoning. They also show how this might change with some changes in the architecture.
Researchers from Ohio State University and Carnegie Mellon University have investigated whether Transformer models are capable of implicit reasoning in a new study. Their results show that the models can acquire this ability through so-called "grokking" - but not equally well for all types of reasoning.
The scientists examined two representative types of reasoning: composition and comparison. Composition requires linking multiple facts together, for example "Barack's wife is called Michelle" and "Michelle was born in 1964", to complete the sentence "Barack's wife was born in [1964]". Comparison involves comparing attribute values of entities, such as the age of two people.
The results show that Transformer models acquire the ability for implicit reasoning in both task types by training beyond the point of overfitting. After 14,000 training steps, the network achieved saturated training performance but no generalization. However, as the number of training steps increased, the generalization performance increased until near-perfect accuracy was achieved after about 50-fold trainingsteps.
But what exactly does generalization mean here? The team divided the training data into "atomic" facts and "inferred" facts - similar to the example above. The atomic facts were further divided into "in-distribution" and "out-of-distribution" facts. Despite what the name suggests, the Transformer model was trained on all atomic facts. Out of those atomic facts, the team formed inferred facts, which the network was also trained on - but only for inferred facts that were not formed from OOD atomic facts. This last group (inferred OOD) was held back for testing. This ensures that in its tests, the team can also check whether the networks systematically learn to generalize - i.e., in the case of composition tasks, to combine known facts (learned atomic OOD facts) into new inferred facts (unlearned inferred OOD facts).
In the tests, the team found a clear difference in generalization behavior: while the models can generalize to unseen examples (out-of-distribution, OOD) in comparison tasks, they fail to do so in composition tasks. This is consistent with many other research findings that have found the same. The team also discovered that the speed of generalization correlates with the ratio of inferred to atomic facts in training, not the absolute size of the training data.
Why transformers don't generalize - and how the problem might be solved
What's special about this work is that the team has an explanation for the different behavior. The researchers attribute the difference to the internal structure of the learned circuits. In comparison tasks, the relevant facts are processed in parallel, allowing for systematic generalization. In composition tasks, on the other hand, the facts are processed sequentially in different layers of the network, which hinders OOD generalization.
Conversely, this means that a connection between the layers could enable OOD generalization for composition tasks, the team speculates. The authors therefore suggest adaptations to the Transformer architecture. In particular, they show that a parameter-sharing scheme similar to the Universal Transformer does indeed enable OOD generalization in composition tasks, albeit more slowly than in-distribution generalization.
To further illustrate the central importance of "parametric" knowledge - i.e., the implicit knowledge stored in the network - for generalization, the researchers developed a challenging task with a large search space. They trained a model with the "grokking" data and compared its performance to other models that could only access this knowledge via RAG or within the context window. While GPT-4-Turbo and Gemini-Pro-1.5 using such "non-parametric" memory failed at this task, a fully "grokked" Transformer was able to solve the task with near-perfect accuracy.
As with all research, there are some limitations: The team used synthetic datasets for all experiments, so it is not yet clear whether the results fully translate to real-world scenarios. Moreover, only two specific forms of reasoning were studied.
The team recommends further research into suitable cross-layer mechanisms in Transformers to improve generalization, as well as into how parametric and non-parametric (context window / RAG) memory can be balanced in language models. Further research should also focus on extending the analysis to a broader range of reasoning tasks and more complex language models.