"""Contains classes for querying large language models.""" from math import ceil import os import time from tqdm import tqdm from abc import ABC, abstractmethod import openai gpt_costs_per_thousand = { "davinci": 0.0200, "curie": 0.0020, "babbage": 0.0005, "ada": 0.0004, } def model_from_config(config, disable_tqdm=True): """Returns a model based on the config.""" model_type = config["name"] if model_type == "GPT_forward": return GPT_Forward(config, disable_tqdm=disable_tqdm) elif model_type == "GPT_insert": return GPT_Insert(config, disable_tqdm=disable_tqdm) raise ValueError(f"Unknown model type: {model_type}") class LLM(ABC): """Abstract base class for large language models.""" @abstractmethod def generate_text(self, prompt): """Generates text from the model. Parameters: prompt: The prompt to use. This can be a string or a list of strings. Returns: A list of strings. """ pass @abstractmethod def log_probs(self, text, log_prob_range): """Returns the log probs of the text. Parameters: text: The text to get the log probs of. This can be a string or a list of strings. log_prob_range: The range of characters within each string to get the log_probs of. This is a list of tuples of the form (start, end). Returns: A list of log probs. """ pass class GPT_Forward(LLM): """Wrapper for GPT-3.""" def __init__(self, config, needs_confirmation=False, disable_tqdm=True): """Initializes the model.""" self.config = config self.needs_confirmation = needs_confirmation self.disable_tqdm = disable_tqdm def confirm_cost(self, texts, n, max_tokens): total_estimated_cost = 0 for text in texts: total_estimated_cost += ( gpt_get_estimated_cost(self.config, text, max_tokens) * n ) print(f"Estimated cost: ${total_estimated_cost:.2f}") # Ask the user to confirm in the command line if os.getenv("LLM_SKIP_CONFIRM") is None: confirm = input("Continue? (y/n) ") if confirm != "y": raise Exception("Aborted.") def auto_reduce_n(self, fn, prompt, n): """Reduces n by half until the function succeeds.""" try: return fn(prompt, n) except BatchSizeException as e: if n == 1: raise e return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n( fn, prompt, n // 2 ) def generate_text(self, prompt, n): if not isinstance(prompt, list): prompt = [prompt] if self.needs_confirmation: self.confirm_cost(prompt, n, self.config["gpt_config"]["max_tokens"]) batch_size = self.config["batch_size"] prompt_batches = [ prompt[i : i + batch_size] for i in range(0, len(prompt), batch_size) ] if not self.disable_tqdm: print( f"[{self.config['name']}] Generating {len(prompt) * n} completions, " f"split into {len(prompt_batches)} batches of size {batch_size * n}" ) text = [] for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm): text += self.auto_reduce_n(self.__generate_text, prompt_batch, n) return text def complete(self, prompt, n): """Generates text from the model and returns the log prob data.""" if not isinstance(prompt, list): prompt = [prompt] batch_size = self.config["batch_size"] prompt_batches = [ prompt[i : i + batch_size] for i in range(0, len(prompt), batch_size) ] if not self.disable_tqdm: print( f"[{self.config['name']}] Generating {len(prompt) * n} completions, " f"split into {len(prompt_batches)} batches of size {batch_size * n}" ) res = [] for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm): res += self.__complete(prompt_batch, n) return res def log_probs(self, text, log_prob_range=None): """Returns the log probs of the text.""" if not isinstance(text, list): text = [text] if self.needs_confirmation: self.confirm_cost(text, 1, 0) batch_size = self.config["batch_size"] text_batches = [ text[i : i + batch_size] for i in range(0, len(text), batch_size) ] if log_prob_range is None: log_prob_range_batches = [None] * len(text) else: assert len(log_prob_range) == len(text) log_prob_range_batches = [ log_prob_range[i : i + batch_size] for i in range(0, len(log_prob_range), batch_size) ] if not self.disable_tqdm: print( f"[{self.config['name']}] Getting log probs for {len(text)} strings, " f"split into {len(text_batches)} batches of (maximum) size {batch_size}" ) log_probs = [] tokens = [] for text_batch, log_prob_range in tqdm( list(zip(text_batches, log_prob_range_batches)), disable=self.disable_tqdm ): log_probs_batch, tokens_batch = self.__log_probs(text_batch, log_prob_range) log_probs += log_probs_batch tokens += tokens_batch return log_probs, tokens def __generate_text(self, prompt, n): """Generates text from the model.""" if not isinstance(prompt, list): text = [prompt] config = self.config["gpt_config"].copy() config["n"] = n # If there are any [APE] tokens in the prompts, remove them for i in range(len(prompt)): prompt[i] = prompt[i].replace("[APE]", "").strip() response = None while response is None: try: response = openai.Completion.create(**config, prompt=prompt) except Exception as e: if "is greater than the maximum" in str(e): raise BatchSizeException() print(e) print("Retrying...") time.sleep(5) return [response["choices"][i]["text"] for i in range(len(response["choices"]))] def __complete(self, prompt, n): """Generates text from the model and returns the log prob data.""" if not isinstance(prompt, list): text = [prompt] config = self.config["gpt_config"].copy() config["n"] = n # If there are any [APE] tokens in the prompts, remove them for i in range(len(prompt)): prompt[i] = prompt[i].replace("[APE]", "").strip() response = None while response is None: try: response = openai.Completion.create(**config, prompt=prompt) except Exception as e: print(e) print("Retrying...") time.sleep(5) return response["choices"] def __log_probs(self, text, log_prob_range=None): """Returns the log probs of the text.""" if not isinstance(text, list): text = [text] if log_prob_range is not None: for i in range(len(text)): lower_index, upper_index = log_prob_range[i] assert lower_index < upper_index assert lower_index >= 0 assert upper_index - 1 < len(text[i]) config = self.config["gpt_config"].copy() config["logprobs"] = 1 config["echo"] = True config["max_tokens"] = 0 if isinstance(text, list): text = [f"\n{text[i]}" for i in range(len(text))] else: text = f"\n{text}" response = None while response is None: try: response = openai.Completion.create(**config, prompt=text) except Exception as e: print(e) print("Retrying...") time.sleep(5) log_probs = [ response["choices"][i]["logprobs"]["token_logprobs"][1:] for i in range(len(response["choices"])) ] tokens = [ response["choices"][i]["logprobs"]["tokens"][1:] for i in range(len(response["choices"])) ] offsets = [ response["choices"][i]["logprobs"]["text_offset"][1:] for i in range(len(response["choices"])) ] # Subtract 1 from the offsets to account for the newline for i in range(len(offsets)): offsets[i] = [offset - 1 for offset in offsets[i]] if log_prob_range is not None: # First, we need to find the indices of the tokens in the log probs # that correspond to the tokens in the log_prob_range for i in range(len(log_probs)): lower_index, upper_index = self.get_token_indices( offsets[i], log_prob_range[i] ) log_probs[i] = log_probs[i][lower_index:upper_index] tokens[i] = tokens[i][lower_index:upper_index] return log_probs, tokens def get_token_indices(self, offsets, log_prob_range): """Returns the indices of the tokens in the log probs that correspond to the tokens in the log_prob_range.""" # For the lower index, find the highest index that is less than or equal to the lower index lower_index = 0 for i in range(len(offsets)): if offsets[i] <= log_prob_range[0]: lower_index = i else: break upper_index = len(offsets) for i in range(len(offsets)): if offsets[i] >= log_prob_range[1]: upper_index = i break return lower_index, upper_index class GPT_Insert(LLM): def __init__(self, config, needs_confirmation=False, disable_tqdm=True): """Initializes the model.""" self.config = config self.needs_confirmation = needs_confirmation self.disable_tqdm = disable_tqdm def confirm_cost(self, texts, n, max_tokens): total_estimated_cost = 0 for text in texts: total_estimated_cost += ( gpt_get_estimated_cost(self.config, text, max_tokens) * n ) print(f"Estimated cost: ${total_estimated_cost:.2f}") # Ask the user to confirm in the command line if os.getenv("LLM_SKIP_CONFIRM") is None: confirm = input("Continue? (y/n) ") if confirm != "y": raise Exception("Aborted.") def auto_reduce_n(self, fn, prompt, n): """Reduces n by half until the function succeeds.""" try: return fn(prompt, n) except BatchSizeException as e: if n == 1: raise e return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n( fn, prompt, n // 2 ) def generate_text(self, prompt, n): if not isinstance(prompt, list): prompt = [prompt] if self.needs_confirmation: self.confirm_cost(prompt, n, self.config["gpt_config"]["max_tokens"]) batch_size = self.config["batch_size"] assert batch_size == 1 prompt_batches = [ prompt[i : i + batch_size] for i in range(0, len(prompt), batch_size) ] if not self.disable_tqdm: print( f"[{self.config['name']}] Generating {len(prompt) * n} completions, split into {len(prompt_batches)} batches of (maximum) size {batch_size * n}" ) text = [] for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm): text += self.auto_reduce_n(self.__generate_text, prompt_batch, n) return text def log_probs(self, text, log_prob_range=None): raise NotImplementedError def __generate_text(self, prompt, n): """Generates text from the model.""" config = self.config["gpt_config"].copy() config["n"] = n # Split prompts into prefixes and suffixes with the [APE] token (do not include the [APE] token in the suffix) prefix = prompt[0].split("[APE]")[0] suffix = prompt[0].split("[APE]")[1] response = None while response is None: try: response = openai.Completion.create( **config, prompt=prefix, suffix=suffix ) except Exception as e: print(e) print("Retrying...") time.sleep(5) # Remove suffix from the generated text texts = [ response["choices"][i]["text"].replace(suffix, "") for i in range(len(response["choices"])) ] return texts def gpt_get_estimated_cost(config, prompt, max_tokens): """Uses the current API costs/1000 tokens to estimate the cost of generating text from the model.""" # Get rid of [APE] token prompt = prompt.replace("[APE]", "") # Get the number of tokens in the prompt n_prompt_tokens = len(prompt) // 4 # Get the number of tokens in the generated text total_tokens = n_prompt_tokens + max_tokens engine = config["gpt_config"]["model"].split("-")[1] costs_per_thousand = gpt_costs_per_thousand if engine not in costs_per_thousand: # Try as if it is a fine-tuned model engine = config["gpt_config"]["model"].split(":")[0] costs_per_thousand = { "davinci": 0.1200, "curie": 0.0120, "babbage": 0.0024, "ada": 0.0016, } price = costs_per_thousand[engine] * total_tokens / 1000 return price class BatchSizeException(Exception): pass