Spaces:
Running
on
Zero
Running
on
Zero
lower columns temperature
Browse files- generate.py +5 -2
- samplers.py +1 -0
generate.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
import json
|
3 |
import logging
|
4 |
import time
|
|
|
5 |
from typing import Annotated, Iterator
|
6 |
|
7 |
import ijson
|
@@ -31,6 +32,7 @@ else:
|
|
31 |
|
32 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
33 |
sampler = PenalizedMultinomialSampler()
|
|
|
34 |
empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]
|
35 |
sampler.set_max_repeats(empty_tokens, 1)
|
36 |
|
@@ -56,7 +58,7 @@ samples_generator_template = generate.json(model, Dataset, sampler=sampler)
|
|
56 |
class Columns(BaseModel):
|
57 |
columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore
|
58 |
|
59 |
-
columns_generator = generate.json(model, Columns, sampler=
|
60 |
|
61 |
def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
|
62 |
fsm=samples_generator_template.fsm
|
@@ -89,7 +91,8 @@ def samples_prommpt(filename: str, prompt: str, columns: str):
|
|
89 |
{{ prompt }}
|
90 |
"""
|
91 |
|
92 |
-
def
|
|
|
93 |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
|
94 |
_start = time.time()
|
95 |
rng = torch.Generator(device=model.device)
|
|
|
2 |
import json
|
3 |
import logging
|
4 |
import time
|
5 |
+
from pathlib import Path
|
6 |
from typing import Annotated, Iterator
|
7 |
|
8 |
import ijson
|
|
|
32 |
|
33 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
34 |
sampler = PenalizedMultinomialSampler()
|
35 |
+
low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3)
|
36 |
empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]
|
37 |
sampler.set_max_repeats(empty_tokens, 1)
|
38 |
|
|
|
58 |
class Columns(BaseModel):
|
59 |
columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore
|
60 |
|
61 |
+
columns_generator = generate.json(model, Columns, sampler=low_temperature_sampler)
|
62 |
|
63 |
def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
|
64 |
fsm=samples_generator_template.fsm
|
|
|
91 |
{{ prompt }}
|
92 |
"""
|
93 |
|
94 |
+
def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
|
95 |
+
filename = Path(filename).stem
|
96 |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
|
97 |
_start = time.time()
|
98 |
rng = torch.Generator(device=model.device)
|
samplers.py
CHANGED
@@ -6,6 +6,7 @@ from outlines.samplers import MultinomialSampler
|
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
|
|
9 |
class PenalizedMultinomialSampler(MultinomialSampler):
|
10 |
|
11 |
def __init__(self, **kwargs):
|
|
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
9 |
+
|
10 |
class PenalizedMultinomialSampler(MultinomialSampler):
|
11 |
|
12 |
def __init__(self, **kwargs):
|