Epistemic status: Highly speculative, basically shower thoughts. These are some thoughts I had a few months back but just got motivation to write them up today.
While now it is becoming increasingly obvious that the representations learned by large neural networks are significantly linear and compositional, there is still to my mind a big question of why this is the case. It is not at all clear a-priori why the representations learned by SGD have to form such a nice structure which is easy for us to understand and manipulate through linear methods.
There are a couple of probably related explanations of this explicitly or implicitly in the literature.
1.) The loss landscape has a natural bias towards simple, flat, and smooth solutions that generalise, and perhaps the representations in such solutions are naturally linear. The causality goes through flatness to generalisation to linearity. Flat solutions should be easier to find because they occupy much larger volumes in solution space than ‘sharp’ solutions. I make a similar argument here. The big question with this is why ‘flat’ solutions should tend to have representations which are linear, modular, and compositional. This doesn’t seem necessary to me a-priori.
2.) Neural networks naturally tend to learn smooth functions which are reasonably low frequency and so don’t tend to massively overfit even when the network is strongly overparametrized — thus explaining results around double descent. It is either argued that this is due to some implicit bias of SGD to find low frequency solutions or instead that there is some parameter of the architecture of neural networks — i.e. the parameter-space to function space mapping that makes this likely. The argument here is that although low frequency functions take up a tiny volume in function-space, they may actually take up a much larger volume in the parameter-space of the neural network that implements the functions, although why this would be the case is unclear to me.
3.) The learning algorithm — SGD (with Adam) itself has some natural implicit or explicit bias towards learning simple functions. A large number of works have focused on various aspects of this hypothesis and indeed shown that SGD does have at least some biases although these are usually proven in relatively toy scenarios.
Here, I want to propose a slightly different argument which is close in spirit to anthropic arguments. Namely, SGD is a weak optimiser which can only really optimise effectively over smooth, low frequency, convex-like functions, the reason we see convex, linear-seeming representations is because these are the only kind of representations that can be easily found by SGD in the first place.
The story then is that, when we are training a deep learning model, what we are really doing is finding a subspace of parameters for which, by luck, the problem is mostly convex and linear, and then training effectively only in that subspace. The rest of the parameters are supposedly training but SGD is making effectively zero progress due to the difficulty of the loss landscape.
This hypothesis proposes that we should look at a DL system not as a full parameter space, but instead in terms of parameter subsets. By basic combinatorics, the number of possible subsets of the space is extremely larger (although of course the subspaces intersect considerably) and that most subspaces occur with a fraction of the total parameters of the model. The argument is that while the optimisation landscape generally looks rough and un-navigable for almost all subspaces of the full parameter space, for just a few subspaces, by essentially pure chance, it looks nearly convex. Pretraining can then be thought of as optimising over a susperposition of these subspaces. In almost all the subspaces SGD makes relatively little progress because the dynamics are too challenging. However, this does not particularly impede optimisation in the others because at initialisation each subspace is essentially adding a random Gaussian to the output and these are averaged away across all the subspaces. During training, then, the combination of the few subspaces that are optimisable dominates while the rest fade away , resulting in their linear and convex representations dominating after training.
This hypothesis makes some predictions which seem to be surprising and hold up given the literature. Firstly, it is very consistent with the lottery ticket hypothesis — and is essentially very similar to it — that an overparametrized network already contains a subnetwork which represents a solution fairly close to the true one by luck. This would then suggest why some specific subspaces are linear and convex in operation — because by chance they got initialized essentially in the basin that contains the solution, and this would also suggest why generalising solutions tend to be found — because flat solutions take up exponentially more volume than sharp ones and so randomly initialised subnetworks are more likely to be in the flat basins than the sharp ones (assuming of course that random initialisation gives you something close to a uniform measure over some relevant region of solution-space). This hypothesis also makes sense given pruning results since it would predict many neurons are effectively contributing nothing but noise to the network’s output, since they are not included in the subspace which is actually optimised. An additional interesting prediction which I haven’t seen any dramatic results on, is instead pruning by subspaces instead of by neutrons — i.e .first rotating the network into the most prunable basis and then pruning. This, however, is fairly obvious from a number of perspectives since it is obvious that neurons are not the correct unit of analysis in any case. It is also consistent with the findings that neural network training appears to happen primarily in small subspaces with only a few eigenvectors in the Hessian driving most of the dynamics, and that gradient updates can be well approximated by a low rank matrix while retaining most of their efficacy. Finally, it makes it clear why significantly overparametrized models are necessary for DL systems to perform well.
While the hypothesis seems to predict most of these known and a-priori facts about neural networks, many other hypotheses also make these predictions (or at least fit to them) and it is unclear what novel predictions our specific hypothesis might make. One is that the subspace in which training primarily occurs should be approximately consistent throughout the training process - -i.e. the top eigenvectors of the gradients and Hessian remain stable. Another is that it is better to think of the lottery ticket hypothesis in terms of subspaces and not subnetworks — i.e. we should be able to find subspaces which already get decent performance at initialisation time. It might also be possible to figure out early on in training the principal subspaces being used and prune away the rest of the network at that point without a significant hit for performance 1. If that is possible it would be a strong validation of the hypothesis.
Additionally, it is unclear if we can or should consider the parameter subspaces of the network in this way. For one thing, we can’t really consider all the subspaces of parameters independently, since they share many individual weights with other subspaces. This substantially reduces the amount of ‘effective’ subspaces due to the correlation existing between subspaces. However, somehow despite this the lottery ticket shows it is often possible to find subnetworks which are decent by pure chance so presumably the same thing would apply (maybe more so) to subspaces. Secondly, it is likely that SGD can in fact optimise nonconvex functions decently well (although nowhere near as well as convex ones) and thus the SGD operates in some set of both nearly convex and more non convex and nonlinear loss landscapes and thus the representations learnt by the network will be some combination of linear and nonlinear. It is very possible that we just lack the tools to properly describe and extract such representations and hence they are effectively invisible to us since they are lost in the noise – i.e. we model the problem as linear + noise but in fact it is linear + nonlinear + noise and we are just claiming to see a lot of noise instead of the nonlinear + small noise term. Finally, it is not clear if this hypothesis adds anything major over the lottery ticket hypothesis and likely inherits many of the objections to that idea. Nevertheless, I think it is a reasonable hypothesis which would be worthwhile to explore.