File size: 1,601 Bytes
34b78ab 1f71841 34b78ab 1f71841 34b78ab cfc7185 34b78ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import logging
import os
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer
HF_TOKEN = None
key_file = 'data/hftoken.txt'
if os.path.exists(key_file):
with open(key_file) as f:
HF_TOKEN = f.read().strip()
if HF_TOKEN is None:
HF_TOKEN = os.getenv('HF_TOKEN')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# noinspection PyTypeChecker
class HuggingfaceGenerator:
def __init__(
self, model_name,
temperature: float = 0.9, max_new_tokens: int = 512,
top_p: float = None, repetition_penalty: float = None,
stream: bool = True,
):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
self.hf_client = InferenceClient(model_name, token=HF_TOKEN)
self.stream = stream
self.generate_kwargs = {
'temperature': max(temperature, 0.1),
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'do_sample': True,
'seed': 42,
}
def generate(self, messages):
formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
logger.info(f'Start HuggingFace generation, model {self.hf_client.model} ...')
stream = self.hf_client.text_generation(
formatted_prompt, **self.generate_kwargs,
stream=self.stream, details=True, return_full_text=not self.stream
)
for response in stream:
yield response.token.text
|