# Guidelines
Here, we provide guidelines for the model architecture, pre-training, SFT, and inference of LLaDA.
## Model Architecture
LLaDA employs a Transformer Encoder as the network architecture for its mask predictor.
In terms of trainable parameters, the Transformer Encoder is identical to the Transformer
Decoder. Starting from an autoregressive model, we derive the backbone of LLaDA by simply
removing the causal mask from the self-attention mechanism as following.
In addition, LLaDA designates a reserved token as the mask token (i.e., 126336).
## Pre-training
The pre-training of LLaDA is straightforward and simple. Starting from an existing
autoregressive model training code, only a few lines need to be modified.
We provide the core code (i.e., loss computation) here.
```angular2html
def forward_process(input_ids, eps=1e-3):
b, l = input_ids.shape
t = torch.rand(b, device=input_ids.device)
p_mask = (1 - eps) * t + eps
p_mask = p_mask[:, None].repeat(1, l)
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
# 126336 is used for [MASK] token
noisy_batch = torch.where(masked_indices, 126336, input_ids)
return noisy_batch, masked_indices, p_mask
# The data is an integer tensor of shape (b, 4096),
# where b represents the batch size and 4096 is the sequence length.
input_ids = batch["input_ids"]
# We set 1% of the pre-training data to a random length that is uniformly sampled from the range [1, 4096].
# The following implementation is not elegant and involves some data waste.
# However, the data waste is minimal, so we ignore it.
if torch.rand(1) < 0.01:
random_length = torch.randint(1, input_ids.shape[1] + 1, (1,))
input_ids = input_ids[:, :random_length]
noisy_batch, masked_indices, p_mask = forward_process(input_ids)
logits = model(input_ids=noisy_batch).logits
token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
loss = token_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
```
## SFT
First, please refer to Appendix B.1 for the preprocessing of the SFT data. After preprocessing the data,
the data format is as follows. For simplicity, we treat each word as a token and set the batch size to 2
in the following visualization.
```angular2html
input_ids:
user\nWhat is the capital of France?assistant\nParis.
user\nWhat is the capital of Canada?assistant\nThe capital of Canada is Ottawa, located in Ontario.
prompt_lengths:
[17, 17]
```
After preprocessing the SFT data, we can obtain the SFT code by making simple modifications to the pre-training code.
The key difference from pre-training is that SFT does not add noise to the prompt.
```angular2html
input_ids, prompt_lengths = batch["input_ids"], batch["prompt_lengths"]
noisy_batch, _, p_mask = forward_process(input_ids)
# Do not add noise to the prompt
token_positions = torch.arange(noisy_batch.shape[1], device=noisy_batch.device).expand(noisy_batch.size(0), noisy_batch.size(1))
prompt_mask = (temp_tensor < prompt_length.unsqueeze(1))
noisy_batch[prompt_mask] = input_ids[prompt_mask]
# Calculate the answer length (including the padded tokens)
prompt_mask = prompt_mask.to(torch.int64)
answer_lengths = torch.sum((1 - prompt_mask), dim=-1, keepdim=True)
answer_lengths = answer_length.repeat(1, noisy_batch.shape[1])
masked_indices = (noisy_batch == 126336)
logits = model(input_ids=noisy_batch).logits
token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
ce_loss = torch.sum(token_loss / answer_lengths[masked_indices]) / input_ids.shape[0]
```
## Sampling
Overall, we categorize LLaDA's sampling process into three types: fixed-length, semi-autoregressive-origin, and semi-autoregressive-padding.
**It is worth noting that the semi-autoregressive-origin method was not mentioned in our paper, nor did we provide the corresponding code**.
However, we include it here because we believe that sharing both our failures and insights from the exploration process is valuable.
These three sampling methods are illustrated in the figure below.
For each step in the above three sampling processes, as detailed in Section 2.4 in our paper, the mask predictor
first predicts all masked tokens simultaneously. Then, a certain proportion of these predictions are remasked.
To determine which predicted tokens should be re-masked, we can adopt two strategies: *randomly remasking* or
*low-confidence remasking*. Notably, both remasking strategies can be applied to all three sampling processes
mentioned above.
For the LLaDA-Base model, we adapt low-confidence remasking to the three sampling processes mentioned above.
We find that fixed-length and semi-autoregressive-padding achieve similar results, whereas semi-autoregressive-origin
performs slightly worse.
For the LLaDA-Instruct model, the situation is slightly more complex.
First, if the semi-autoregressive-origin method is used,
the Instruct model performs poorly. This is because, during SFT, each sequence is a complete sentence (whereas in pre-training,
many sequences are truncated sentences). As a result, during sampling, given a generated length, regardless of whether it is
long or short, the Instruct model tends to generate a complete sentence. Unlike the Base model, it does not encounter cases
where a sentence is only partially generated and needs to be continued.
When performing fixed-length sampling with a high answer length (e.g., greater than 512),
we find that low-confidence remasking results in an unusually high proportion of `` tokens in
the generated sentences, which severely impacts the model's performance. In contrast, this
issue does not arise when randomly remasking is used.
Furthermore, since low-confidence remasking achieved better results in the Base model, we also hoped that it could be applied to
the Instruct model. We found that combining low-confidence remasking with semi-autoregressive-padding effectively mitigates
the issue of generating an excessively high proportion of tokens. Moreover, this combination achieves
slightly better results than randomly remasking & fixed-length.
You can find more details about the sampling method in our paper.