File size: 4,728 Bytes
2841b26 f5ab4cb 8dfc799 2841b26 f5ab4cb 2841b26 f5ab4cb 8dfc799 f5ab4cb 3c2fc33 cd47483 3c2fc33 f5ab4cb 2841b26 f5ab4cb 8dfc799 f5ab4cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import math
import random
import gradio as gr
from distilabel.llms import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM
from distilabel.steps.tasks import TextGeneration
from synthetic_dataset_generator.constants import (
API_KEYS,
DEFAULT_BATCH_SIZE,
HUGGINGFACE_BASE_URL,
MAGPIE_PRE_QUERY_TEMPLATE,
MODEL,
OLLAMA_BASE_URL,
OPENAI_BASE_URL,
TOKENIZER_ID,
VLLM_BASE_URL,
)
TOKEN_INDEX = 0
def _get_next_api_key():
global TOKEN_INDEX
api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
TOKEN_INDEX += 1
return api_key
def _get_prompt_rewriter():
generation_kwargs = {
"temperature": 1,
}
system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new."
prompt_rewriter = TextGeneration(
llm=_get_llm(generation_kwargs=generation_kwargs),
system_prompt=system_prompt,
use_system_prompt=True,
)
prompt_rewriter.load()
return prompt_rewriter
def get_rewriten_prompts(prompt: str, num_rows: int):
prompt_rewriter = _get_prompt_rewriter()
# create prompt rewrites
inputs = [
{"instruction": f"Original prompt: {prompt} \nRewritten prompt: "}
for i in range(math.floor(num_rows / 100))
]
n_processed = 0
prompt_rewrites = [prompt]
while n_processed < num_rows:
batch = list(
prompt_rewriter.process(
inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE]
)
)
prompt_rewrites += [entry["generation"] for entry in batch[0]]
n_processed += DEFAULT_BATCH_SIZE
random.seed(a=random.randint(0, 2**32 - 1))
return prompt_rewrites
def _get_llm(use_magpie_template=False, **kwargs):
if OPENAI_BASE_URL:
llm = OpenAILLM(
model=MODEL,
base_url=OPENAI_BASE_URL,
api_key=_get_next_api_key(),
**kwargs,
)
if "generation_kwargs" in kwargs:
if "stop_sequences" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
"stop_sequences"
]
del kwargs["generation_kwargs"]["stop_sequences"]
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
elif OLLAMA_BASE_URL:
if "generation_kwargs" in kwargs:
if "max_new_tokens" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["num_predict"] = kwargs[
"generation_kwargs"
]["max_new_tokens"]
del kwargs["generation_kwargs"]["max_new_tokens"]
if "stop_sequences" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
"stop_sequences"
]
del kwargs["generation_kwargs"]["stop_sequences"]
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
options = kwargs["generation_kwargs"]
del kwargs["generation_kwargs"]
kwargs["generation_kwargs"] = {}
kwargs["generation_kwargs"]["options"] = options
llm = OllamaLLM(
model=MODEL,
host=OLLAMA_BASE_URL,
tokenizer_id=TOKENIZER_ID or MODEL,
**kwargs,
)
elif HUGGINGFACE_BASE_URL:
kwargs["generation_kwargs"]["do_sample"] = True
llm = InferenceEndpointsLLM(
api_key=_get_next_api_key(),
base_url=HUGGINGFACE_BASE_URL,
tokenizer_id=TOKENIZER_ID or MODEL,
**kwargs,
)
elif VLLM_BASE_URL:
if "generation_kwargs" in kwargs:
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
llm = ClientvLLM(
base_url=VLLM_BASE_URL,
model=MODEL,
tokenizer=TOKENIZER_ID or MODEL,
api_key=_get_next_api_key(),
**kwargs,
)
else:
llm = InferenceEndpointsLLM(
api_key=_get_next_api_key(),
tokenizer_id=TOKENIZER_ID or MODEL,
model_id=MODEL,
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
**kwargs,
)
return llm
try:
llm = _get_llm()
llm.load()
llm.generate([[{"content": "Hello, world!", "role": "user"}]])
except Exception as e:
gr.Error(f"Error loading {llm.__class__.__name__}: {e}")
|