Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModel | |
def add_gumbel_noise(logits, temperature): | |
''' | |
The Gumbel max is a method for sampling categorical distributions. | |
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. | |
Thus, we use float64. | |
''' | |
logits = logits.to(torch.float64) | |
noise = torch.rand_like(logits, dtype=torch.float64) | |
gumbel_noise = (- torch.log(noise)) ** temperature | |
return logits.exp() / gumbel_noise | |
def get_num_transfer_tokens(mask_index, steps): | |
''' | |
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. | |
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), | |
the expected number of tokens transitioned at each step should be consistent. | |
This function is designed to precompute the number of tokens that need to be transitioned at each step. | |
''' | |
mask_num = mask_index.sum(dim=1, keepdim=True) | |
base = mask_num // steps | |
remainder = mask_num % steps | |
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base | |
for i in range(mask_num.size(0)): | |
num_transfer_tokens[i, :remainder[i]] += 1 | |
return num_transfer_tokens | |
def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., | |
cfg_scale=0., remasking='low_confidence', mask_id=126336): | |
''' | |
Args: | |
model: Mask predictor. | |
prompt: A tensor of shape (1, l). | |
steps: Sampling steps, less than or equal to gen_length. | |
gen_length: Generated answer length. | |
block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. | |
temperature: Categorical distribution sampling temperature. | |
cfg_scale: Unsupervised classifier-free guidance scale. | |
remasking: Remasking strategy. 'low_confidence' or 'random'. | |
mask_id: The toke id of [MASK] is 126336. | |
''' | |
x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) | |
x[:, :prompt.shape[1]] = prompt.clone() | |
prompt_index = (x != mask_id) | |
assert gen_length % block_length == 0 | |
num_blocks = gen_length // block_length | |
assert steps % num_blocks == 0 | |
steps = steps // num_blocks | |
for num_block in range(num_blocks): | |
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) | |
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) | |
for i in range(steps): | |
mask_index = (x == mask_id) | |
if cfg_scale > 0.: | |
un_x = x.clone() | |
un_x[prompt_index] = mask_id | |
x_ = torch.cat([x, un_x], dim=0) | |
logits = model(x_).logits | |
logits, un_logits = torch.chunk(logits, 2, dim=0) | |
logits = un_logits + (cfg_scale + 1) * (logits - un_logits) | |
else: | |
logits = model(x).logits | |
logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | |
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l | |
if remasking == 'low_confidence': | |
p = F.softmax(logits.to(torch.float64), dim=-1) | |
x0_p = torch.squeeze( | |
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l | |
elif remasking == 'random': | |
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) | |
else: | |
raise NotImplementedError(remasking) | |
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf | |
x0 = torch.where(mask_index, x0, x) | |
confidence = torch.where(mask_index, x0_p, -np.inf) | |
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) | |
for j in range(confidence.shape[0]): | |
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) | |
transfer_index[j, select_index] = True | |
x[transfer_index] = x0[transfer_index] | |
return x | |
def main(): | |
device = 'cuda' | |
model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() | |
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) | |
prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?" | |
# Add special tokens for the Instruct model. The Base model does not require the following two lines. | |
m = [{"role": "user", "content": prompt}, ] | |
prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) | |
input_ids = tokenizer(prompt)['input_ids'] | |
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) | |
out = generate(model, input_ids, steps=128, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence') | |
print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]) | |
if __name__ == '__main__': | |
main() |