Spaces:
Running
on
Zero
Running
on
Zero
chiayewken
commited on
Commit
·
596d336
1
Parent(s):
9442fde
Add qwen2-vl streaming inference
Browse files- app.py +129 -38
- run_demo.py +0 -97
app.py
CHANGED
@@ -1,13 +1,25 @@
|
|
|
|
1 |
import os
|
2 |
-
from
|
3 |
-
from typing import Iterator
|
4 |
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
8 |
-
from
|
9 |
-
|
10 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
MAX_MAX_NEW_TOKENS = 2048
|
13 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
@@ -29,15 +41,110 @@ As a derivate work of [Llama-3-8b-chat](https://huggingface.co/meta-llama/Meta-L
|
|
29 |
this demo is governed by the original [license](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE) and [acceptable use policy](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/USE_POLICY.md).
|
30 |
"""
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
if not torch.cuda.is_available():
|
33 |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
34 |
|
35 |
|
36 |
if torch.cuda.is_available():
|
37 |
-
|
38 |
-
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
39 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
40 |
-
tokenizer.use_default_system_prompt = False
|
41 |
|
42 |
|
43 |
@spaces.GPU
|
@@ -51,32 +158,8 @@ def generate(
|
|
51 |
top_k: int = 50,
|
52 |
repetition_penalty: float = 1.2,
|
53 |
) -> Iterator[str]:
|
54 |
-
demo = ZeroShotChatTemplate()
|
55 |
-
prompt = demo.make_prompt(message)
|
56 |
-
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
57 |
-
|
58 |
-
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
59 |
-
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
60 |
-
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
61 |
-
input_ids = input_ids.to(model.device)
|
62 |
-
|
63 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
64 |
-
generate_kwargs = dict(
|
65 |
-
{"input_ids": input_ids},
|
66 |
-
streamer=streamer,
|
67 |
-
max_new_tokens=max_new_tokens,
|
68 |
-
do_sample=True,
|
69 |
-
top_p=top_p,
|
70 |
-
top_k=top_k,
|
71 |
-
temperature=temperature,
|
72 |
-
num_beams=1,
|
73 |
-
repetition_penalty=repetition_penalty,
|
74 |
-
)
|
75 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
76 |
-
t.start()
|
77 |
-
|
78 |
outputs = []
|
79 |
-
for text in
|
80 |
outputs.append(text)
|
81 |
yield "".join(outputs)
|
82 |
|
@@ -123,9 +206,15 @@ chat_interface = gr.ChatInterface(
|
|
123 |
],
|
124 |
stop_btn=None,
|
125 |
examples=[
|
126 |
-
[
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
],
|
130 |
cache_examples=False,
|
131 |
type="messages",
|
@@ -133,7 +222,9 @@ chat_interface = gr.ChatInterface(
|
|
133 |
|
134 |
with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
|
135 |
gr.Markdown(DESCRIPTION)
|
136 |
-
gr.DuplicateButton(
|
|
|
|
|
137 |
chat_interface.render()
|
138 |
gr.Markdown(LICENSE)
|
139 |
|
|
|
1 |
+
import hashlib
|
2 |
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Iterator, Optional, List, Union
|
5 |
|
6 |
import gradio as gr
|
7 |
import spaces
|
8 |
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from pydantic import BaseModel
|
11 |
+
from swift.llm import (
|
12 |
+
ModelType,
|
13 |
+
get_model_tokenizer,
|
14 |
+
get_default_template_type,
|
15 |
+
get_template,
|
16 |
+
inference,
|
17 |
+
inference_stream,
|
18 |
+
)
|
19 |
+
from transformers import (
|
20 |
+
Qwen2VLForConditionalGeneration,
|
21 |
+
PreTrainedTokenizer,
|
22 |
+
)
|
23 |
|
24 |
MAX_MAX_NEW_TOKENS = 2048
|
25 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
|
|
41 |
this demo is governed by the original [license](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE) and [acceptable use policy](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/USE_POLICY.md).
|
42 |
"""
|
43 |
|
44 |
+
|
45 |
+
def save_image(image: Image.Image, folder: str) -> str:
|
46 |
+
image_hash = hashlib.md5(image.tobytes()).hexdigest()
|
47 |
+
path = Path(folder, f"{image_hash}.png")
|
48 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
49 |
+
if not path.exists():
|
50 |
+
image.save(path)
|
51 |
+
return str(path)
|
52 |
+
|
53 |
+
|
54 |
+
def resize_image(image: Image.Image, max_size: int) -> Image.Image:
|
55 |
+
# Same as modeling.py resize_image
|
56 |
+
width, height = image.size
|
57 |
+
if width <= max_size and height <= max_size:
|
58 |
+
return image
|
59 |
+
if width > height:
|
60 |
+
new_width = max_size
|
61 |
+
new_height = round(height * max_size / width)
|
62 |
+
else:
|
63 |
+
new_height = max_size
|
64 |
+
new_width = round(width * max_size / height)
|
65 |
+
return image.resize((new_width, new_height), Image.LANCZOS)
|
66 |
+
|
67 |
+
|
68 |
+
class EvalModel(BaseModel, arbitrary_types_allowed=True):
|
69 |
+
engine: str
|
70 |
+
timeout: int = 60
|
71 |
+
temperature: float = 0.0
|
72 |
+
max_output_tokens: int = 512
|
73 |
+
|
74 |
+
def run(self, inputs: List[Union[str, Image.Image]]) -> str:
|
75 |
+
raise NotImplementedError
|
76 |
+
|
77 |
+
def run_many(self, inputs: List[Union[str, Image.Image]], num: int) -> List[str]:
|
78 |
+
raise NotImplementedError
|
79 |
+
|
80 |
+
|
81 |
+
class SwiftQwenModel(EvalModel):
|
82 |
+
# https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Multi-Modal/qwen2-vl-best-practice.md
|
83 |
+
path: str = ""
|
84 |
+
model: Optional[Qwen2VLForConditionalGeneration] = None
|
85 |
+
tokenizer: Optional[PreTrainedTokenizer] = None
|
86 |
+
engine: str = ModelType.qwen2_vl_7b_instruct
|
87 |
+
image_size: int = 768
|
88 |
+
image_dir: str = "data/qwen_images"
|
89 |
+
|
90 |
+
def load(self):
|
91 |
+
if self.model is None or self.tokenizer is None:
|
92 |
+
self.model, self.tokenizer = get_model_tokenizer(
|
93 |
+
self.engine,
|
94 |
+
torch.bfloat16,
|
95 |
+
model_kwargs={"device_map": "auto"},
|
96 |
+
model_id_or_path=self.path or None,
|
97 |
+
)
|
98 |
+
|
99 |
+
def run(self, inputs: List[Union[str, Image.Image]]) -> str:
|
100 |
+
self.load()
|
101 |
+
template_type = get_default_template_type(self.engine)
|
102 |
+
self.model.generation_config.max_new_tokens = self.max_output_tokens
|
103 |
+
template = get_template(template_type, self.tokenizer)
|
104 |
+
|
105 |
+
text = "\n\n".join([x for x in inputs if isinstance(x, str)])
|
106 |
+
content = []
|
107 |
+
for x in inputs:
|
108 |
+
if isinstance(x, Image.Image):
|
109 |
+
path = save_image(resize_image(x, self.image_size), self.image_dir)
|
110 |
+
content.append(f"<img>{path}</img>")
|
111 |
+
content.append(text)
|
112 |
+
|
113 |
+
query = "".join(content)
|
114 |
+
response, history = inference(self.model, template, query)
|
115 |
+
return response
|
116 |
+
|
117 |
+
def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]:
|
118 |
+
self.load()
|
119 |
+
template_type = get_default_template_type(self.engine)
|
120 |
+
self.model.generation_config.max_new_tokens = self.max_output_tokens
|
121 |
+
template = get_template(template_type, self.tokenizer)
|
122 |
+
|
123 |
+
text = "\n\n".join([x for x in inputs if isinstance(x, str)])
|
124 |
+
content = []
|
125 |
+
for x in inputs:
|
126 |
+
if isinstance(x, Image.Image):
|
127 |
+
path = save_image(resize_image(x, self.image_size), self.image_dir)
|
128 |
+
content.append(f"<img>{path}</img>")
|
129 |
+
content.append(text)
|
130 |
+
|
131 |
+
query = "".join(content)
|
132 |
+
generator = inference_stream(self.model, template, query)
|
133 |
+
print_idx = 0
|
134 |
+
print(f"query: {query}\nresponse: ", end="")
|
135 |
+
for response, history in generator:
|
136 |
+
delta = response[print_idx:]
|
137 |
+
print(delta, end="", flush=True)
|
138 |
+
print_idx = len(response)
|
139 |
+
yield delta
|
140 |
+
|
141 |
+
|
142 |
if not torch.cuda.is_available():
|
143 |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
144 |
|
145 |
|
146 |
if torch.cuda.is_available():
|
147 |
+
model = SwiftQwenModel()
|
|
|
|
|
|
|
148 |
|
149 |
|
150 |
@spaces.GPU
|
|
|
158 |
top_k: int = 50,
|
159 |
repetition_penalty: float = 1.2,
|
160 |
) -> Iterator[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
outputs = []
|
162 |
+
for text in model.run_stream(inputs=[message]):
|
163 |
outputs.append(text)
|
164 |
yield "".join(outputs)
|
165 |
|
|
|
206 |
],
|
207 |
stop_btn=None,
|
208 |
examples=[
|
209 |
+
[
|
210 |
+
"Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?"
|
211 |
+
],
|
212 |
+
[
|
213 |
+
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
|
214 |
+
],
|
215 |
+
[
|
216 |
+
"Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?"
|
217 |
+
],
|
218 |
],
|
219 |
cache_examples=False,
|
220 |
type="messages",
|
|
|
222 |
|
223 |
with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
|
224 |
gr.Markdown(DESCRIPTION)
|
225 |
+
gr.DuplicateButton(
|
226 |
+
value="Duplicate Space for private use", elem_id="duplicate-button"
|
227 |
+
)
|
228 |
chat_interface.render()
|
229 |
gr.Markdown(LICENSE)
|
230 |
|
run_demo.py
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
from typing import Optional, List
|
3 |
-
|
4 |
-
import vllm
|
5 |
-
from fire import Fire
|
6 |
-
from pydantic import BaseModel
|
7 |
-
from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForCausalLM
|
8 |
-
|
9 |
-
|
10 |
-
class ZeroShotChatTemplate:
|
11 |
-
# This is the default template used in llama-factory for training
|
12 |
-
texts: List[str] = []
|
13 |
-
|
14 |
-
@staticmethod
|
15 |
-
def make_prompt(prompt: str) -> str:
|
16 |
-
return f"Human: {prompt}\nAssistant: "
|
17 |
-
|
18 |
-
@staticmethod
|
19 |
-
def get_stopping_words() -> List[str]:
|
20 |
-
return ["Human:"]
|
21 |
-
|
22 |
-
@staticmethod
|
23 |
-
def extract_answer(text: str) -> str:
|
24 |
-
filtered = "".join([char for char in text if char.isdigit() or char == " "])
|
25 |
-
if not filtered.strip():
|
26 |
-
return text
|
27 |
-
return re.findall(pattern=r"\d+", string=filtered)[-1]
|
28 |
-
|
29 |
-
|
30 |
-
class VLLMModel(BaseModel, arbitrary_types_allowed=True):
|
31 |
-
path_model: str
|
32 |
-
model: vllm.LLM = None
|
33 |
-
tokenizer: Optional[PreTrainedTokenizer] = None
|
34 |
-
max_input_length: int = 512
|
35 |
-
max_output_length: int = 512
|
36 |
-
stopping_words: Optional[List[str]] = None
|
37 |
-
|
38 |
-
def load(self):
|
39 |
-
if self.model is None:
|
40 |
-
self.model = vllm.LLM(model=self.path_model, trust_remote_code=True)
|
41 |
-
if self.tokenizer is None:
|
42 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.path_model)
|
43 |
-
|
44 |
-
def format_prompt(self, prompt: str) -> str:
|
45 |
-
self.load()
|
46 |
-
prompt = prompt.rstrip(" ") # Llama is sensitive (eg "Answer:" vs "Answer: ")
|
47 |
-
return prompt
|
48 |
-
|
49 |
-
def make_kwargs(self, do_sample: bool, **kwargs) -> dict:
|
50 |
-
if self.stopping_words:
|
51 |
-
kwargs.update(stop=self.stopping_words)
|
52 |
-
params = vllm.SamplingParams(
|
53 |
-
temperature=0.5 if do_sample else 0.0,
|
54 |
-
max_tokens=self.max_output_length,
|
55 |
-
**kwargs,
|
56 |
-
)
|
57 |
-
|
58 |
-
outputs = dict(sampling_params=params, use_tqdm=False)
|
59 |
-
return outputs
|
60 |
-
|
61 |
-
def run(self, prompt: str) -> str:
|
62 |
-
prompt = self.format_prompt(prompt)
|
63 |
-
outputs = self.model.generate([prompt], **self.make_kwargs(do_sample=False))
|
64 |
-
pred = outputs[0].outputs[0].text
|
65 |
-
pred = pred.split("<|endoftext|>")[0]
|
66 |
-
return pred
|
67 |
-
|
68 |
-
|
69 |
-
def upload_to_hub(path: str, repo_id: str):
|
70 |
-
tokenizer = AutoTokenizer.from_pretrained(path)
|
71 |
-
model = AutoModelForCausalLM.from_pretrained(path)
|
72 |
-
model.push_to_hub(repo_id)
|
73 |
-
tokenizer.push_to_hub(repo_id)
|
74 |
-
|
75 |
-
|
76 |
-
def main(
|
77 |
-
question: str = "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?",
|
78 |
-
**kwargs,
|
79 |
-
):
|
80 |
-
model = VLLMModel(**kwargs)
|
81 |
-
demo = ZeroShotChatTemplate()
|
82 |
-
model.stopping_words = demo.get_stopping_words()
|
83 |
-
|
84 |
-
prompt = demo.make_prompt(question)
|
85 |
-
raw_outputs = model.run(prompt)
|
86 |
-
pred = demo.extract_answer(raw_outputs)
|
87 |
-
print(dict(question=question, prompt=prompt, raw_outputs=raw_outputs, pred=pred))
|
88 |
-
|
89 |
-
|
90 |
-
"""
|
91 |
-
p run_demo.py upload_to_hub outputs_paths/gsm8k_paths_llama3_8b_beta_03_rank_128/final chiayewken/llama3-8b-gsm8k-rpo
|
92 |
-
p run_demo.py main --path_model chiayewken/llama3-8b-gsm8k-rpo
|
93 |
-
"""
|
94 |
-
|
95 |
-
|
96 |
-
if __name__ == "__main__":
|
97 |
-
Fire()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|