import gc import copy import time from tenacity import RetryError from tenacity import retry, stop_after_attempt, wait_fixed import torch from transformers import ( AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessorList, MinNewTokensLengthLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper, MinLengthLogitsProcessor ) def get_output_batch( model, tokenizer, prompts, generation_config ): if len(prompts) == 1: encoding = tokenizer(prompts, return_tensors="pt") input_ids = encoding["input_ids"].cuda() generated_id = model.generate( input_ids=input_ids, generation_config=generation_config, max_new_tokens=256 ) decoded = tokenizer.batch_decode(generated_id) del input_ids, generated_id torch.cuda.empty_cache() return decoded else: encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda') generated_ids = model.generate( **encodings, generation_config=generation_config, max_new_tokens=256 ) decoded = tokenizer.batch_decode(generated_ids) del encodings, generated_ids torch.cuda.empty_cache() return decoded # StreamModel is borrowed from basaran project # please find more info about it -> https://github.com/hyperonym/basaran class StreamModel: """StreamModel wraps around a language model to provide stream decoding.""" def __init__(self, model, tokenizer): super().__init__() self.model = model self.tokenizer = tokenizer self.device = "cuda" if torch.cuda.is_available() else "cpu" self.processor = LogitsProcessorList() self.processor.append(TemperatureLogitsWarper(0.9)) self.processor.append(TopPLogitsWarper(0.75)) def __call__( self, prompt, min_tokens=0, max_tokens=16, temperature=1.0, top_p=1.0, n=1, logprobs=0, ): """Create a completion stream for the provided prompt.""" input_ids = self.tokenize(prompt) logprobs = max(logprobs, 0) # bigger than 1 chunk_size = 2 chunk_count = 0 # Generate completion tokens. final_tokens = torch.empty(0) for tokens in self.generate( input_ids[None, :].repeat(n, 1), logprobs=logprobs, min_new_tokens=min_tokens, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, ): if chunk_count < chunk_size: chunk_count = chunk_count + 1 final_tokens = torch.cat((final_tokens, tokens.to("cpu"))) if chunk_count == chunk_size-1: chunk_count = 0 yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) if chunk_count > 0: yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) del final_tokens, input_ids if self.device == "cuda": torch.cuda.empty_cache() def _infer(self, model_fn, **kwargs): with torch.inference_mode(): return model_fn(**kwargs) def tokenize(self, text): """Tokenize a string into a tensor of token IDs.""" batch = self.tokenizer.encode(text, return_tensors="pt") return batch[0].to(self.device) def generate(self, input_ids, logprobs=0, **kwargs): """Generate a stream of predicted tokens using the language model.""" # Store the original batch size and input length. batch_size = input_ids.shape[0] input_length = input_ids.shape[-1] # Separate model arguments from generation config. config = self.model.generation_config config = copy.deepcopy(config) kwargs = config.update(**kwargs) kwargs["output_attentions"] = False kwargs["output_hidden_states"] = False kwargs["use_cache"] = True # Collect special token IDs. pad_token_id = config.pad_token_id bos_token_id = config.bos_token_id eos_token_id = config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] if pad_token_id is None and eos_token_id is not None: pad_token_id = eos_token_id[0] # Generate from eos if no input is specified. if input_length == 0: input_ids = input_ids.new_ones((batch_size, 1)).long() if eos_token_id is not None: input_ids = input_ids * eos_token_id[0] input_length = 1 # Keep track of which sequences are already finished. unfinished = input_ids.new_ones(batch_size) # Start auto-regressive generation. while True: inputs = self.model.prepare_inputs_for_generation( input_ids, **kwargs ) # noqa: E501 outputs = self._infer( self.model, **inputs, # return_dict=True, output_attentions=False, output_hidden_states=False, ) # Pre-process the probability distribution of the next tokens. logits = outputs.logits[:, -1, :] with torch.inference_mode(): logits = self.processor(input_ids, logits) probs = torch.nn.functional.softmax(logits, dim=-1) # Select deterministic or stochastic decoding strategy. if (config.top_p is not None and config.top_p <= 0) or ( config.temperature is not None and config.temperature <= 0 ): tokens = torch.argmax(probs, dim=-1)[:, None] else: tokens = torch.multinomial(probs, num_samples=1) tokens = tokens.squeeze(1) # Finished sequences should have their next token be a padding. if pad_token_id is not None: tokens = tokens * unfinished + pad_token_id * (1 - unfinished) # Append selected tokens to the inputs. input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1) # Mark sequences with eos tokens as finished. if eos_token_id is not None: not_eos = sum(tokens != i for i in eos_token_id) unfinished = unfinished.mul(not_eos.long()) # Set status to -1 if exceeded the max length. status = unfinished.clone() if input_ids.shape[-1] - input_length >= config.max_new_tokens: status = 0 - status # Yield predictions and status. yield tokens # Stop when finished or exceeded the max length. if status.max() <= 0: break