AlexanderKazakov
add falcon-180B-chat
cfc7185
raw
history blame
1.6 kB
import logging
import os
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer
HF_TOKEN = None
key_file = 'data/hftoken.txt'
if os.path.exists(key_file):
with open(key_file) as f:
HF_TOKEN = f.read().strip()
if HF_TOKEN is None:
HF_TOKEN = os.getenv('HF_TOKEN')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# noinspection PyTypeChecker
class HuggingfaceGenerator:
def __init__(
self, model_name,
temperature: float = 0.9, max_new_tokens: int = 512,
top_p: float = None, repetition_penalty: float = None,
stream: bool = True,
):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
self.hf_client = InferenceClient(model_name, token=HF_TOKEN)
self.stream = stream
self.generate_kwargs = {
'temperature': max(temperature, 0.1),
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'do_sample': True,
'seed': 42,
}
def generate(self, messages):
formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
logger.info(f'Start HuggingFace generation, model {self.hf_client.model} ...')
stream = self.hf_client.text_generation(
formatted_prompt, **self.generate_kwargs,
stream=self.stream, details=True, return_full_text=not self.stream
)
for response in stream:
yield response.token.text