Slider for repetition_penalty

#4
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -14,13 +14,8 @@ checkpoint = "CohereForAI/aya-101"
14
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
15
  model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device)
16
 
17
- #Set a the value of the repetition penalty
18
- #The higher the value, the less repetitive the generated text will be
19
- #Note that `repetition_penalty` has to be a strictly positive float
20
- repetition_penalty = 1.8
21
-
22
  @spaces.GPU
23
- def aya(text, max_new_tokens):
24
  model.to(device)
25
  inputs = tokenizer.encode(text, return_tensors="pt").to(device)
26
  outputs = model.generate(inputs, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty)
@@ -33,9 +28,10 @@ def main():
33
  gr.Markdown(description)
34
  input_text = gr.Textbox(label="🗣️Input Text")
35
  max_new_tokens_slider = gr.Slider(minimum=150, maximum=1648, step=1, value=250, label="Size of your inputs and answer")
 
36
  submit_button = gr.Button("Use🌐Aya")
37
  output_text = gr.Textbox(label="🌐Aya", interactive=False)
38
- submit_button.click(fn=aya, inputs=[input_text, max_new_tokens_slider], outputs=output_text)
39
 
40
  demo.launch()
41
 
 
14
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
15
  model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device)
16
 
 
 
 
 
 
17
  @spaces.GPU
18
+ def aya(text, max_new_tokens, repetition_penalty):
19
  model.to(device)
20
  inputs = tokenizer.encode(text, return_tensors="pt").to(device)
21
  outputs = model.generate(inputs, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty)
 
28
  gr.Markdown(description)
29
  input_text = gr.Textbox(label="🗣️Input Text")
30
  max_new_tokens_slider = gr.Slider(minimum=150, maximum=1648, step=1, value=250, label="Size of your inputs and answer")
31
+ repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=4.0, step=0.1, value=1.8, label="Repetition Penalty")
32
  submit_button = gr.Button("Use🌐Aya")
33
  output_text = gr.Textbox(label="🌐Aya", interactive=False)
34
+ submit_button.click(fn=aya, inputs=[input_text, max_new_tokens_slider, repetition_penalty_slider], outputs=output_text)
35
 
36
  demo.launch()
37