Recently Noumenal Labs announced themselves and I read their white paper. Although pretty light on specifics, it seems pretty clear that their issues with LLMs and generally NNs is that they do not properly reflect in their structure the true underlying generative process of reality — effectively that they do not learn (at least explicitly) an ‘object oriented’ representation the way humans do where we represent things in terms of objects and their properties in some kind of nested hierarchy. Also, they have issues with how LLMs don’t really learn physical causality — as they say ‘word models’ rather than ‘world models’.
I think these criticisms of LLMs are somewhat fair but that the implied remedies — i.e. train models that directly recapitulate the structure of the generative process via some kind of sparse graph learning— are worse and that there is a reason that the current large dense neural network paradigm works. My general hope in this post is to share my model of why a whole bunch of approaches in this sparse graph learning vein do not work and conversely why neural networks do work.
So let’s begin. Generally it appears that the ‘true’ causal nature of reality is that of a highly sparse and hierarchical graph. I.e., that although you receive very rich and complex sensory inputs, these inputs can be adequately ‘explained’ by generally only a few simple nodes which interact in knowable causal ways. For instance, when you look at a video there are a whole bunch of pixels with different values across time, but most of the variance in this scene can be explained by a few concepts such as specific objects and their causal interactions such as e.g. motion according to the laws of physics.
Many many real-world datasets and environments appear to have this structure, and indeed it is arguable that the fact that the world has this kind of local, sparse, and hierarchical structure is what makes it learnable at all. If for every sense datum there were billions of causes all interacting in incredibly complex and potentially nonlocal ways, learning anything at all would be incredibly challenging. People have discussed this point in various ways — for instance Bengio proposes this kind of sparse factor graph approach as the ‘consciousness prior’. Similarly, work on abstraction argue that this structure is fundamental to learning.
As a side-note, it’s important to reconcile this view with that of power-law scaling laws which imply a continuous fractal effectively infinitely detailed underlying structure to naturalistic datasets. The simple synthesis of these views is that the underlying causal DAG of naturalistic datasets resembles a scale-free network which has a power-law distribution of connections. The natural discrete and sparse abstraction viewpoint primarily focuses on the most connected components — i.e. it is really true that some small finite number of components correspond to the majority of edges on the graph. At the same time, farther out in the tail there are almost always more intricate details to be learned which is what gives us the smooth scaling laws.
This gives us the first primary issue with the hierarchical sparse graph approach — namely that while a large proportion of the dataset variance can be explained by such factors, there is also a significant amount that can’t which relies instead upon the detailed structure of the world and this will inevitably limit the true fidelity of any such simple abstraction. To truly represent the actual dataset structure, the model should recapitulate a finely connected scale-free power-law structured graph.
Okay then let’s suppose you want to learn a scale-free network topology to model your data. Clearly you do not in general begin with the right causal graph structure so you need to infer it from data — i.e. you essentially need to solve an inverse problem of inferring graph structure from its outputs. This you can do in the standard way by taking your dataset and performing approximate bayesian inference to infer the latent causes. The simplest possible way to do this is by maximum likelihood estimation over some parametrised model of the graph. This is training.
Obviously at the start you do not know the correct topology of your underlying causal graph nor do you know the functions by which the different nodes interact with each other. Mathematically, perhaps the best way to represent this is as a factor graph, which represents a specific factorisation of a general joint probability distribution in terms of specific factors. The variables in the joint density correspond to nodes in your graph and the fact that there is sparsity in the underlying structure of the joint means that it can be factored into a set of factors which do not include all the nodes in all interactions. However, the factors themselves can be nonlinear and arbitrary functions of their component nodes.
In general, learning an arbitrary factor graph from data is a very hard problem. This is for two reasons — firstly learning arbitrary nonlinear functions in general is challenging. To do it in the naive way you are optimizing directly in function space, which is both infinite and highly discontinuous. The second issue is that of discreteness. The choice of factorisation is discrete — i.e. a node is either connected or non-connected to another. Discrete optimization is generally very difficult at scale because our classic algorithms like backpropagation cannot be used.
To address the first issue, we need to stop optimizing in general function space and instead replace direct function space approximation with the optimization of the parameters of a function approximator. Theoretically this is fine since we can approximate any function this way and it essentially just serves as a way to parametrize function space in terms of coefficients that we can numerically optimize on a computer. However, in practice it also significantly increases the amount of parameters to be optimized vs just finding the optimal function somehow.
One way to solve the second issue of discreteness is through a continuous relaxation of the problem. Namely, instead of nodes being connected or disconnected, all nodes are connected with a given strength which can then be increased or decreased and this strength can be learnt as a continuous parameter. While this may work it essentially removes the inherent sparsity of the underlying representation.
The final issue is learning algorithm. Theoretically, this entire problem can be cast as a Bayesian inference problem and solved through standard Bayesian methods. Of course in practice, due to nonlinearities and scale exact Bayesian computation is infeasible so we must resort to approximations. One set of approximations relate to variational message passing algorithms. These assume some sets of factorisation and then can perform inference over the node values, and even theoretically over node parameters using EM based approaches. However, existing methods tend to struggle to scale, mostly due to relying on MCMC sampling which has poor computational asymptotic or else struggle to derive variational updates that can be numerically stable and cheap to compute. Interesting work is being done in this direction, however and advances can be made here.
However, in general, the most scalable approach is black-box variational inference. Recall that variational inference involves minimising an upper bound on the log evidence known as the variational free energy. How exactly this minimisation is done is irrelevant and indeed it can be done by backprop through the ‘black box’ of a neural network to optimize the parameters of the variational distribution. Indeed, if we approximate further and give up on the variational distribution being a complex probability distribution and instead approximate it as a Dirac delta on a set of weights, then the VFE collapses to the standard cross-entropy objective meaning that we can interpret all existing NN training recipes through this probabilistic lens if we so choose.
So, from trying to derive a workable and scalable algorithm for learning a sparse hierarchical graph model which is isomorphic to how we hypothesise the true underlying causal structure of naturalistic datasets, we have made a set of successive approximations to make this approach tractable and we have ended up with something extremely similar to existing neural networks. Note, that we have a ‘graph system’ comprised of sets of function approximations (e.g. mlps, relu nets) which are being trained by backprop, in an all-to-all topology due to the continuous relaxation of the discrete graph.
Perhaps the only remaining difference is that we are still thinking in graph terms with nodes and edges. Likely, although we have done some continuous relaxation of the graph representation, we still aim to end up with a sparse representation. However, unstructured sparsity of this kind is penalized by existing hardware for efficiently running matrix ops — i.e. GPUs. Unstructured sparsity very rarely gives any kind of advantage and often is a disadvantage on GPU systems given the way they are designed for large parallel SIMD operations and large contiguous memory accesses 1.
So what we end up finding is both a justification of why we need neural networks and also a new perspective on what exactly they are doing.
Now certainly, perhaps this is not the only way, but it requires unknown advancements beyond current technology. For instance, it is very possible that better variational message passing or MCMC algorithms are discovered that lets there be a more full treatment of uncertainty in parameter estimation which would enable us to move from neural networks towards arbitrary factor graphs. Perhaps hardware designs will change so that unstructured sparsity becomes more beneficial, enabling much less regular neural network architectures to be competitive. Perhaps we discover significantly better discrete optimization algorithms that can compete at scale with backprop enabling us to train extremely large binary networks similarly to how today we can train large continuous valued networks. Certainly there is some possibility of this given recent results in quantising and ultimately training very low-bit precision networks. The combination of these approaches would allow us to get much closer to the vision of truly matching the likely generative process of naturalistic data with an isomorphic model structure.
-
Structured sparsity is extremely effective on GPUs when done correctly and indeed almost all NN architectures are effectively a case of structured sparsity. ↩