chiayewken commited on
Commit
596d336
·
1 Parent(s): 9442fde

Add qwen2-vl streaming inference

Browse files
Files changed (2) hide show
  1. app.py +129 -38
  2. run_demo.py +0 -97
app.py CHANGED
@@ -1,13 +1,25 @@
 
1
  import os
2
- from threading import Thread
3
- from typing import Iterator
4
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
-
10
- from run_demo import ZeroShotChatTemplate
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  MAX_MAX_NEW_TOKENS = 2048
13
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -29,15 +41,110 @@ As a derivate work of [Llama-3-8b-chat](https://huggingface.co/meta-llama/Meta-L
29
  this demo is governed by the original [license](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE) and [acceptable use policy](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/USE_POLICY.md).
30
  """
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if not torch.cuda.is_available():
33
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
34
 
35
 
36
  if torch.cuda.is_available():
37
- model_id = "chiayewken/llama3-8b-gsm8k-rpo"
38
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
39
- tokenizer = AutoTokenizer.from_pretrained(model_id)
40
- tokenizer.use_default_system_prompt = False
41
 
42
 
43
  @spaces.GPU
@@ -51,32 +158,8 @@ def generate(
51
  top_k: int = 50,
52
  repetition_penalty: float = 1.2,
53
  ) -> Iterator[str]:
54
- demo = ZeroShotChatTemplate()
55
- prompt = demo.make_prompt(message)
56
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
57
-
58
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
59
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
60
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
61
- input_ids = input_ids.to(model.device)
62
-
63
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
64
- generate_kwargs = dict(
65
- {"input_ids": input_ids},
66
- streamer=streamer,
67
- max_new_tokens=max_new_tokens,
68
- do_sample=True,
69
- top_p=top_p,
70
- top_k=top_k,
71
- temperature=temperature,
72
- num_beams=1,
73
- repetition_penalty=repetition_penalty,
74
- )
75
- t = Thread(target=model.generate, kwargs=generate_kwargs)
76
- t.start()
77
-
78
  outputs = []
79
- for text in streamer:
80
  outputs.append(text)
81
  yield "".join(outputs)
82
 
@@ -123,9 +206,15 @@ chat_interface = gr.ChatInterface(
123
  ],
124
  stop_btn=None,
125
  examples=[
126
- ["Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?"],
127
- ["Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"],
128
- ["Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?"],
 
 
 
 
 
 
129
  ],
130
  cache_examples=False,
131
  type="messages",
@@ -133,7 +222,9 @@ chat_interface = gr.ChatInterface(
133
 
134
  with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
135
  gr.Markdown(DESCRIPTION)
136
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
 
 
137
  chat_interface.render()
138
  gr.Markdown(LICENSE)
139
 
 
1
+ import hashlib
2
  import os
3
+ from pathlib import Path
4
+ from typing import Iterator, Optional, List, Union
5
 
6
  import gradio as gr
7
  import spaces
8
  import torch
9
+ from PIL import Image
10
+ from pydantic import BaseModel
11
+ from swift.llm import (
12
+ ModelType,
13
+ get_model_tokenizer,
14
+ get_default_template_type,
15
+ get_template,
16
+ inference,
17
+ inference_stream,
18
+ )
19
+ from transformers import (
20
+ Qwen2VLForConditionalGeneration,
21
+ PreTrainedTokenizer,
22
+ )
23
 
24
  MAX_MAX_NEW_TOKENS = 2048
25
  DEFAULT_MAX_NEW_TOKENS = 1024
 
41
  this demo is governed by the original [license](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE) and [acceptable use policy](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/USE_POLICY.md).
42
  """
43
 
44
+
45
+ def save_image(image: Image.Image, folder: str) -> str:
46
+ image_hash = hashlib.md5(image.tobytes()).hexdigest()
47
+ path = Path(folder, f"{image_hash}.png")
48
+ path.parent.mkdir(exist_ok=True, parents=True)
49
+ if not path.exists():
50
+ image.save(path)
51
+ return str(path)
52
+
53
+
54
+ def resize_image(image: Image.Image, max_size: int) -> Image.Image:
55
+ # Same as modeling.py resize_image
56
+ width, height = image.size
57
+ if width <= max_size and height <= max_size:
58
+ return image
59
+ if width > height:
60
+ new_width = max_size
61
+ new_height = round(height * max_size / width)
62
+ else:
63
+ new_height = max_size
64
+ new_width = round(width * max_size / height)
65
+ return image.resize((new_width, new_height), Image.LANCZOS)
66
+
67
+
68
+ class EvalModel(BaseModel, arbitrary_types_allowed=True):
69
+ engine: str
70
+ timeout: int = 60
71
+ temperature: float = 0.0
72
+ max_output_tokens: int = 512
73
+
74
+ def run(self, inputs: List[Union[str, Image.Image]]) -> str:
75
+ raise NotImplementedError
76
+
77
+ def run_many(self, inputs: List[Union[str, Image.Image]], num: int) -> List[str]:
78
+ raise NotImplementedError
79
+
80
+
81
+ class SwiftQwenModel(EvalModel):
82
+ # https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Multi-Modal/qwen2-vl-best-practice.md
83
+ path: str = ""
84
+ model: Optional[Qwen2VLForConditionalGeneration] = None
85
+ tokenizer: Optional[PreTrainedTokenizer] = None
86
+ engine: str = ModelType.qwen2_vl_7b_instruct
87
+ image_size: int = 768
88
+ image_dir: str = "data/qwen_images"
89
+
90
+ def load(self):
91
+ if self.model is None or self.tokenizer is None:
92
+ self.model, self.tokenizer = get_model_tokenizer(
93
+ self.engine,
94
+ torch.bfloat16,
95
+ model_kwargs={"device_map": "auto"},
96
+ model_id_or_path=self.path or None,
97
+ )
98
+
99
+ def run(self, inputs: List[Union[str, Image.Image]]) -> str:
100
+ self.load()
101
+ template_type = get_default_template_type(self.engine)
102
+ self.model.generation_config.max_new_tokens = self.max_output_tokens
103
+ template = get_template(template_type, self.tokenizer)
104
+
105
+ text = "\n\n".join([x for x in inputs if isinstance(x, str)])
106
+ content = []
107
+ for x in inputs:
108
+ if isinstance(x, Image.Image):
109
+ path = save_image(resize_image(x, self.image_size), self.image_dir)
110
+ content.append(f"<img>{path}</img>")
111
+ content.append(text)
112
+
113
+ query = "".join(content)
114
+ response, history = inference(self.model, template, query)
115
+ return response
116
+
117
+ def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]:
118
+ self.load()
119
+ template_type = get_default_template_type(self.engine)
120
+ self.model.generation_config.max_new_tokens = self.max_output_tokens
121
+ template = get_template(template_type, self.tokenizer)
122
+
123
+ text = "\n\n".join([x for x in inputs if isinstance(x, str)])
124
+ content = []
125
+ for x in inputs:
126
+ if isinstance(x, Image.Image):
127
+ path = save_image(resize_image(x, self.image_size), self.image_dir)
128
+ content.append(f"<img>{path}</img>")
129
+ content.append(text)
130
+
131
+ query = "".join(content)
132
+ generator = inference_stream(self.model, template, query)
133
+ print_idx = 0
134
+ print(f"query: {query}\nresponse: ", end="")
135
+ for response, history in generator:
136
+ delta = response[print_idx:]
137
+ print(delta, end="", flush=True)
138
+ print_idx = len(response)
139
+ yield delta
140
+
141
+
142
  if not torch.cuda.is_available():
143
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
144
 
145
 
146
  if torch.cuda.is_available():
147
+ model = SwiftQwenModel()
 
 
 
148
 
149
 
150
  @spaces.GPU
 
158
  top_k: int = 50,
159
  repetition_penalty: float = 1.2,
160
  ) -> Iterator[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  outputs = []
162
+ for text in model.run_stream(inputs=[message]):
163
  outputs.append(text)
164
  yield "".join(outputs)
165
 
 
206
  ],
207
  stop_btn=None,
208
  examples=[
209
+ [
210
+ "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?"
211
+ ],
212
+ [
213
+ "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
214
+ ],
215
+ [
216
+ "Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?"
217
+ ],
218
  ],
219
  cache_examples=False,
220
  type="messages",
 
222
 
223
  with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
224
  gr.Markdown(DESCRIPTION)
225
+ gr.DuplicateButton(
226
+ value="Duplicate Space for private use", elem_id="duplicate-button"
227
+ )
228
  chat_interface.render()
229
  gr.Markdown(LICENSE)
230
 
run_demo.py DELETED
@@ -1,97 +0,0 @@
1
- import re
2
- from typing import Optional, List
3
-
4
- import vllm
5
- from fire import Fire
6
- from pydantic import BaseModel
7
- from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForCausalLM
8
-
9
-
10
- class ZeroShotChatTemplate:
11
- # This is the default template used in llama-factory for training
12
- texts: List[str] = []
13
-
14
- @staticmethod
15
- def make_prompt(prompt: str) -> str:
16
- return f"Human: {prompt}\nAssistant: "
17
-
18
- @staticmethod
19
- def get_stopping_words() -> List[str]:
20
- return ["Human:"]
21
-
22
- @staticmethod
23
- def extract_answer(text: str) -> str:
24
- filtered = "".join([char for char in text if char.isdigit() or char == " "])
25
- if not filtered.strip():
26
- return text
27
- return re.findall(pattern=r"\d+", string=filtered)[-1]
28
-
29
-
30
- class VLLMModel(BaseModel, arbitrary_types_allowed=True):
31
- path_model: str
32
- model: vllm.LLM = None
33
- tokenizer: Optional[PreTrainedTokenizer] = None
34
- max_input_length: int = 512
35
- max_output_length: int = 512
36
- stopping_words: Optional[List[str]] = None
37
-
38
- def load(self):
39
- if self.model is None:
40
- self.model = vllm.LLM(model=self.path_model, trust_remote_code=True)
41
- if self.tokenizer is None:
42
- self.tokenizer = AutoTokenizer.from_pretrained(self.path_model)
43
-
44
- def format_prompt(self, prompt: str) -> str:
45
- self.load()
46
- prompt = prompt.rstrip(" ") # Llama is sensitive (eg "Answer:" vs "Answer: ")
47
- return prompt
48
-
49
- def make_kwargs(self, do_sample: bool, **kwargs) -> dict:
50
- if self.stopping_words:
51
- kwargs.update(stop=self.stopping_words)
52
- params = vllm.SamplingParams(
53
- temperature=0.5 if do_sample else 0.0,
54
- max_tokens=self.max_output_length,
55
- **kwargs,
56
- )
57
-
58
- outputs = dict(sampling_params=params, use_tqdm=False)
59
- return outputs
60
-
61
- def run(self, prompt: str) -> str:
62
- prompt = self.format_prompt(prompt)
63
- outputs = self.model.generate([prompt], **self.make_kwargs(do_sample=False))
64
- pred = outputs[0].outputs[0].text
65
- pred = pred.split("<|endoftext|>")[0]
66
- return pred
67
-
68
-
69
- def upload_to_hub(path: str, repo_id: str):
70
- tokenizer = AutoTokenizer.from_pretrained(path)
71
- model = AutoModelForCausalLM.from_pretrained(path)
72
- model.push_to_hub(repo_id)
73
- tokenizer.push_to_hub(repo_id)
74
-
75
-
76
- def main(
77
- question: str = "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?",
78
- **kwargs,
79
- ):
80
- model = VLLMModel(**kwargs)
81
- demo = ZeroShotChatTemplate()
82
- model.stopping_words = demo.get_stopping_words()
83
-
84
- prompt = demo.make_prompt(question)
85
- raw_outputs = model.run(prompt)
86
- pred = demo.extract_answer(raw_outputs)
87
- print(dict(question=question, prompt=prompt, raw_outputs=raw_outputs, pred=pred))
88
-
89
-
90
- """
91
- p run_demo.py upload_to_hub outputs_paths/gsm8k_paths_llama3_8b_beta_03_rank_128/final chiayewken/llama3-8b-gsm8k-rpo
92
- p run_demo.py main --path_model chiayewken/llama3-8b-gsm8k-rpo
93
- """
94
-
95
-
96
- if __name__ == "__main__":
97
- Fire()