from typing import Any, Dict, List import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer MAX_TOKENS_IN_BATCH = 4_000 # Hard limit to prevent OOMs DEFAULT_MAX_NEW_TOKENS = 10 # By default limit the output to 10 tokens class EndpointHandler: """ This class is used to handle the inference with pre and post process for text2text models. See https://huggingface.co./docs/inference-endpoints/guides/custom_handler for more details. """ def __init__(self, path: str = ""): try: self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto") except: import accelerate print(f"ACCELERATE VERSION: {accelerate.__version__}") raise def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ This method is called when the endpoint is called. Arguments --------- data (Dict[str, Any]): Must contains the input data under `input` key and any parameters for the inference under `parameters`. Returns ------- output (List[Dict[str, Any]]): A list, length equal to the number of outputted characters, where each item is a dictionary containing `generated_text` (i.e the character), `perplexity` and `first_token_probs`. """ input_texts = data["inputs"] generate_kwargs = data.get("parameters", {}) # This is not technically a generate_kwarg, but needs to live under parameters check_first_tokens = generate_kwargs.pop("check_first_tokens", None) max_new_tokens = ( generate_kwargs.pop("max_new_tokens", None) or DEFAULT_MAX_NEW_TOKENS ) # Tokenizing input texts inputs = self.tokenizer( input_texts, return_tensors="pt", padding=True, truncation=True, )["input_ids"] # Make sure not to OOM if too many inputs assert inputs.dim() == 2, f"Inputs have dimension {inputs.dim()} != 2" total_tokens = inputs.shape[0] * (inputs.shape[1] + max_new_tokens - 1) assert ( total_tokens <= MAX_TOKENS_IN_BATCH ), f"Passed {total_tokens} (shape: {inputs.shape}, max_new_tokens: {max_new_tokens}), which is greater than limit of {MAX_TOKENS_IN_BATCH}" # Run inference on GPU inputs = inputs.to("cuda:0") with torch.no_grad(): outputs = self.model.generate( inputs, output_scores=True, return_dict_in_generate=True, max_new_tokens=max_new_tokens, **generate_kwargs, ) inputs = inputs.to("cpu") scores = [s.to("cpu") for s in outputs.scores] del outputs # process outputs to_return: Dict[str, Any] = { "generated_text": self._output_text_from_scores(scores), "perplexity": [float(p) for p in self._perplexity(scores)], } if check_first_tokens: to_return["first_token_probs"] = self._get_first_token_probs( check_first_tokens, scores ) # Reformat output to conform to HF Pipeline format return [ {key: to_return[key][ndx] for key in to_return.keys()} for ndx in range(len(to_return["generated_text"])) ] def _output_text_from_scores(self, scores: List[torch.Tensor]) -> List[str]: """ Returns the decoded text from the scores. TODO (ENG-20823): Use the returned sequences so we pay attention to things like bad_words, force_words etc. """ # Always return list format batch_token_ids = [ [score[ndx].argmax() for score in scores] for ndx in range(scores[0].shape[0]) ] # Fix for new tokens being generated after EOS new_batch_token_ids = [] for token_ids in batch_token_ids: try: new_token_ids = token_ids[ : token_ids.index(self.tokenizer.eos_token_id) ] except ValueError: new_token_ids = token_ids[:-1] new_batch_token_ids.append(new_token_ids) return self.tokenizer.batch_decode(new_batch_token_ids) def _perplexity(self, scores: List[torch.Tensor]) -> List[float]: """ Returns the perplexity (model confidence) of the outputted text. e^( sum(ln(p(word))) / N) TODO (ENG-20823): don't include the trailing pad tokens in perplexity """ return torch.exp( torch.stack( [score.softmax(axis=1).log().max(axis=1)[0] for score in scores] ).sum(axis=0) / len(scores) ).tolist() def _get_first_token_probs( self, tokens: List[str], scores: List[torch.Tensor] ) -> List[Dict[str, float]]: """ Return the softmaxed probabilities of the specific tokens for each output """ first_token_probs = [] softmaxed_scores = scores[0].softmax(axis=1) # Finding the correct token IDs # TODO (ENG-20824): Support multi-token words token_ids = {} for token in tokens: encoded_token: List[int] = self.tokenizer.encode(token) if len(encoded_token) > 2: # This means the tokenizer broke the token up into multiple parts token_ids[token] = -1 else: token_ids[token] = encoded_token[0] # Now finding the scores for each token in the list for seq_ndx in range(scores[0].shape[0]): curr_token_probs: Dict[str, float] = {} for token in tokens: if token_ids[token] == -1: curr_token_probs[token] = 0 else: curr_token_probs[token] = float( softmaxed_scores[seq_ndx, token_ids[token]] ) first_token_probs.append(curr_token_probs) return first_token_probs