guy-dar commited on
Commit
12c81db
·
1 Parent(s): cea7b8e

add sidebar

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -21,6 +21,13 @@ model, model_params, tokenizer = load_model(model_name)
21
  # neuron_dim = col3.text_input("Dim: ", value='0')
22
  # neurons = model_params.K_heads[int(neuron_layer), int(neuron_dim)]
23
 
 
 
 
 
 
 
 
24
  prompt = st.text_area("Prompt: ")
25
  submitted = st.button("Send!")
26
 
@@ -28,9 +35,9 @@ if submitted:
28
  with st.spinner('Wait for it..'):
29
  model, model_params, tokenizer = map(deepcopy, (model, model_params, tokenizer))
30
  decoded = speaking_probe(model, model_params, tokenizer, prompt,
31
- repetition_penalty=2., num_generations=3,
32
  min_length=1, do_sample=True,
33
- max_new_tokens=100)
34
 
35
  for text in decoded:
36
  st.code('\n'.join(textwrap.wrap(text, width=70)), language=None)
 
21
  # neuron_dim = col3.text_input("Dim: ", value='0')
22
  # neurons = model_params.K_heads[int(neuron_layer), int(neuron_dim)]
23
 
24
+ with st.sidebar:
25
+ temperature = st.slider("Temperature", min_value=0., max_value=2., value=0.5, step=0.05)
26
+ repetition_penalty = st.slider("Repetition Penalty", min_value=0., max_value=4., value=2., step=0.1)
27
+ sidebar_cols = st.columns(2)
28
+ num_generations = sidebar_cols[0].number_input("Number of Answers", min_value=1, value=3, format='%d')
29
+ max_new_tokens = sidebar_cols[1].number_input("Max Answer Length", min_value=1, value=50, format='%d')
30
+
31
  prompt = st.text_area("Prompt: ")
32
  submitted = st.button("Send!")
33
 
 
35
  with st.spinner('Wait for it..'):
36
  model, model_params, tokenizer = map(deepcopy, (model, model_params, tokenizer))
37
  decoded = speaking_probe(model, model_params, tokenizer, prompt,
38
+ repetition_penalty=repetition_penalty, num_generations=num_generations,
39
  min_length=1, do_sample=True,
40
+ max_new_tokens=max_new_tokens, temperature=temperature)
41
 
42
  for text in decoded:
43
  st.code('\n'.join(textwrap.wrap(text, width=70)), language=None)