lhoestq HF staff commited on
Commit
451395b
·
1 Parent(s): 2d4d597

revert batching

Browse files
Files changed (2) hide show
  1. generate.py +7 -29
  2. gradio_app.py +4 -4
generate.py CHANGED
@@ -3,7 +3,6 @@ import json
3
  import logging
4
  import regex
5
  import time
6
- from itertools import chain, islice
7
  from pathlib import Path
8
  from typing import Annotated, Iterator
9
 
@@ -23,16 +22,14 @@ logger = logging.getLogger(__name__)
23
 
24
 
25
  logger.warning("Loading model...")
 
 
26
  if torch.backends.mps.is_available():
27
  device = "mps"
28
- model_id = "Qwen/Qwen1.5-0.5B-Chat"
29
- batch_size = 1 # batching generates duplicates
30
  else:
31
  device = "cuda"
32
- model_id = "google/gemma-2b-it"
33
- batch_size = 1 # batching generates duplicates
34
-
35
- model = models.transformers(model_id, device=device)
36
 
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
  sampler = PenalizedMultinomialSampler()
@@ -98,24 +95,6 @@ def samples_prommpt(filename: str, prompt: str, columns: str):
98
  {{ prompt }}
99
  """
100
 
101
-
102
- def stream_json_objects_from_batched_tokens_generator(batched_tokens_generator: Iterator[list[str]], json_field: str) -> Iterator[dict]:
103
- first_batch = next(batched_tokens_generator)
104
- batch_size = len(first_batch)
105
- streams = [""] * batch_size
106
- skips = [0] * batch_size
107
- for tokens_batch in chain([first_batch], batched_tokens_generator):
108
- for stream_idx, token in enumerate(tokens_batch):
109
- streams[stream_idx] += token
110
- if '"' in token or "}" in token:
111
- try:
112
- for stream_sample in islice(ijson.items(StringIteratorIO(streams[stream_idx].__iter__()), json_field + ".item", buf_size=1), skips[stream_idx], None):
113
- yield stream_sample
114
- skips[stream_idx] = +1
115
- except ijson.IncompleteJSONError:
116
- pass
117
-
118
-
119
  def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
120
  filename = Path(filename).stem
121
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
@@ -155,8 +134,7 @@ def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int,
155
  tokenize=False,
156
  add_generation_prompt=True
157
  )
158
- batched_samples_generator_tokens = samples_generator.stream([text] * batch_size, rng=rng)
159
- json_field = list(Dataset.model_fields)[0]
160
- for _, sample in zip(range(size), stream_json_objects_from_batched_tokens_generator(batched_samples_generator_tokens, json_field=json_field)):
161
  yield json.dumps(sample, ensure_ascii=False) + "\n"
162
- logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
 
3
  import logging
4
  import regex
5
  import time
 
6
  from pathlib import Path
7
  from typing import Annotated, Iterator
8
 
 
22
 
23
 
24
  logger.warning("Loading model...")
25
+ model_id = "google/gemma-2b-it"
26
+ # model_id = "Qwen/Qwen1.5-0.5B-Chat"
27
  if torch.backends.mps.is_available():
28
  device = "mps"
29
+ model = models.transformers(model_id, device=device)
 
30
  else:
31
  device = "cuda"
32
+ model = models.transformers(model_id, device=device)
 
 
 
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(model_id)
35
  sampler = PenalizedMultinomialSampler()
 
95
  {{ prompt }}
96
  """
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
99
  filename = Path(filename).stem
100
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
 
134
  tokenize=False,
135
  add_generation_prompt=True
136
  )
137
+ samples_generator_tokens = samples_generator.stream(text, rng=rng)
138
+ for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)):
 
139
  yield json.dumps(sample, ensure_ascii=False) + "\n"
140
+ logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
gradio_app.py CHANGED
@@ -6,11 +6,11 @@ import io
6
  import pandas as pd
7
  import spaces
8
 
9
- from generate import model_id, stream_jsonl_file, batch_size
10
 
11
- MAX_SIZE = 20 * batch_size
12
  DEFAULT_SEED = 42
13
- DEFAULT_SIZE = 5 * batch_size
14
 
15
  @spaces.GPU(duration=120)
16
  def stream_output(query: str, continue_content: str = ""):
@@ -87,4 +87,4 @@ with gr.Blocks() as demo:
87
  generate_more_button.click(stream_more_output, filename_comp, outputs)
88
 
89
 
90
- demo.launch()
 
6
  import pandas as pd
7
  import spaces
8
 
9
+ from generate import model_id, stream_jsonl_file
10
 
11
+ MAX_SIZE = 20
12
  DEFAULT_SEED = 42
13
+ DEFAULT_SIZE = 3
14
 
15
  @spaces.GPU(duration=120)
16
  def stream_output(query: str, continue_content: str = ""):
 
87
  generate_more_button.click(stream_more_output, filename_comp, outputs)
88
 
89
 
90
+ demo.launch()