loubnabnl HF staff commited on
Commit
7af0e86
1 Parent(s): 06fe7ef

add api endpoints and dropdown for models

Browse files
Files changed (1) hide show
  1. app.py +60 -47
app.py CHANGED
@@ -11,7 +11,7 @@ from share_btn import community_icon_html, loading_icon_html, share_js, share_bt
11
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
  API_URL = "https://api-inference.huggingface.co/models/bigcode/starcoder/"
14
-
15
 
16
  FIM_PREFIX = "<fim_prefix>"
17
  FIM_MIDDLE = "<fim_middle>"
@@ -77,10 +77,12 @@ client = Client(
77
  API_URL,
78
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
79
  )
80
-
 
 
81
 
82
  def generate(
83
- prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0
84
  ):
85
 
86
  temperature = float(temperature)
@@ -106,7 +108,10 @@ def generate(
106
  raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
107
  prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
108
 
109
- stream = client.generate_stream(prompt, **generate_kwargs)
 
 
 
110
 
111
  if fim_mode:
112
  output = prefix
@@ -178,48 +183,56 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
178
  )
179
  submit = gr.Button("Generate", variant="primary")
180
  output = gr.Code(elem_id="q-output", lines=30)
181
-
182
- with gr.Accordion("Advanced settings", open=False):
183
- with gr.Row():
184
- column_1, column_2 = gr.Column(), gr.Column()
185
- with column_1:
186
- temperature = gr.Slider(
187
- label="Temperature",
188
- value=0.2,
189
- minimum=0.0,
190
- maximum=1.0,
191
- step=0.05,
192
- interactive=True,
193
- info="Higher values produce more diverse outputs",
194
- )
195
- max_new_tokens = gr.Slider(
196
- label="Max new tokens",
197
- value=256,
198
- minimum=0,
199
- maximum=8192,
200
- step=64,
201
- interactive=True,
202
- info="The maximum numbers of new tokens",
203
- )
204
- with column_2:
205
- top_p = gr.Slider(
206
- label="Top-p (nucleus sampling)",
207
- value=0.90,
208
- minimum=0.0,
209
- maximum=1,
210
- step=0.05,
211
- interactive=True,
212
- info="Higher values sample more low-probability tokens",
213
- )
214
- repetition_penalty = gr.Slider(
215
- label="Repetition penalty",
216
- value=1.2,
217
- minimum=1.0,
218
- maximum=2.0,
219
- step=0.05,
220
- interactive=True,
221
- info="Penalize repeated tokens",
222
- )
 
 
 
 
 
 
 
 
223
  gr.Markdown(disclaimer)
224
  with gr.Group(elem_id="share-btn-container"):
225
  community_icon = gr.HTML(community_icon_html, visible=True)
@@ -238,7 +251,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
238
 
239
  submit.click(
240
  generate,
241
- inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty],
242
  outputs=[output],
243
  )
244
  share_button.click(None, [], [], _js=share_js)
 
11
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
  API_URL = "https://api-inference.huggingface.co/models/bigcode/starcoder/"
14
+ API_URL_BASE ="https://api-inference.huggingface.co/models/bigcode/starcoderbase/"
15
 
16
  FIM_PREFIX = "<fim_prefix>"
17
  FIM_MIDDLE = "<fim_middle>"
 
77
  API_URL,
78
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
79
  )
80
+ client_base = Client(
81
+ API_URL_BASE, headers={"Authorization": f"Bearer {HF_TOKEN}"},
82
+ )
83
 
84
  def generate(
85
+ prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, version="StarCoder",
86
  ):
87
 
88
  temperature = float(temperature)
 
108
  raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
109
  prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
110
 
111
+ if version == "StarCoder":
112
+ stream = client.generate_stream(prompt, **generate_kwargs)
113
+ else:
114
+ stream = client_base.generate_stream(prompt, **generate_kwargs)
115
 
116
  if fim_mode:
117
  output = prefix
 
183
  )
184
  submit = gr.Button("Generate", variant="primary")
185
  output = gr.Code(elem_id="q-output", lines=30)
186
+ with gr.Row():
187
+ with gr.Column():
188
+ with gr.Accordion("Advanced settings", open=False):
189
+ with gr.Row():
190
+ column_1, column_2 = gr.Column(), gr.Column()
191
+ with column_1:
192
+ temperature = gr.Slider(
193
+ label="Temperature",
194
+ value=0.2,
195
+ minimum=0.0,
196
+ maximum=1.0,
197
+ step=0.05,
198
+ interactive=True,
199
+ info="Higher values produce more diverse outputs",
200
+ )
201
+ max_new_tokens = gr.Slider(
202
+ label="Max new tokens",
203
+ value=256,
204
+ minimum=0,
205
+ maximum=8192,
206
+ step=64,
207
+ interactive=True,
208
+ info="The maximum numbers of new tokens",
209
+ )
210
+ with column_2:
211
+ top_p = gr.Slider(
212
+ label="Top-p (nucleus sampling)",
213
+ value=0.90,
214
+ minimum=0.0,
215
+ maximum=1,
216
+ step=0.05,
217
+ interactive=True,
218
+ info="Higher values sample more low-probability tokens",
219
+ )
220
+ repetition_penalty = gr.Slider(
221
+ label="Repetition penalty",
222
+ value=1.2,
223
+ minimum=1.0,
224
+ maximum=2.0,
225
+ step=0.05,
226
+ interactive=True,
227
+ info="Penalize repeated tokens",
228
+ )
229
+ with gr.Column():
230
+ version = gr.Dropdown(
231
+ ["StarCoderBase", "StarCoder"],
232
+ value="StarCoder",
233
+ label="Version",
234
+ info="",
235
+ )
236
  gr.Markdown(disclaimer)
237
  with gr.Group(elem_id="share-btn-container"):
238
  community_icon = gr.HTML(community_icon_html, visible=True)
 
251
 
252
  submit.click(
253
  generate,
254
+ inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version],
255
  outputs=[output],
256
  )
257
  share_button.click(None, [], [], _js=share_js)