lhoestq HF staff commited on
Commit
72a89db
1 Parent(s): 61755fe

lower columns temperature

Browse files
Files changed (2) hide show
  1. generate.py +5 -2
  2. samplers.py +1 -0
generate.py CHANGED
@@ -2,6 +2,7 @@
2
  import json
3
  import logging
4
  import time
 
5
  from typing import Annotated, Iterator
6
 
7
  import ijson
@@ -31,6 +32,7 @@ else:
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(model_id)
33
  sampler = PenalizedMultinomialSampler()
 
34
  empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]
35
  sampler.set_max_repeats(empty_tokens, 1)
36
 
@@ -56,7 +58,7 @@ samples_generator_template = generate.json(model, Dataset, sampler=sampler)
56
  class Columns(BaseModel):
57
  columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore
58
 
59
- columns_generator = generate.json(model, Columns, sampler=sampler)
60
 
61
  def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
62
  fsm=samples_generator_template.fsm
@@ -89,7 +91,8 @@ def samples_prommpt(filename: str, prompt: str, columns: str):
89
  {{ prompt }}
90
  """
91
 
92
- def stream_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
 
93
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
94
  _start = time.time()
95
  rng = torch.Generator(device=model.device)
 
2
  import json
3
  import logging
4
  import time
5
+ from pathlib import Path
6
  from typing import Annotated, Iterator
7
 
8
  import ijson
 
32
 
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
  sampler = PenalizedMultinomialSampler()
35
+ low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3)
36
  empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]
37
  sampler.set_max_repeats(empty_tokens, 1)
38
 
 
58
  class Columns(BaseModel):
59
  columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore
60
 
61
+ columns_generator = generate.json(model, Columns, sampler=low_temperature_sampler)
62
 
63
  def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
64
  fsm=samples_generator_template.fsm
 
91
  {{ prompt }}
92
  """
93
 
94
+ def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
95
+ filename = Path(filename).stem
96
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
97
  _start = time.time()
98
  rng = torch.Generator(device=model.device)
samplers.py CHANGED
@@ -6,6 +6,7 @@ from outlines.samplers import MultinomialSampler
6
 
7
  logger = logging.getLogger(__name__)
8
 
 
9
  class PenalizedMultinomialSampler(MultinomialSampler):
10
 
11
  def __init__(self, **kwargs):
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
+
10
  class PenalizedMultinomialSampler(MultinomialSampler):
11
 
12
  def __init__(self, **kwargs):