File size: 4,604 Bytes
d1f08d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1210242
d1f08d3
 
 
 
1210242
d1f08d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e480e8f
 
d1f08d3
 
e480e8f
d1f08d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import logging
import os
from abc import ABC, abstractmethod

import tiktoken
from gradio_client import Client


from openai import OpenAI

client = OpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)

HUB_TOKEN = os.getenv("HUB_TOKEN")
FLAN_HF_SPACE = "jerpint/i-like-flan"
LLAMA_HF_SPACE = "meta-llama/Llama-2-7b"


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class TextCompletion(ABC):
    @abstractmethod
    def get_token_count(prompt: str) -> int:
        ...

    @abstractmethod
    def complete(prompt: str):
        ...

    def get_score_multiplier(self) -> float:
        return 1.0


class DebugCompleter(TextCompletion):
    """This is only meant for testing/debugging purposes."""

    def __init__(self, model):
        self.model = model

    def get_token_count(self, prompt: str) -> int:
        encoding = tiktoken.encoding_for_model("text-davinci-003")
        return len(encoding.encode(prompt))

    def complete(self, prompt: str):
        if "$CORRECT$" in prompt:
            # returns the correct prompt
            return "I have been PWNED"
        elif "$ERROR$" in prompt:
            # simulate an error
            raise RuntimeError(f"User raised an error explicitly in {self.model}")
        else:
            return "Some incorrect completion"


class ChatGPTCompletor(TextCompletion):
    def __init__(self, model):
        self.model = model

    def get_token_count(self, prompt: str) -> int:
        encoding = tiktoken.encoding_for_model(self.model)
        return len(encoding.encode(prompt))

    def complete(self, prompt: str):
        messages = [
            {"role": "user", "content": prompt},
        ]

        completion = client.chat.completions.create(
            messages=messages,
            temperature=0,
            model=self.model,
        )

        # Get the response text
        response_text = completion.choices[0].message.content

        return response_text

    def get_score_multiplier(self) -> float:
        return 2.0


class FlanT5Completor(TextCompletion):
    def __init__(self, model: str):

        self.model = model
        self.client = Client(src=FLAN_HF_SPACE, hf_token=HUB_TOKEN)

    def get_token_count(self, prompt: str) -> int:
        from transformers import T5Tokenizer
        tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")

        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        token_count = len(input_ids[0])
        return token_count

    def complete(self, prompt: str) -> str:
        response_text = self.client.predict(prompt)
        return response_text


class LLama2Completor(TextCompletion):
    def __init__(self, model: str):
        self.model = model
        # self.client = Client(src=LLAMA_HF_SPACE, hf_token=HUB_TOKEN)
        self.client = Client(src="https://e554306b558a76b0fd.gradio.live")

    def get_token_count(self, prompt: str) -> int:
        return 0

    def complete(self, prompt: str) -> str:
        output = self.client.predict(prompt)
        return output


class ClaudeCompletor(TextCompletion):
    def __init__(self, model="claude-2"):
        from anthropic import Anthropic
        import os

        api_key = os.environ["ANTHROPIC_API_KEY"]

        self.model = model
        self.anthropic = Anthropic(api_key=api_key)

    def get_token_count(self, prompt: str):
        "Not used in actual competition, so just dummy value"
        return 0

    def complete(self, prompt: str):
        from anthropic import HUMAN_PROMPT, AI_PROMPT

        completion = self.anthropic.completions.create(
            model=self.model,
            max_tokens_to_sample=300,
            prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
        )
        return completion.completion


completers = {
    "gpt-3.5-turbo": ChatGPTCompletor,
    "gpt-4o": ChatGPTCompletor,
    "gpt-4": ChatGPTCompletor,
    "gpt-3.5-turbo-0125": ChatGPTCompletor,
    "gpt-4-0613": ChatGPTCompletor,
    "gpt-4o-2024-05-13": ChatGPTCompletor,
}


def get_completer(model: str):
    logger.info(f"Loading completer for {model=}")

    if model in completers.keys():
        completer = completers[model](model=model)
    elif model == "gpt-debug":
        # Not included in completers because we dont want it showing in the app
        logger.warning("Using debugging completer...")
        completer = DebugCompleter(model=model)
    else:
        raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}")

    return completer