|
import copy |
|
|
|
|
|
|
|
default_req_params = { |
|
'max_new_tokens': 16, |
|
'auto_max_new_tokens': False, |
|
'max_tokens_second': 0, |
|
'temperature': 1.0, |
|
'top_p': 1.0, |
|
'top_k': 1, |
|
'repetition_penalty': 1.18, |
|
'repetition_penalty_range': 0, |
|
'encoder_repetition_penalty': 1.0, |
|
'suffix': None, |
|
'stream': False, |
|
'echo': False, |
|
'seed': -1, |
|
|
|
'truncation_length': 2048, |
|
'add_bos_token': True, |
|
'do_sample': True, |
|
'typical_p': 1.0, |
|
'epsilon_cutoff': 0.0, |
|
'eta_cutoff': 0.0, |
|
'tfs': 1.0, |
|
'top_a': 0.0, |
|
'min_length': 0, |
|
'no_repeat_ngram_size': 0, |
|
'num_beams': 1, |
|
'penalty_alpha': 0.0, |
|
'length_penalty': 1.0, |
|
'early_stopping': False, |
|
'mirostat_mode': 0, |
|
'mirostat_tau': 5.0, |
|
'mirostat_eta': 0.1, |
|
'guidance_scale': 1, |
|
'negative_prompt': '', |
|
'ban_eos_token': False, |
|
'custom_token_bans': '', |
|
'skip_special_tokens': True, |
|
'custom_stopping_strings': '', |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
def get_default_req_params(): |
|
return copy.deepcopy(default_req_params) |
|
|
|
|
|
def default(dic, key, default): |
|
''' |
|
little helper to get defaults if arg is present but None and should be the same type as default. |
|
''' |
|
val = dic.get(key, default) |
|
if not isinstance(val, type(default)): |
|
|
|
try: |
|
v = type(default)(val) |
|
if type(val)(v) == val: |
|
return v |
|
except: |
|
pass |
|
|
|
val = default |
|
return val |
|
|
|
|
|
def clamp(value, minvalue, maxvalue): |
|
return max(minvalue, min(value, maxvalue)) |
|
|