The Free Transformer
Directly from the paper.
Suppose we train a model to generate movie reviews. The training set might look like
and so on, with clearly separate good and bad reviews.
With enough data, a reasonably sized transformer model would learn to generate these two types of reviews. But to make these reviews, the model generates several tokens before deciding whether the review was positive or negative.
The model doesn't explicitly choose to generate good or bad reviews. Instead, the choice is implicit in token selection.
It would be much easier if the latent variables weren't latent. If our training data had all the latent variables fully exposed, then this would be a straightforward task: fix some choices of the latent variables $Z$ and then feed these into the models.
IDEA: Introduce a variational autoencoder, a generative deep learning model that learns how to provide a distribution of latent $Z$ from training data.
What might this look like?
The Ideal (Expensive) Transformer
For training, we insert a (conditional) VAE and hope that it learns to give useful latent $Z$. This might look like an encoder-transformer. (It can't be purely autoregressive, as it's trying to learn long-context patterns).
There are setup and training subtleties.
One must limit the information in $Z$ (e.g. don't encode everything in $Z$). There are standard VAE strategies (ELBO+KL) one can use that penalizes overconcentration in the learned latent distribution of $Z$.
Generation is straightforward. Inject a random state $Z$.
We won't know what latent variables are being selected, but if training worked then the model has learned how to use it. For example, for some $Z$, the model might know that it's producing a "good" movie review.
The paper doesn't stop there. This model is a bit more than twice as expensive to train as a standard decoder-only model. (The results are presumably better, but this isn't mentioned).
Main idea: Reuse as much of the decoder-transformer computation as possible. If there are $L$ layers of decoder blocks, split them in half. After $L/2$ layers, insert a single non-causal transformer block.
Interpret the output as a latent variable $Z$
Then apply the remaining $L/2$ layers.
Fix $H = 16$, say. For each token $t$, the encoder block should produce a vector of dimension $H$, sample† a value in $\{0, \ldots, 2^H - 1 \}$, and ultimately one-hot encode this value as a $2^H = 65{,}536$ dimensional vector.
This is the latent variable $Z_t$ associated to the token $t$.
In order to choose latent variables for generation, you choose a sequence of values of $\{0, \ldots, 2^H - 1 \}$ and one-hot encodes them.
Before the second set of decoders (after the encoder provides $Z$), add a single fully-connected linear layer. Add the output to the input key-value tensors in the next decoder.
The remaining $L/2 - 1$ decoder layers are typical autoregressive transformer decoders.
As with the simpler model, for generating new outputs the free transformer model acts like a typical transformer, but with randomness injected in a middle layer. This randomness (theoretically) fixes latent states and (theoretically) increases output cohesion.
Compared to a typical decoder architecture with $L$ transformer blocks, the Free transformer essentially adds one block for the VAE. This adds an overhead to time and memory of approximately $1/L$.
Compared to the ideal (expensive) version noted earlier, this architecture has approximately half the time and memory overhead. So yes, it is very inexpensive!
The paper compared a 1.5B model (that I don't discuss here) and an 8B Llama-3 type model with 32 layers. The 8B model was given either 200B tokens or 1T tokens. Training was tuned for the baseline model and then also used for the Free Transformer variant. These models were then applied to a variety☆ of benchmarks.
Aside: the models trained for this paper used (at least) the equivalent of 256 H100 GPUs for 2 weeks.
The paper also trained a toy model on synthetic data to determine if $Z$ was being used.
In short: the model gives modest improvements on some tasks and seems to condition on unknown latent variables during training.
The paper also notes that no hyperparameters were tuned for the proposed model (except some tuning for $\kappa$, the information/token rate allowed in $Z$).
I only have one negative observation.
The new architecture introduces many more hyperparameters, and it's only "free" because no effort is made to optimize over them. ($\kappa$, $H$, $L/2$ encoder placement, to name the most important). Maybe broad guidelines are possible, but it's likely application and architecture specific.
Thank you. If you have any comments, please direct them to my short post.