File size: 9,743 Bytes
8a5e8bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
model_name = "InternLM"
cmd_to_install = "`pip install -r request_llm/requirements_chatglm.txt`"

from transformers import AutoModel, AutoTokenizer
import time
import threading
import importlib
from toolbox import update_ui, get_conf
from multiprocessing import Process, Pipe
from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns, SingletonLocalLLM


# ------------------------------------------------------------------------------------------------------------------------
# πŸ”ŒπŸ’» Local Model Utils
# ------------------------------------------------------------------------------------------------------------------------
def try_to_import_special_deps():
    import sentencepiece

def combine_history(prompt, hist):
    user_prompt = "<|User|>:{user}<eoh>\n"
    robot_prompt = "<|Bot|>:{robot}<eoa>\n"
    cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
    messages = hist
    total_prompt = ""
    for message in messages:
        cur_content = message
        cur_prompt = user_prompt.replace("{user}", cur_content[0])
        total_prompt += cur_prompt
        cur_prompt = robot_prompt.replace("{robot}", cur_content[1])
        total_prompt += cur_prompt
    total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
    return total_prompt

# ------------------------------------------------------------------------------------------------------------------------
# πŸ”ŒπŸ’» Local Model
# ------------------------------------------------------------------------------------------------------------------------
@SingletonLocalLLM
class GetInternlmHandle(LocalLLMHandle):

    def load_model_info(self):
        # πŸƒβ€β™‚οΈπŸƒβ€β™‚οΈπŸƒβ€β™‚οΈ ε­θΏ›η¨‹ζ‰§θ‘Œ
        self.model_name = model_name
        self.cmd_to_install = cmd_to_install

    def try_to_import_special_deps(self, **kwargs):
        """
        import something that will raise error if the user does not install requirement_*.txt
        """
        import sentencepiece

    def load_model_and_tokenizer(self):
        # πŸƒβ€β™‚οΈπŸƒβ€β™‚οΈπŸƒβ€β™‚οΈ ε­θΏ›η¨‹ζ‰§θ‘Œ
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        device, = get_conf('LOCAL_MODEL_DEVICE')
        if self._model is None:
            tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
            if device=='cpu':
                model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16)
            else:
                model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()

            model = model.eval()
        return model, tokenizer

    def llm_stream_generator(self, **kwargs):
        import torch
        import logging
        import copy
        import warnings
        import torch.nn as nn
        from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig

        # πŸƒβ€β™‚οΈπŸƒβ€β™‚οΈπŸƒβ€β™‚οΈ ε­θΏ›η¨‹ζ‰§θ‘Œ
        def adaptor():
            model = self._model
            tokenizer = self._tokenizer
            prompt = kwargs['query']
            max_length = kwargs['max_length']
            top_p = kwargs['top_p']
            temperature = kwargs['temperature']
            history = kwargs['history']
            real_prompt = combine_history(prompt, history)
            return model, tokenizer, real_prompt, max_length, top_p, temperature
        
        model, tokenizer, prompt, max_length, top_p, temperature = adaptor()
        prefix_allowed_tokens_fn = None
        logits_processor = None
        stopping_criteria = None
        additional_eos_token_id = 103028
        generation_config = None
        # πŸƒβ€β™‚οΈπŸƒβ€β™‚οΈπŸƒβ€β™‚οΈ ε­θΏ›η¨‹ζ‰§θ‘Œ
        # πŸƒβ€β™‚οΈπŸƒβ€β™‚οΈπŸƒβ€β™‚οΈ https://github.com/InternLM/InternLM/blob/efbf5335709a8c8faeac6eaf07193973ff1d56a1/web_demo.py#L25

        inputs = tokenizer([prompt], padding=True, return_tensors="pt")
        input_length = len(inputs["input_ids"][0])
        for k, v in inputs.items():
            inputs[k] = v.cuda()
        input_ids = inputs["input_ids"]
        batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
        if generation_config is None:
            generation_config = model.generation_config
        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)
        bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        if additional_eos_token_id is not None:
            eos_token_id.append(additional_eos_token_id)
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        if has_default_max_length and generation_config.max_new_tokens is None:
            warnings.warn(
                f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
                "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
                " recommend using `max_new_tokens` to control the maximum length of the generation.",
                UserWarning,
            )
        elif generation_config.max_new_tokens is not None:
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
            if not has_default_max_length:
                logging.warn(
                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                    "Please refer to the documentation for more information. "
                    "(https://huggingface.co./docs/transformers/main/en/main_classes/text_generation)",
                    UserWarning,
                )

        if input_ids_seq_length >= generation_config.max_length:
            input_ids_string = "input_ids"
            logging.warning(
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
                " increasing `max_new_tokens`."
            )

        # 2. Set generation parameters if not already defined
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        logits_processor = model._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            encoder_input_ids=input_ids,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
        )

        stopping_criteria = model._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria
        )
        logits_warper = model._get_logits_warper(generation_config)

        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        scores = None
        while True:
            model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
            # forward pass to get next token
            outputs = model(
                **model_inputs,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=False,
            )

            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # sample
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            if generation_config.do_sample:
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(probs, dim=-1)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = model._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=False
            )
            unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
            
            output_token_ids = input_ids[0].cpu().tolist()
            output_token_ids = output_token_ids[input_length:]
            for each_eos_token_id in eos_token_id:
                if output_token_ids[-1] == each_eos_token_id:
                    output_token_ids = output_token_ids[:-1]
            response = tokenizer.decode(output_token_ids)

            yield response
            # stop when each sentence is finished, or if we exceed the maximum length
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                return

    
# ------------------------------------------------------------------------------------------------------------------------
# πŸ”ŒπŸ’» GPT-Academic Interface
# ------------------------------------------------------------------------------------------------------------------------
predict_no_ui_long_connection, predict = get_local_llm_predict_fns(GetInternlmHandle, model_name)