Tawkat commited on
Commit
5423661
1 Parent(s): d5b0eac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -13
app.py CHANGED
@@ -15,20 +15,20 @@ def format_prompt(message, history):
15
  return prompt
16
 
17
  def generate(
18
- prompt, history, max_tokens=2000,
19
  ):
20
- '''temperature = float(temperature)
21
  if temperature < 1e-2:
22
  temperature = 1e-2
23
- top_p = float(top_p)'''
24
 
25
  generate_kwargs = dict(
26
- #temperature=temperature,
27
- max_tokens=max_tokens,
28
- #top_p=top_p,
29
- #repetition_penalty=repetition_penalty,
30
- #do_sample=True,
31
- #seed=42,
32
  )
33
 
34
  formatted_prompt = format_prompt(prompt, history)
@@ -44,15 +44,41 @@ def generate(
44
 
45
  additional_inputs=[
46
  gr.Slider(
47
- label="Max tokens",
48
- value=2000,
 
 
 
 
 
 
 
 
 
49
  minimum=0,
50
- maximum=2048,
51
  step=64,
52
  interactive=True,
53
  info="The maximum numbers of new tokens",
54
  ),
55
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ]
57
 
58
 
 
15
  return prompt
16
 
17
  def generate(
18
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
  ):
20
+ temperature = float(temperature)
21
  if temperature < 1e-2:
22
  temperature = 1e-2
23
+ top_p = float(top_p)
24
 
25
  generate_kwargs = dict(
26
+ temperature=temperature,
27
+ max_new_tokens=max_new_tokens,
28
+ top_p=top_p,
29
+ repetition_penalty=repetition_penalty,
30
+ do_sample=True,
31
+ seed=42,
32
  )
33
 
34
  formatted_prompt = format_prompt(prompt, history)
 
44
 
45
  additional_inputs=[
46
  gr.Slider(
47
+ label="Temperature",
48
+ value=0.9,
49
+ minimum=0.0,
50
+ maximum=1.0,
51
+ step=0.05,
52
+ interactive=True,
53
+ info="Higher values produce more diverse outputs",
54
+ ),
55
+ gr.Slider(
56
+ label="Max new tokens",
57
+ value=256,
58
  minimum=0,
59
+ maximum=1048,
60
  step=64,
61
  interactive=True,
62
  info="The maximum numbers of new tokens",
63
  ),
64
+ gr.Slider(
65
+ label="Top-p (nucleus sampling)",
66
+ value=0.90,
67
+ minimum=0.0,
68
+ maximum=1,
69
+ step=0.05,
70
+ interactive=True,
71
+ info="Higher values sample more low-probability tokens",
72
+ ),
73
+ gr.Slider(
74
+ label="Repetition penalty",
75
+ value=1.2,
76
+ minimum=1.0,
77
+ maximum=2.0,
78
+ step=0.05,
79
+ interactive=True,
80
+ info="Penalize repeated tokens",
81
+ )
82
  ]
83
 
84