https://arxiv.org/abs/2510.17558
Abstract:
We propose an extension of the decoder Transformer that conditions its generative process on random latent variables which are learned without supervision thanks to a variational procedure. Experimental evaluations show that allowing such a conditioning translates into substantial improvements on downstream tasks.
Author: François Fleuret
Links:
Homepage: https://ykilcher.com
Merch: https://ykilcher.com/merch
YouTube: https://www.youtube.com/c/yannic
Hello there. Today we are looking at the free transformer from France Fur at Farret Meta. This transformer is extending the classic decoderbased transformer with uh a series of latent variables that can be used to kind of make underlying decisions about the sequence to be generated. So for example and that's the example that the paper gives. Let's say you want to train a movie review model. Now any given movie whether it's good or bad will have like good reviews and bad reviews and and just assume for now that these are distinct two groups right there are like I like this movie because of this and this and this and I don't like this movie because of this and this and this. Now if the movie is good you know most of the reviews will be good but some will be bad. If the movie is bad most of the you know reviews will be bad but some will be good. Um but nevertheless right whether your distribution looks
like this or whether it looks like this it it really like okay this is this is movie goodness here. Um and this is how many reviews there are. No matter how there is this biodal distribution and now let's say you are training a transformer on uh movie reviews. So you take a whole bunch of movie reviews, right? So uh here is a movie review, here's a movie review across all movies. Um and maybe the prompt is the description of the movie. So you have the description of the movie and then you train to you know given a description generate a movie review. Now if you have a big enough transformer um and enough data the model will learn to generate these good and bad reviews in the correct proportions. Right? So um in the limit uh you will if you supply a good movie to your trained model uh and
ask it to write a review it will come up with a good review in let's say 90% of the time and the bad review in 10% of the time. Now the question is how does it do that right how does it uh make its because you sample once right when you sample once you get a concrete answer and that concrete answer is like a good review or a bad review and that needs to be internally consistent right you you can say well I like this movie because and then there's reasons why it's good and if you say this movie is bad because that text that follows is a lot different uh than movies that you Like so how does a transformer model achieve from the same movie description if you sample enough times how does it achieve that it is going to give you like uh 90% good and 10% bad reviews and the answer is um it there's obviously a degree of
randomness in the transformer and the answer is that randomness is always at the end at the token sampling step. So if you do have um assume that you have your tokens. So oh we use blue for that one. So assume that uh these here are the tokens of the movie description and then the um the initial tokens of the movie review, right? You keep sampling. So this is auto reggressive. So you sample this token, then you sample this token, then you sample this token, sample this token. And on top of this, you have your model. So several layers of transformer here. Uh layer layer layer. And the question is how do you generate the next token here? So somehow the next token needs to be generated. The way this is done is that um once you are through generating here you will not
get a token directly but you'll get a big distribution over tokens specifically you'll get logits which you normalize but here is your whole vocabulary like a and uh and blah blah blah dog blah blah blah good like all the tokens you could generate and what you will get is a distribution. Okay, like a little bit of that. That makes no sense. Maybe that makes a lot of sense. No, dog doesn't go for the next token. Well, it depends, right? It depends what this text is, but the transformer is going to give you like a distribution over next tokens which you can then use to sample from. Okay, this is the only place where we have randomness in transformers. Now I'm overlooking here the randomness that comes from batch normalization and the randomness that comes from GPU error correction or not correction and the randomness that comes
from quantum fluctuations in two nanometer processors like we'll keep it like this randomness here is is much bigger than those things. So we'll focus on this one. This is the only place where you have randomness. And so specifically, let's look at good and bad movie reviews. The the only thing that the trans that a classic transformer can do is effectively be self-consistent. And so let's say you again you got your movie your movie um description here and you started generating the tokens. at the at the point when you start generating the tokens you like it's you have all the possibilities right and so the first word might be this right this and then the second one might be movie is and now when you sample the next token right
the let's say it's a good movie right the transformer will correctly tell you your token distribution ution good bad it will be like like so right like this will be nine times higher than this now you sample from this and once you sample so let's say it you sample the word good because that's way more likely um you place it here good then the next token that you generate needs to be consistent with everything before and the next token needs to be consistent with everything before. Now there is variation in these tokens but they need to be consistent and therefore what you'll get is probably a a a you know reasons why the movie is good. Whereas if here you know just by sampling from this distribution here you would have chosen the word bad you because of self-consistency you would get a movie review that's you know highlights why the movie isn't particularly good. Um
now I'm not saying everything after this is fixed. All the tokens are obviously still sampled. Um but they the distribution that you sample from needs to be you know make sense in the context of what happened before. So up until this point here both good and bad movie reviews make sense as continuations. But once you make this sampling choice um once you flip the coin like the 9010 coin um and one of the two comes up then you you place a token here and because this token is so deciding for kind of like whether you know this is a good or a bad movie review um the the rest of the of the thing needs to be consistent. Now this is not limited to when the starts with this movie is but this is just the most glaring example. So I hope you can see that sampling certain tokens is what ultimately decides you know this kind of from which bucket here you're
going to get a the whole answer. So the whole answer is kind of dependent on these intermediate little uh sampling steps. Wouldn't it be and that's what this paper asks. Wouldn't it be better if here let's say let's say here at the beginning you make a choice and you make a choice and say well I would like to generate a good movie review and then every single token you generate from this point on can look so here you make a choice like good right this like I like this movie and every single token you generate can come and look at that choice and be informed by that choice. You will get simpler math. You will get um broader consistency, right? You don't have to start vague and and uh and and sort of incrementally decide on these things. So this transformer uh variant
is specifically concerned with the cases when um the underlying data uh makes sense to have like some latent variables that it depends on. We call this a latent variable because it's just in your head. You're thinking like, "Okay, I like this movie." And then you express the movie review and the whole movie review is really um dependent on that latent choice and and and uh not on the not so much on the samplings of the tokens. So this transformer tries to model that. Now the bad part is during like in your training data you only have movie descriptions and movie reviews you do not have um you you don't you don't necessarily have uh these things here available right you don't have the latent variables and it's not so easy as it's just always like one sort of um classification one token that
you shove in a particular ular place actually that exists and that's actually called reasoning like that's what people do with reasoning. So here like here's the prompt and then here they have a deliberate place to put explicit tokens that that uh weren't there during training right they appear during reinforcement learning um [clears throat] uh to to shove that in and then to condition the rest of the answer on these. So that already exists. It's called reasoning. But what we want to do here is we want to say okay let's introduce latent variables that can kind of make these choices um that are would usually just be handled by the random token sampling in the output tokens. Now I have already uh explained a whole bunch of of this first section right here. So we won't dive too much into the paper. Um but this this example here is quite illustrative uh looking at the the math of it. So they're saying okay
consider um a random variable that is a Berni. So like half and half coin flip um and we want X1 through XT. So these are other random variables right T random variables to be equal to Z with independent flips of probability epsilon. Um so probability epsilon I I so I I first flip a coin and then I um flip my epsilon coins and uh if the epsilon hits then my xt is equal to epsilon and if the epsilon doesn't if the epsilon coin doesn't hit then it's the opposite. If I first decide what like if I if I explicitly factor out Z then my math looks like this right? So if I can condition on Z, uh my math looks like this. The probability that any given X is one. Um
given that I know what Z is is this right here, right? Very very simple. However, if we um didn't know Z, right? If we had to express this ultraaggressively, so given um the other X's, what is my X? And the idea here is that well if I don't know Z I have to sort of you know I have to um infer it from the other X's. So I have to have to look at the past tokens to see well what could my random variable be. It's the same thing as saying like well if I generate tokens for my movie review I either look at my latent decision whether I want a good or a bad review or I look at all the previous tokens and somehow try to derive from them whether this is supposed to be a good or a bad review. Turns out unsurprisingly that if you have to look at all the previous tokens and infer that the math gets a
lot more complex. Um so well this this is obviously the extreme example right here but you can see that um your expressions get a lot more tricky at that and this directly translates to well if you want a uh machine learning model to do this on your behalf then it has it it's [clears throat] going to have a much easier time learning this than learning this just like it's a lot or lot less complex function. So um they say that purely autogressive density models suffers from drawbacks unnecessarily comp complicated computation. Uh that means it requires greater capacity. The model needs to be bigger. Um it may be sent off track during the process. For example, if a few uh if a few tokens are generated erroneously like if a few tokens are you know like a mistake and that those
tokens have a big big influence over the trajectory uh you might get off track and then lastly um the key concepts do not appear spontaneously due to the natural factorization. So if I if I just model my P of X1, X2 and so on through XT. If I just model that as my P of XT given XT -1 through X1 sorry X1 times my P of X -1 given you know like Xtus 2 and so on. If I simply model it my sequence auto reggressively like this um that's that's going to turn out way more complex than if I model my sequence and just say well these are effectively independent uh p of xi given z. Now the problem is
um it's ultimately going to be a mix right like you even if you know the latent concept you still have to be consistent with the words at least from a linguistic perspective so what this is going to be is this is always going to be like this and Z this and Z right so um yeah that's that's that but I think the point still holds S if you know like if you have a lot of information here then the dependence on these variables can be a lot simpler maybe right maybe it's only you have to be linguistically constant uh consistent but not necessarily conceptually because that's already captured by Z. So they're they're saying here okay any latent random value uh whatever its statistical dependence with these tokens and other latent uh variables can be expressed under reasonable assumptions as a
function of their [snorts] things that you know like the the tokens other latence and some randomly sampled uh value coming from a random generator. So if you have some source of randomness, you can model and shape this into whatever you want basically. Um so the goal is going to be to introduce that source of randomness to the transformer. Now what you what you might want to do is and uh this that now dives into the territory of like variational autoenccoders where this the free transformer effectively takes uh what a variational autoenccoder does and shoves it into the middle of a transformer to introduce these uh latent variables. So let's talk about let's talk about uh image generation for a while because that's where uh at least I encountered the VAE first and it's it's it's
illustrative in a way. So let's assume um let's assume you have you're generating uh cats cat images right so on one hand cat image and on the other hand uh cat and let's say half the cats have sunglasses right so okay this cat has has big sunglasses okay um now what you want to do is you want to train a model to to generate this biodal distribution of of data so that uh you end up with a generative model that um so like a box and that box you can say please give me a cat and half the time it's going to be a cat without sunglasses and half the time that it's going to be a cat with sunglasses. This is exactly the same situation we're in with the movie reviews. Now what you might say is you might say oh well I'm just going to you know sample like
sample a random variable here from a coin flip from a bernoli 1/2 and I'm going to enter it here right and because you know like it's it's going to you know if that's heads it's going to create this and if that's tails it's going to create this. That's fantastic. That that is a great idea. The question is how do you teach this model to take this variable into account and to map it like that. Right? So you and that that is that is actually the problem. How do you teach the model to consider even these random variables and to map them to the things that you want to map them? The second question um has no answer, right? Like how do you like let's say the model actually pays attention to your random variable that you feed in? How do you how do you make it such that it corrects the sunglasses
when it's heads and the non- sunglasses when it's tails? That has no answer. Um because unless you have training data, right? Unless you can give it like give the the uh rand like unless you know which of the cats have sunglass and which don't and and in that case it's no longer a random variable. In that case you actually supply the label um in which case this comp becomes a conditional generative model then yes but other than that you are just relying on the sort of underlying disentanglement discovery mechanism. Right? If your data is really biodal and this is the biggest you know biggest variance like it's just every the cat there are super uniformly distributed except that half have no sunglasses and half do have sunglasses then you could reasonably expect that um any model that is actually paying attention to the random variable uh is going to map it
like so. Um, but uh it's a bit more tricky. You have to kind of force it to learn that. You have to force it to to be like, "Oh, okay. I can pay attention to this random variable and I should probably I should probably make it so that it can help me." So the whole name of the game of variational autoenccoders is how can we make it so that during training the model learns that this here is helpful information because if during training it can learn its helpful information it will incorporate it into its process and by incorporating it into its process it it kind of start representing the data right and then during inference we can simply do the coin flip right during inference we can simply say okay now Let's flip a coin. Boom. We have uh heads. And because I supply heads and because during training the model has learned to associate heads with one of these two categories, we're good. So the
name of the game is how during training do we teach a variational autoenccoder to pay attention to this latent variable that we don't have any training data for. And the answer is we cheat. The answer is called an encoder. So um like this is your training data, right? And what you train a variational autoenccoder to do is ultimately to reconstruct the training data, right? You want what comes out to be like what goes in. Now in a regular autoenccoder, you will simply take the data and compress it and learn the decoder learns to decompress it again. Now, that's fantastic as a compression tool, but it's not going to work as a generative model because when you want to generate something new, you don't have it to compress it. Okay? So, we need to be a bit smarter. Now, what we're what we're doing is we're also um we also create an
encoder, but this encoder is simply generating this Z right here. It's not it's not generating the full thing. It's simply generating the Z. And this Z has some information about this sample right here. So you might say, well, why doesn't it fall into the same trap as the regular encoder? And that's because we it this we in addition to the reconstruction loss, right? We also really severely limit the amount of information that can be transmitted dur to via the Z. Okay. And because we severely limit the information that can be transmitted um we hope that we can get it just to the right amount so that it can transmit this one bit of information like one bit of information in this case. And because the data is so
biodal and we train all of this, the most useful thing the model can do is to a learn to actually make this one bit of information to be about the sunglasses or the non- sunglasses and b the decoder will learn to then pay attention to this and map it into the sunglasses and the non- sunglasses. Right? So we want to limit the information just enough so that uh the most important things are captured in these latent variables and we limit the information in two ways. For one uh we can simply not make bandwidth available right like that Z if we make this literally one bit um then that's all that can be transmitted here. The second thing is we heavily regularize and that has an additional benefit. So we regularize the Z's that come out here to follow the distribution we ultimately want to sample from. So we
may may want to say okay actually the in practice we want to have like a half and half distribution here. Um in practice you know at inference time we want to sample the Z from this distribution right here. Um what we have to do is make sure that during training the the Z's produced by the encoder follow that distribution because if not then during inference um we sample from a distribution that the encoder isn't familiar with and therefore we'll be mostly out of distribution uh sorry the decoder and therefore will be mostly out of distribution for the decoder. So all of this effectively um takes the form where we say a loss of a variational autoenccoder is the uh loss of reconstruction right uh plus the um like the KL divergence between the uh distribution Q
which is what the encoder produces during training and P of Z which is the distribution that we would like to sample from. So we want these here uh to be close enough. So we force the encoder to only produce the disease in a distributional sense that align with the distribution we ultimately want to sample from. thereby teaching the decoder to kind of make sense out of this this distribution, not some other distribution, out of this distribution, which ensures that during inference we can actually sample from this distribution and get meaningful outputs. All the while um limiting the information so that um so that uh the the the encoder can't just cheat and and give everything to the decoder.
That's essentially it. That's the free transformer. So on the left hand side you can see like a pure uh decoder transformer, right? I have my tokens here. I give it through a decoder only transformer and it gives me my next token or tokens during training. um uh obviously uh a causal transformer I can train jointly over the sequence. On B you see the ideal case during training uh during inference we simply sample a Z and then we produce tokens based on the token so far and the latent uh random variable and then the sampling here is what makes this latent decision during training. However, we don't supply the Z but we replace this by an encoder and the encoder gets to cheat. the encoder gets to look ahead in time. So here we have the whole like sequence. Now due to the causal mask during
training of the of the of the decoder uh there's no look ahead, right? Like this is causally masked but the encoder gets to cheat. The encoder actually gets to look ahead at the whole thing, right? Like before, the encoder gets to look at the thing that should ultimately be reconstructed. And the encoder gets to encode that into a latent variable and supply that to the decoder. Because the encoder gets to cheat during training, the decoder learns to pay attention to what the encoder says. And so this here will be based not just on the tokens, but also on the encoder's uh uh cheated um encoding. And if we can manage if we can manage to limit the information that goes uh through the encoder and to make the encoder output um output outputs that follow overall in a distributional sense the way we want to sample Z during inference. Then we have ourselves a
really nice variational autoenccoder. In practice, this paper does it a bit more sophisticated or a bit more resource um conscious whereas they split the whole transformer into two parts. Um so they just run the sequence during the first part and then they only have like one small block that is an encoder uh based on the decoder outputs right here. So what this means is you don't have to have like a full super big model uh for an encoder and you effectively share a whole a lot of the computation with the decoder right here. Now the decoder can't cheat, can't look ahead, right? So that you still have to do in this little encoder block based on the decoder outputs here. Um but um you save yourself a lot of computation. they put it into the middle because um if you put it too early then um the encoder doesn't have enough power
right uh because it's it's dependent on this computation here. If you put it too late then there's not enough sort of layers to pay attention to the random variable to make actually good use of it. So they put it into the middle. If you want to go deeper here, um they have this kind of learned query vector uh a non-causal transformer block. That's where we cheat. Uh we do have like fully connected layer right here. Um that gives us a set of binary variables and then from those set of binary variables we we sample. Oh no, this here is I guess here is where we sample and then we have another fully connected layer and we add that back into the stream. So this here ultimately is where this Z during training comes from and during inference we simply sample from the distribution. So this is regular transformer block take take tokens
you know push through blocks success and um this is a free transformer. So um during training we go through half the blocks and then we use the the encoder in order to determine the random variables here. And if we're not training, then we use the a uniform sampler right here or a one hot of a uniform sampler um and generate our our Z variable. Add this to the stream back and return. Um here it says train or prefill. And the the idea is well like when you're prefilling um it means you're kind of loading a KV cache and obviously you need to be like your Z variables from the past need to be consistent with the
decisions that were made in the past. But since you're prefilling, you already know the tokens. So you can be consistent there. I guess I guess what you could also do is you could simply store the Z variables that you sampled and add them also to the cache. I'm not entirely sure honestly that that might need to be um yeah or is the preill is the prefilling the the part where I load in like the prompt or something? I'm not entirely sure I have to say here but I am not the best person on on talking about these kinds of things. So this stuff here not uh super important anymore. they do some experiments on synthetic data um where they construct the data set like this where uh you have kind of like at the beginning there is a
letter and then there's a big underscore um big big length of underscores at some point um in this big thing uh there is a block of like eight recurring um instances of this character placed and then some places also randomly have this exclamation mark. So these some places they're kind of like the noise and then um and the latent decision is obviously kind of where to place that block of eight. Now if you now now if you we're playing with how how much how much information do we want to let through uh how much information do we want to let through to the uh through the encoder? And these experiments here show quite nicely. So what we have is um what we have is I believe the blue one is always like a
regular transformer. Oh, [clears throat] sorry. Not a regular one, but it's it's a different Z per sequence in the blue box. And it's the same in each green box. is the same Z sampled. So on the top left you can see there's pretty much no difference, right? Um [laughter] so the blue box all they all have the um they all have uh uh sim uh different Z variables and here the Z variables are consistent like they're the same always and it's like yeah okay like it's what I like looks the same. So this here means that there's not enough information uh going through the encoder at all like it cannot actually transmit any information. If we let too much information through, which is the case here, you can see that it actually
starts erroring, right? Like so this this is inference time here. This is inference time. And what it does is it starts erroring. It starts generating invalid sequences. They're no longer blocks of eight. And what this means is the model hasn't actually learned any like the decoder hasn't learned anything. it is fully relying on the like the encoder just encodes the sequence and the decoder is just like well whatever you say I'm going to copy right it doesn't need to learn anything because the encoder has so much capacity to transmit information to the decoder during training that the encoder simply tells the decoder the solution and the decoder never learns anything that's like if you if you copy your homework and and then you're asked to do the test you are the decoder However, in the middle is where it actually kind of works, right? So, here you can see, oh, all of these has have different Z and and and but within the green box here, they have the same Z. So
you can see that the the the Z actually captures the position um in the sequence and it means that yes uh this transformer has actually learned that there is a there are latent variables it can make use of that and it has learned and that's the part where I said oh this has no answer um it has intrinsically simply associated the latent variable able with the position because that happened to be the most useful thing to do given the structure of the overall data. Right? So that's what we're relying on. We're relying on the structure of the overall data dictates how the decoder and encoder associate outcomes with latent variables. The only important thing is that we are actually um training latent variables that follow the distribution that we then want to sample because if we don't
then um then we will always be out of distribution when we actually sample from our base distribution. We will always be like well the decoder is like well I've never seen anything like that. Like what am I supposed to do with this? All right. So, I don't want to go too much into the actual results right here because um [clears throat] benchmarks are just always so tricky to properly interpret. In this particular case, the free transformer does excel at things like uh coding and math. Uh it does not excel at things like question answering uh and knowledge. And I don't know, you make your own conclusions about why exactly why exactly that is. Now, that's about it what I had to say about this paper. Um, I do think this is definitely an interesting um investigation right here. Whether this
is going to be a, you know, smash hit in the world of large language models, I'm not so sure. Um, and I'm saying that because it it is uh it's it's um we're effectively right introducing this cheating mechanism and then you have to tune exactly the hyperparameters. Um, and there's always a trade-off, right? Even if you limit the amount of information going through and you you penalize the KL divergence to your distribution, it still will not fully match your distribution of your like the distribution that you want to sample from. And you're still ending up with a bit of of cheating here and there, right? And all of that hurts at inference, right? Oh, you're not really sampling from the same distribution. Oh, your decoder has learned to rely on something that is now no longer there and and all of that hurts. So, while I can definitely see this helping for
specific uh use cases where it's very clear that there is this kind of latent structure in the desired output distribution. I am a bit more bearish on this being like the the next big thing, but uh that's just my opinion. So, let me know your opinions. I'd be interested. Uh we do paper discussions almost every Saturday evening on Discord. Um you feel free to come listen in or if you want present uh your own paper or or some someone else's most people present someone else's paper. No requirement for you to actually have written it. That's it. Thank you very much. I'll see you around. Bye-bye.