"Qwen/Qwen2-0.5B-Instruct" from threading import Thread from simulator import Simulator from transformers import TextIteratorStreamer class Qwen2Simulator(Simulator): def generate_query(self, history): inputs = "" if history: messages = [] for query, response in history: messages += [ {"role": "user", "content": query}, {"role": "assistant", "content": response}, ] inputs += self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, ) inputs = inputs + "<|im_start|>user\n" input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device) return self._generate(input_ids) # for new_text in self._stream_generate(input_ids): # yield new_text def generate_response(self, query, history): messages = [] for _query, _response in history: if _response is None: pass messages += [ {"role": "user", "content": _query}, {"role": "assistant", "content": _response}, ] messages.append({"role": "user", "content": query}) input_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, return_tensors="pt", add_generation_prompt=True ).to(self.model.device) return self._generate(input_ids) # for new_text in self._stream_generate(input_ids): # yield new_text def _generate(self, input_ids): input_ids_length = input_ids.shape[-1] response = self.model.generate(input_ids=input_ids, **self.generation_kwargs) return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True) def _stream_generate(self, input_ids): streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True, timeout=60.0, skip_special_tokens=True) stream_generation_kwargs = dict( input_ids=input_ids, streamer=streamer ).update(self.generation_kwargs) thread = Thread(target=self.model.generate, kwargs=stream_generation_kwargs) thread.start() for new_text in streamer: yield new_text bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct") # bot = Qwen2Simulator("Qwen/Qwen2-0.5B-Instruct") # # history = [["hi, what your name", "rhino"]] # generated_query = bot.generate_query(history) # for char in generated_query: # print(char) # # bot.generate_response("1+2*3=", history)