import torch import torch.nn.functional as F def apply_temperature(scores, tempt): if tempt > 0: scores = scores / tempt return scores def apply_top_p(scores, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1): if top_p > 0 and top_p < 1: sorted_logits, sorted_indices = torch.sort(scores, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs <= (1 - top_p) if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) scores = scores.masked_fill(indices_to_remove, filter_value) return scores def apply_top_k(logits, top_k): top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits.float(), top_k)[0][..., -1, None] logits[indices_to_remove] = -float("Inf") return logits def apply_advanced_repetition_penalty( input_ids, scores, penalty_range, penalty_slope, penalty ): penalty_range = int(penalty_range) clipped_penalty_range = min(input_ids.shape[-1], penalty_range) if penalty != 1.0: if penalty_range > 0: if clipped_penalty_range < input_ids.shape[1]: input_ids = input_ids[..., -clipped_penalty_range:] if penalty_slope != 0: _penalty = ( torch.arange( penalty_range, dtype=scores.dtype, device=scores.device ) / (penalty_range - 1) ) * 2.0 - 1 _penalty = (penalty_slope * _penalty) / ( 1 + torch.abs(_penalty) * (penalty_slope - 1) ) _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1) penalty = _penalty[..., -clipped_penalty_range:] score = torch.gather(scores, 1, input_ids) score = torch.where(score <= 0, score * penalty, score / penalty) scores.scatter_(1, input_ids, score) return scores class LmGeneration: def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer def generate(self, args, prompts, cut_off=None, cut_off_times=1): if cut_off is not None: cut_off_times = [cut_off_times for i in range(len(prompts))] batch = len(prompts) assert batch <= args.batch_size prompt_tokens = [args.tokenizer.encode(x, bos=True, eos=False) for x in prompts] min_prompt_len = min([len(x) for x in prompt_tokens]) # max_prompt_len = max([len(x) for x in prompt_tokens]) total_len = args.seq_length device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokens = torch.full((batch, total_len), self.tokenizer.pad_token).to(device).long() for idx, t in enumerate(prompt_tokens): tokens[idx, : len(t)] = torch.tensor(t).long() mask = tokens != self.tokenizer.pad_token start_pos = min_prompt_len prev_pos = 0 continue_exsample = [i for i in range(batch)] with torch.no_grad(): for cur_pos in range(start_pos, total_len): logits = self.model.forward(tokens[continue_exsample, prev_pos:cur_pos], prev_pos, continue_exsample).float() next_token_scores = apply_top_k(logits, top_k=args.top_k) next_token_scores = apply_top_p(next_token_scores, args.top_p) next_token_scores = apply_temperature(next_token_scores, args.temperature) next_token_scores = apply_advanced_repetition_penalty( tokens[continue_exsample, :cur_pos], next_token_scores, args.repetition_penalty_range, args.repetition_penalty_slope, args.repetition_penalty ) scores = F.softmax(next_token_scores, dim=-1) next_token = torch.multinomial(scores, num_samples=1).squeeze(1) next_token = next_token.reshape(-1) next_token = torch.where( mask[continue_exsample, cur_pos], tokens[continue_exsample, cur_pos], next_token ) tokens[continue_exsample, cur_pos] = next_token prev_pos = cur_pos # remove eos examples. continue_exsample = [] for i, t in enumerate(tokens.tolist()): try: t.index(self.tokenizer.eos_token) except ValueError: if cut_off is not None: if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]: if cut_off_times[i] == 1: continue else: cut_off_times[i] -= 1 continue_exsample.append(i) if len(continue_exsample) == 0: break decoder = [] for i, t in enumerate(tokens.tolist()): t = t[: args.seq_length] try: t = t[: t.index(self.tokenizer.pad_token)] t = t[: t.index(self.tokenizer.eos_token)] except ValueError: pass decoder.append(self.tokenizer.decode(t)) return decoder