Longest Path Search

Throwing darts in latent space

Posted

For this post, I would recommend having at least some passing familiarity with sparse autoencoders, though it’s not necessary to have a full one. I’ll explain the high level idea here.

The tl;dr is that, like I talked a bit about in my previous post, we were somewhat skeptical of sparse autoencoders as actually doing anything explicitly meaningful over some much simpler techniques. Other people have similar skepticism, stated in different ways.

For example, this paper finds that sparse autoencoders in general do not beat some (basic, but not completely trivial) baselines in a number of tasks you would expect SAEs would be good at.

Possibly much more topically, this paper which shows that sparse autoencoders can “interpret” randomly initialized transformer models very similarly to trained transformer models. This (roughly) shows that the weights appear almost irrelevant in interpreting features, though there’s a second potential possibility, which is more similar to what I will talk about below.

Either way, structurally, I’ve found sparse autoencoders (SAEs, from here on out) to feel less directly interesting and more “magical” both in how they are explained and why they work. I will preface this post heavily with the fact that I have no current major data to back any of this up, though some small numerical experiments seem to point roughly, but not fully, in the right direction. (More specifically, observations were informed by some early data work that Henry de Valence and I did trying to understand SAEs, though this post expresses my opinions on the topic, which Henry need not share! Huge thanks to him for bearing with me and running some of these experiments.)

Random vectors

First off, it’ll be useful to explain some properties of random vectors. We will assume that we have a list of m vectors x1,,xm𝐑n which are constructed in the following way. For each index i, let zi be a vector with entries drawn, uniformly, from ±1 and define xi=zi/n. In other words, xi is the normalized zi vector. For now, assume mn and n here will be a constant. Then, the following things are true, even when m is exponentially(!) large in n.

  1. With high probability, xiTxj is small for all ij
  2. For any fixed vector y𝐑n, with high probability we have that, for each i, xiTy>τy2/n, where τ>0 is some fixed parameter. In particular, with extremely high probability there is at least one i for which this is true for m large enough.1
  3. In fact, for y where the entries “don’t differ too much” from each other, the distribution of xiTy is nearly Gaussian, so very few of the vectors xiTy will be close to the maximum value. Almost all of them will actually cluster around zero with standard deviation y2/n.2

So, in a very general sense (a) the xi form a nearly-orthogonal, incredibly “over-complete” basis for 𝐑n and (b) for any fixed vector y, there is sure to be some i which is nearly-collinear to y. That is, there is some i such that xi and y point in “almost” the same direction.

It’s worth doing some numerical experiments here, just to see this. I personally recommend something like n=1024 and m=32n=32768 or similar, but up to you. (In fact, I’d recommend messing around with m, making it range from n to 32m and then compare the maximum absolute inner product.3)

Random directions in latent space

Ok, let’s say we do have some LLM trained on a bunch of data. Intermediate layer activations, also known as “residual streams”, of the LLM essentially act a function which maps some complicated space of text into an n dimensional representation of this text. (I’ll call this n dimensional output space the “latent space” as I guess that’s kind of old-school-cool now.) It stands to reason that there are some directions y which are meaningful in that they map to concepts that are “interpretable”.

Now, here’s the rub.

Of course, when you have exponentially many vectors {xi} at least one of those vectors will have very high inner product with such a y! Indeed, for any reasonable numbers of y, there will exist a (small, but nonzero) number of vectors xi with large inner product xiTy.

Here’s a unicode plot for a simple numerical example I ran in Julia, where y is a vector with normal mean 0, variance 1 entries:

julia> histogram(X' * y)
                ┌                                        ┐ 
   [-4.5, -4.0) ┤▏ 4                                       
   [-4.0, -3.5) ┤▏ 8                                       
   [-3.5, -3.0) ┤▍ 46                                      
   [-3.0, -2.5) ┤█▎ 213                                    
   [-2.5, -2.0) ┤███▍ 623                                  
   [-2.0, -1.5) ┤████████▌ 1 545                           
   [-1.5, -1.0) ┤█████████████████▎ 3 157                  
   [-1.0, -0.5) ┤██████████████████████████▌ 4 858         
   [-0.5,  0.0) ┤████████████████████████████████▊ 6 026   
   [ 0.0,  0.5) ┤█████████████████████████████████  6 043  
   [ 0.5,  1.0) ┤██████████████████████████▎ 4 796         
   [ 1.0,  1.5) ┤████████████████▎ 2 960                   
   [ 1.5,  2.0) ┤████████▊ 1 606                           
   [ 2.0,  2.5) ┤███▍ 621                                  
   [ 2.5,  3.0) ┤█▎ 199                                    
   [ 3.0,  3.5) ┤▍ 48                                      
   [ 3.5,  4.0) ┤▏ 14                                      
   [ 4.0,  4.5) ┤▏ 1                                       
                └                                        ┘ 
                                 Frequency   

This nicely illustrates points 2 and 3 above.

Interpretability?

Ok, so we now can make the following (weak) claim.

Given some text input, which has some interesting features (for example, a text about dogs) there will essentially always exist some direction y which is highly correlated with whatever is in the latent space (in our example, probably with respect to “dogs”). From the previous discussion, there will be some xi (or a very small number of such xi) which have large inner products with this y. In turn, this makes those xi themselves interpretable.

But any meaningful text will have some direction y in latent space, which in turn will be nearly collinear with some (very small!) number of vectors xi.

But then… why do we need to train SAEs at all if this is the case?

Indeed, here’s a simple “sparse autoencoder”: take the top k largest inner products xiTy with latent vector i, and zero the rest out. We will then say a “learned” feature i is active if xiTy is in the top k values, over all possible i. This forms the “activations” for the sparse autoencoder. The output mapping from these activations can really be anything, but one simple example is learning a least-squares approximation to the output given some input corpus.

In fact, I’m going to go slightly further and make the following (more annoying, but also harder to prove) claim that almost all directions in latent space are probably meaningful. In fact, there’s probably enough of these meaningful directions that many, if not most, of the xi will themselves be interpretable. (This is much stronger than the previous claim that certainly there will be some xi that are interpretable.)

Too many directions!

The problem, at least for the weak claim, comes down to this.

By virtue of our choice of random directions, we are putting out an exponential number of hypotheses m about “meaningful” directions (the xi) that are all ~ uncorrelated (by item 1 above). On the other hand, these hypotheses “live” in a much smaller number of dimensions nm. We should be extremely suspicious whenever we do this, since essentially any direction in the latent space that is “meaningful” will be picked up by the xi.

It is, in turn, not clear that doing additional training of the SAEs is doing what we expect, other than roughly “fitting the corpus to reasonable dimensions” but it’s not even clear that this is necessary!

Surely someone has some basic data on this (even if just as the initialization for some SAE training run?) so I’d be curious to see if that confirms the claim. My suspicion is that the features from these random xi won’t be quite as interpretable as a “trained” SAE, but my guess is it’s probably not as far off as one would expect. This seems at least slightly suggested by the second paper linked above, but, unfortunately, I guess we can’t know until we try!


Footnotes

  1. For the nerds: to see this, use Payley–Zigmund and the bound y44y24.

  2. See, e.g., here.

  3. A simple Julia one-liner for point 1 is X = rand([-1,1], n, m) ./ sqrt(n); Z = X' * X; maximum(abs.(Z - I)). Make sure to import LinearAlgebra to get I to work. Don’t forget to also define m and n! Using the n and m given here takes around 10 or so seconds to run on my laptop.