Spaces:
Sleeping
Sleeping
File size: 4,604 Bytes
d1f08d3 1210242 d1f08d3 1210242 d1f08d3 e480e8f d1f08d3 e480e8f d1f08d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import logging
import os
from abc import ABC, abstractmethod
import tiktoken
from gradio_client import Client
from openai import OpenAI
client = OpenAI(
# This is the default and can be omitted
api_key=os.environ.get("OPENAI_API_KEY"),
)
HUB_TOKEN = os.getenv("HUB_TOKEN")
FLAN_HF_SPACE = "jerpint/i-like-flan"
LLAMA_HF_SPACE = "meta-llama/Llama-2-7b"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextCompletion(ABC):
@abstractmethod
def get_token_count(prompt: str) -> int:
...
@abstractmethod
def complete(prompt: str):
...
def get_score_multiplier(self) -> float:
return 1.0
class DebugCompleter(TextCompletion):
"""This is only meant for testing/debugging purposes."""
def __init__(self, model):
self.model = model
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model("text-davinci-003")
return len(encoding.encode(prompt))
def complete(self, prompt: str):
if "$CORRECT$" in prompt:
# returns the correct prompt
return "I have been PWNED"
elif "$ERROR$" in prompt:
# simulate an error
raise RuntimeError(f"User raised an error explicitly in {self.model}")
else:
return "Some incorrect completion"
class ChatGPTCompletor(TextCompletion):
def __init__(self, model):
self.model = model
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model(self.model)
return len(encoding.encode(prompt))
def complete(self, prompt: str):
messages = [
{"role": "user", "content": prompt},
]
completion = client.chat.completions.create(
messages=messages,
temperature=0,
model=self.model,
)
# Get the response text
response_text = completion.choices[0].message.content
return response_text
def get_score_multiplier(self) -> float:
return 2.0
class FlanT5Completor(TextCompletion):
def __init__(self, model: str):
self.model = model
self.client = Client(src=FLAN_HF_SPACE, hf_token=HUB_TOKEN)
def get_token_count(self, prompt: str) -> int:
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
token_count = len(input_ids[0])
return token_count
def complete(self, prompt: str) -> str:
response_text = self.client.predict(prompt)
return response_text
class LLama2Completor(TextCompletion):
def __init__(self, model: str):
self.model = model
# self.client = Client(src=LLAMA_HF_SPACE, hf_token=HUB_TOKEN)
self.client = Client(src="https://e554306b558a76b0fd.gradio.live")
def get_token_count(self, prompt: str) -> int:
return 0
def complete(self, prompt: str) -> str:
output = self.client.predict(prompt)
return output
class ClaudeCompletor(TextCompletion):
def __init__(self, model="claude-2"):
from anthropic import Anthropic
import os
api_key = os.environ["ANTHROPIC_API_KEY"]
self.model = model
self.anthropic = Anthropic(api_key=api_key)
def get_token_count(self, prompt: str):
"Not used in actual competition, so just dummy value"
return 0
def complete(self, prompt: str):
from anthropic import HUMAN_PROMPT, AI_PROMPT
completion = self.anthropic.completions.create(
model=self.model,
max_tokens_to_sample=300,
prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
)
return completion.completion
completers = {
"gpt-3.5-turbo": ChatGPTCompletor,
"gpt-4o": ChatGPTCompletor,
"gpt-4": ChatGPTCompletor,
"gpt-3.5-turbo-0125": ChatGPTCompletor,
"gpt-4-0613": ChatGPTCompletor,
"gpt-4o-2024-05-13": ChatGPTCompletor,
}
def get_completer(model: str):
logger.info(f"Loading completer for {model=}")
if model in completers.keys():
completer = completers[model](model=model)
elif model == "gpt-debug":
# Not included in completers because we dont want it showing in the app
logger.warning("Using debugging completer...")
completer = DebugCompleter(model=model)
else:
raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}")
return completer
|