# Guidelines Here, we provide guidelines for the model architecture, pre-training, SFT, and inference of LLaDA. ## Model Architecture LLaDA employs a Transformer Encoder as the network architecture for its mask predictor. In terms of trainable parameters, the Transformer Encoder is identical to the Transformer Decoder. Starting from an autoregressive model, we derive the backbone of LLaDA by simply removing the causal mask from the self-attention mechanism as following.
In addition, LLaDA designates a reserved token as the mask token (i.e., 126336). ## Pre-training The pre-training of LLaDA is straightforward and simple. Starting from an existing autoregressive model training code, only a few lines need to be modified. We provide the core code (i.e., loss computation) here. ```angular2html def forward_process(input_ids, eps=1e-3): b, l = input_ids.shape t = torch.rand(b, device=input_ids.device) p_mask = (1 - eps) * t + eps p_mask = p_mask[:, None].repeat(1, l) masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask # 126336 is used for [MASK] token noisy_batch = torch.where(masked_indices, 126336, input_ids) return noisy_batch, masked_indices, p_mask # The data is an integer tensor of shape (b, 4096), # where b represents the batch size and 4096 is the sequence length. input_ids = batch["input_ids"] # We set 1% of the pre-training data to a random length that is uniformly sampled from the range [1, 4096]. # The following implementation is not elegant and involves some data waste. # However, the data waste is minimal, so we ignore it. if torch.rand(1) < 0.01: random_length = torch.randint(1, input_ids.shape[1] + 1, (1,)) input_ids = input_ids[:, :random_length] noisy_batch, masked_indices, p_mask = forward_process(input_ids) logits = model(input_ids=noisy_batch).logits token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] loss = token_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) ``` ## SFT First, please refer to Appendix B.1 for the preprocessing of the SFT data. After preprocessing the data, the data format is as follows. For simplicity, we treat each word as a token and set the batch size to 2 in the following visualization. ```angular2html input_ids: user\nWhat is the capital of France?assistant\nParis. user\nWhat is the capital of Canada?assistant\nThe capital of Canada is Ottawa, located in Ontario. prompt_lengths: [17, 17] ``` After preprocessing the SFT data, we can obtain the SFT code by making simple modifications to the pre-training code. The key difference from pre-training is that SFT does not add noise to the prompt. ```angular2html input_ids, prompt_lengths = batch["input_ids"], batch["prompt_lengths"] noisy_batch, _, p_mask = forward_process(input_ids) # Do not add noise to the prompt token_positions = torch.arange(noisy_batch.shape[1], device=noisy_batch.device).expand(noisy_batch.size(0), noisy_batch.size(1)) prompt_mask = (temp_tensor < prompt_length.unsqueeze(1)) noisy_batch[prompt_mask] = input_ids[prompt_mask] # Calculate the answer length (including the padded tokens) prompt_mask = prompt_mask.to(torch.int64) answer_lengths = torch.sum((1 - prompt_mask), dim=-1, keepdim=True) answer_lengths = answer_length.repeat(1, noisy_batch.shape[1]) masked_indices = (noisy_batch == 126336) logits = model(input_ids=noisy_batch).logits token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] ce_loss = torch.sum(token_loss / answer_lengths[masked_indices]) / input_ids.shape[0] ``` ## Sampling Overall, we categorize LLaDA's sampling process into three types: fixed-length, semi-autoregressive-origin, and semi-autoregressive-padding. **It is worth noting that the semi-autoregressive-origin method was not mentioned in our paper, nor did we provide the corresponding code**. However, we include it here because we believe that sharing both our failures and insights from the exploration process is valuable. These three sampling methods are illustrated in the figure below.
For each step in the above three sampling processes, as detailed in Section 2.4 in our paper, the mask predictor first predicts all masked tokens simultaneously. Then, a certain proportion of these predictions are remasked. To determine which predicted tokens should be re-masked, we can adopt two strategies: *randomly remasking* or *low-confidence remasking*. Notably, both remasking strategies can be applied to all three sampling processes mentioned above. For the LLaDA-Base model, we adapt low-confidence remasking to the three sampling processes mentioned above. We find that fixed-length and semi-autoregressive-padding achieve similar results, whereas semi-autoregressive-origin performs slightly worse. For the LLaDA-Instruct model, the situation is slightly more complex. First, if the semi-autoregressive-origin method is used, the Instruct model performs poorly. This is because, during SFT, each sequence is a complete sentence (whereas in pre-training, many sequences are truncated sentences). As a result, during sampling, given a generated length, regardless of whether it is long or short, the Instruct model tends to generate a complete sentence. Unlike the Base model, it does not encounter cases where a sentence is only partially generated and needs to be continued. When performing fixed-length sampling with a high answer length (e.g., greater than 512), we find that low-confidence remasking results in an unusually high proportion of `` tokens in the generated sentences, which severely impacts the model's performance. In contrast, this issue does not arise when randomly remasking is used. Furthermore, since low-confidence remasking achieved better results in the Base model, we also hoped that it could be applied to the Instruct model. We found that combining low-confidence remasking with semi-autoregressive-padding effectively mitigates the issue of generating an excessively high proportion of tokens. Moreover, this combination achieves slightly better results than randomly remasking & fixed-length. You can find more details about the sampling method in our paper.