Jae-Kyung Cho Being unique is better than being perfect

Diary - Diffusion language models and the LLM reversal curse

A few months ago I saw the news that a model called Gemini Diffusion had come out. First of all, I was curious about how on earth you build a language model with diffusion. Second, I watched the demo video and the insanely fast speed really caught my eye. (You absolutely have to watch the demo video!)

So I decided to read the paper that this is based on. It’s the paper from Alibaba’s Ant Group, Large Language Diffusion Models, which proposes a model known as LLaDA (though the final A isn’t actually in the title… it stands for Large Language Diffusion with mAsking).

I had a lot of fun reading it, and partway through it introduces a concept called the LLM reversal curse, which I found really fascinating. The discussion I had with a teammate about it was fun too, so I’m writing it down.

What is LLaDA?

Let’s take a quick look at what kind of model LLaDA is, and the concept behind how they trained a language model using diffusion. First, the authors argue that the core capabilities of large language models (scalability, in-context learning, instruction following) are not confined to the auto-regressive architecture alone, and can also be achieved through diffusion models.

People who’ve worked on LLMs are familiar with the auto-regressive probability formulation below.
This says you predict the probability of the next token from the probability of the preceding token sequence, and if you minimize the negative log likelihood (NLL) of this, you’ve trained a language model.

To satisfy the NLL minimization formula below, LLaDA proposes the following surrogate loss function.

What this means is that \(x_t\) refers to a text sequence that has been masked at a ratio of t. \(x_0\) is a text sequence that hasn’t been masked at all, i.e. the ground truth. Ultimately, the loss function above is about predicting the entire ground-truth sequence given a masked text sequence, and computing the loss only over the masked tokens. (That’s what \(1\left[x_t^i=M\right]\) means.)

The authors saw this as being the same concept as diffusion. They viewed masking tokens in a text sequence as adding noise, and producing the text sequence from a masked token sequence as de-noising.

There’s one point we should address here: can you actually minimize the NLL loss using this LLaDA loss? This was already proven by DeepMind. But it’s a bit hard, so only read it if you want to (I gave up partway through myself, haha).

Looking at LLaDA through pictures

So far we’ve looked at LLaDA through formulas, and naturally it’s hard to grasp. So what exactly is being trained and how is it used? Thankfully, the authors drew really nice diagrams.

LLaDA pre-training

Anyway, since this is also a language model, it has pre-training and post-training stages. Let’s start with the pre-training stage.

LLaDA pre-training method

  1. Sample with t ~ U(0,1)
  2. Mask t’s worth of tokens
  3. Apply the masked token prediction loss (the LLaDA loss function above)

Pretty simple, right? As you can see from this diagram, LLaDA has one enormous advantage over conventional auto-regressive LMs. Namely, it can obtain attention over the entire sequence. The masked token predictor is set up to compute attention from the tokens of the whole sequence to predict tokens. In other words, it can get information from tokens that come after it.

(🤪 read it or not) The paper’s detailed points:

  • It uses an architecture similar to LLaMA3. However, for convenience it uses MHA instead of GQA.
  • It uses a sequence length of 4K. But about 1% of the data is used randomly in the range [1,4096] to secure generalization performance over various sequence lengths.
  • An 8B-class model is trained on 2.3T tokens.

LLaDA post-training

If you also perform post-training (instruction tuning) using a similar method, it becomes a model you can actually use.
Instruction tuning works much like regular LLM training: you leave the prompt untouched and train by masking only the response side. You also apply a chat template to get the formatting right.

There’s one thing we should address here, which is that LLaDA has a fatal drawback. Namely, it has to fix the sequence length in advance at inference time. That’s what makes de-noising possible. For pre-training it’s fine, since the corpus can be cut to whatever length I want, but for post-training the lengths of prompt-response pairs vary wildly. So padding has to be applied, and here LLaDA uses one trick. Instead of padding tokens, it uses EOS (End-of-Sequence) tokens.

A conventional auto-regressive LM doesn’t compute loss over padding tokens, but LLaDA computes loss over all the EOS tokens. In other words, how many EOS tokens it generates at the end determines the response length.

LLaDA post-training method

  1. Apply masking only to the response part
  2. Append <|EOS|> to the prompt-response length to match the sequence length
  3. The rest is the same as pre-training

(🤪 read it or not) The paper’s detailed points:

  • Uses 4.5 million data samples

LLaDA inference

Now the most important part. So how exactly do you use it?

Basic Inference

  1. Start with the response fully masked. Specify the de-masking (sampling) steps.
  2. Predict all the response tokens
  3. Randomly re-mask some tokens again (fewer than before)
  4. Repeat steps 2~3

Amazingly, it doesn’t generate tokens one at a time, but generates all the tokens at once. Then it masks some of them again. So at every step, all the tokens in the response region keep being regenerated over and over.
With this approach, if you set the de-masking steps equal to the sequence length, it ends up looking like it’s generating tokens one at a time. If you set the de-masking steps appropriately high, quality goes up; if you set them low, inference throughput goes up.

Now the key is deciding which tokens to mask again. Basically you pick them randomly to mask, but the authors propose better methods. First is a strategy of re-masking the tokens that had low probability values.

Low-confidence inference

  1. Start with the response fully masked. Specify the de-masking (sampling) steps.
  2. Predict all the response tokens
  3. Re-mask the low-confidence (=low-probability) tokens again (fewer than before)
  4. Repeat steps 2~3

But interestingly, if you do this, the EOS tokens get generated first. That’s because the most frequently seen token during post-training is EOS (since multiple of them go into a single sample, the absolute amount of training is large). So you don’t get proper results. That’s why they proposed a new method as below.

Semi-autoregressive inference

  1. Start with the response fully masked. Specify the de-masking (sampling) steps.
  2. Divide the response into sequence blocks
  3. Predict all the tokens starting from the front block
  4. Re-mask the low-confidence tokens for that block (fewer than before)
  5. Once a block’s token prediction is done, move on to the next block and repeat 3~4

Instead of generating the entire response in one shot, it divides it into blocks and generates them starting from the front block. This way you can prevent the phenomenon of EOS being generated first.

So how good is LLaDA, really? (feat. reversal curse)

Naturally, the paper is full of benchmarks showing LLaDA is good. The performance wasn’t anything that really stands out, so I won’t record it separately (those curious can go look it up themselves). Instead, there’s a claim that LLaDA solved a really interesting phenomenon, the LLM reversal curse, so I’ll cover that.

Reversal curse

The authors propose a task called Poem completion. There’s a dataset of 496 very famous Chinese poems each made up of two lines. The forward task gives the first line and asks you to fill in the second line, and the reversal task gives the second line and asks you to fill in the first line. These 496 poems are so incredibly famous that they’re absolutely included in the model’s training data. In other words, they measured forward/reversal task performance on a model that has already learned the knowledge of these 496 poems.

The results are as follows.
The auto-regressive LM has very poor reversal performance, whereas LLaDA has similar performance on both tasks. Even with the same data fed in, an LLM that learned text omni-directionally by applying a causal mask can’t properly solve the reversal task. On the other hand, LLaDA, which learned text bi-directionally, has a small performance gap between the forward and reversal tasks!! (Of course the performance is low to begin with, but the authors didn’t mention that part…)

The conversation I had with a teammate about this part is really interesting.
My argument was:

An LLM acquired knowledge, but the fact that whether it can answer or not depends on the direction is not true intelligence. So through the reversal curse experiment, we can see that the auto-regressive LM is just pretending to have intelligence, and is in fact nothing more than a statistical parrot.

My teammate’s argument was:

We all know the English alphabet. But if someone asked you to recite it backward, how many people could do it right away? It depends on how you define knowledge, but it’s not strange for knowing something to have directionality. Reciting it backward is hard precisely because you’ve never done it before, and since you do know it, you just need to do some appropriate reasoning to derive it.

In the end I was convinced!!

📌 Actually, this is similar to the FIM (Fill-in-the-middle) task in the code domain

Proposed in the paper Efficient training of language models to fill in the middle

  • A typical LLM can’t solve the problem of being given the front and back parts of code and filling in the middle
  • On the code side too, they poured in separate data to build a Coder model in order to solve the FIM problem
  • You can solve it by feeding in FIM data, but you have to feed in an enormous amount
  • So LLaDA has potential as a code-specialized model (this is also the part Gemini diffusion is paying attention to)

LLaDA’s limitations

LLaDA actually still has more drawbacks. That’s probably why Gemini diffusion has only put up benchmarks and a demo video on its web page and hasn’t been able to release it yet.

Inefficiency of inference

First, the auto-regressive LM has various algorithms and kernels developed based on the causal mask. The KV-cache is one of them. The KV-cache is a hugely important concept that made computation scale linearly with sequence length. But LLaDA has to recompute attention over the entire sequence every time, so it can’t use a KV-cache.

You can see the difference by doing a simple time-complexity calculation.

  • C : prompt length, L : response length, d : hidden dimension, T : sampling step
    • LLaDA: \(O(T\cdot[(C+L)^2d + (C+L)d])\)
    • ARM (w/o KV-cache): \(O(\sum_{i=1}^{L}[(C+i)^2d + (C+i)d])\)
    • ARM (w/ KV-cache): \(O(L\cdot[d^2+Cd])\)
  • Plugging in C=1024, L=256, d=4096, T=32 gives
    • LLaDA: 9E+11
    • ARM (w/o KV-cache): 6.35E+12
    • ARM (w/ KV-cache): 5E+9

Unfortunately, an efficient attention algorithm for LLaDA hasn’t been developed yet. If one were developed it would be worth pitting against auto-regressive LMs, which is a shame.

Validation on large models

As with most academic papers, LLaDA hasn’t yet been validated on large models. The largest model used in the paper is 7B, which is way too small compared to SOTA models. I’m just waiting for Google to release Gemini diffusion soon https://deepmind.google/models/gemini-diffusion/

Just looking at the speeds Google wrote down, they are pretty staggering.

  • Gemini diffusion is 1479 tokens/sec (based on 32k length generation)
  • Gemini 2.0 flash lite is 181 tokens/sec

For reference, Qwen3-0.6B has a generation throughput of 414.17 tokens/sec (based on sglang). At this level, inference basically finishes the moment you feed it in.

Some scattered thoughts from actually using it

  • If you reduce the sampling steps too much it’s a total mess. Especially in semi-autoregressive mode.
  • But the fact that the user can adjust the speed at all is pretty interesting.
  • I was surprised that the answer quality was better than I expected.
  • It might turn out to be genuinely useful on the code side.
Comments