File size: 1,601 Bytes
34b78ab
1f71841
34b78ab
 
 
 
1f71841
 
 
 
 
 
 
 
 
 
34b78ab
 
 
 
 
 
 
 
 
 
 
 
 
cfc7185
34b78ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import logging
import os

from huggingface_hub import InferenceClient
from transformers import AutoTokenizer


HF_TOKEN = None
key_file = 'data/hftoken.txt'
if os.path.exists(key_file):
    with open(key_file) as f:
        HF_TOKEN = f.read().strip()

if HF_TOKEN is None:
    HF_TOKEN = os.getenv('HF_TOKEN')


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# noinspection PyTypeChecker
class HuggingfaceGenerator:
    def __init__(
            self, model_name,
            temperature: float = 0.9, max_new_tokens: int = 512,
            top_p: float = None, repetition_penalty: float = None,
            stream: bool = True,
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
        self.hf_client = InferenceClient(model_name, token=HF_TOKEN)
        self.stream = stream

        self.generate_kwargs = {
            'temperature': max(temperature, 0.1),
            'max_new_tokens': max_new_tokens,
            'top_p': top_p,
            'repetition_penalty': repetition_penalty,
            'do_sample': True,
            'seed': 42,
        }

    def generate(self, messages):
        formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)

        logger.info(f'Start HuggingFace generation, model {self.hf_client.model} ...')
        stream = self.hf_client.text_generation(
            formatted_prompt, **self.generate_kwargs,
            stream=self.stream, details=True, return_full_text=not self.stream
        )

        for response in stream:
            yield response.token.text