jwkirchenbauer commited on
Commit
0e6d24f
1 Parent(s): 4343565

features mostly in place

Browse files
Files changed (2) hide show
  1. app.py +2 -0
  2. demo_watermark.py +82 -59
app.py CHANGED
@@ -21,6 +21,7 @@ arg_dict = {
21
  'run_gradio': True,
22
  'demo_public': False,
23
  'model_name_or_path': 'facebook/opt-125m',
 
24
  'prompt_max_length': None,
25
  'max_new_tokens': 200,
26
  'generation_seed': 123,
@@ -36,6 +37,7 @@ arg_dict = {
36
  'detection_z_threshold': 4.0,
37
  'select_green_tokens': True,
38
  'skip_model_load': False,
 
39
  }
40
 
41
  args.__dict__.update(arg_dict)
 
21
  'run_gradio': True,
22
  'demo_public': False,
23
  'model_name_or_path': 'facebook/opt-125m',
24
+ # 'model_name_or_path': 'facebook/opt-2.7b',
25
  'prompt_max_length': None,
26
  'max_new_tokens': 200,
27
  'generation_seed': 123,
 
37
  'detection_z_threshold': 4.0,
38
  'select_green_tokens': True,
39
  'skip_model_load': False,
40
+ 'seed_separately': True,
41
  }
42
 
43
  args.__dict__.update(arg_dict)
demo_watermark.py CHANGED
@@ -223,7 +223,10 @@ def generate(prompt, args, model=None, device=None, tokenizer=None):
223
 
224
  torch.manual_seed(args.generation_seed)
225
  output_without_watermark = generate_without_watermark(**tokd_input)
226
- # torch.manual_seed(seed) # optional, but will not be the same again generally, unless delta==0.0, no-op watermark
 
 
 
227
  output_with_watermark = generate_with_watermark(**tokd_input)
228
 
229
  if args.is_decoder_only_model:
@@ -275,7 +278,52 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
275
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
276
  <p/>
277
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
 
 
 
 
 
 
 
 
 
279
  # Parameter selection group
280
  with gr.Accordion("Advanced Settings",open=False):
281
  with gr.Row():
@@ -302,11 +350,29 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
302
  ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
303
  with gr.Row():
304
  normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers)
 
 
 
 
 
 
 
 
 
 
305
 
306
- # State manager
307
- # Construct state for parameters, define updates and toggles, and register event listeners
308
- session_args = gr.State(value=args)
 
 
 
 
 
 
 
309
 
 
310
  def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
311
  def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
312
  def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
@@ -331,76 +397,33 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
331
  def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
332
  def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
333
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
 
 
334
 
335
- decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
336
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
337
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
338
  decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
339
-
 
340
  sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
341
  generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
342
  n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
343
  max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
344
-
345
  gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
346
  delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
347
  ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args])
348
  normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args])
 
 
349
 
350
- with gr.Tab("Generation"):
351
-
352
- with gr.Row():
353
- prompt = gr.Textbox(label=f"Prompt", interactive=True)
354
- with gr.Row():
355
- generate_btn = gr.Button("Generate")
356
- with gr.Row():
357
- with gr.Column(scale=2):
358
- output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False)
359
- with gr.Column(scale=1):
360
- without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False)
361
- with gr.Row():
362
- with gr.Column(scale=2):
363
- output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False)
364
- with gr.Column(scale=1):
365
- with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False)
366
-
367
-
368
- redecoded_input = gr.Textbox(visible=False)
369
- truncation_warning = gr.Number(visible=False)
370
- def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
371
- if truncation_warning:
372
- return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
373
- else:
374
- return orig_prompt, args
375
-
376
- generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
377
-
378
- # Show truncated version of prompt if truncation occurred
379
- redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
380
 
381
- # Call detection when the outputs of the generate function are updated.
382
- output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
383
- output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
384
 
385
- with gr.Tab("Detector Only"):
386
- with gr.Row():
387
- detection_input = gr.Textbox(label="Text to Analyze", interactive=True)
388
- with gr.Row():
389
- detect_btn = gr.Button("Detect")
390
- with gr.Row():
391
- detection_result = gr.Textbox(label="Detection Result", interactive=False)
392
- detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
393
-
394
- with gr.Accordion("A note on model capability",open=False):
395
- gr.Markdown(
396
- """
397
- The models that can be used in this demo are limited to those that are open source as well as fit on a single commodity GPU. In particular, there are few models above 10B parameters and way fewer trained using both Instruction finetuning or RLHF that are open source that we can use.
398
-
399
- Therefore, the model, in both it's un-watermarked (normal) and watermarked state, is not generally able to respond well to the kinds of prompts that a 100B+ Instruction and RLHF tuned model such as ChatGPT, Claude, or Bard is.
400
-
401
- We suggest you try prompts that give the model a few sentences and then allow it to 'continue' the prompt, as these weaker models are more capable in this simpler language modeling setting.
402
- """
403
- )
404
 
405
  if args.demo_public:
406
  demo.launch(share=True) # exposes app to the internet via randomly generated link
 
223
 
224
  torch.manual_seed(args.generation_seed)
225
  output_without_watermark = generate_without_watermark(**tokd_input)
226
+
227
+ # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
228
+ if args.seed_separately:
229
+ torch.manual_seed(args.generation_seed)
230
  output_with_watermark = generate_with_watermark(**tokd_input)
231
 
232
  if args.is_decoder_only_model:
 
278
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
279
  <p/>
280
  """)
281
+ # Construct state for parameters, define updates and toggles, and register event listeners
282
+ session_args = gr.State(value=args)
283
+
284
+ with gr.Tab("Generation"):
285
+
286
+ with gr.Row():
287
+ prompt = gr.Textbox(label=f"Prompt", interactive=True)
288
+ with gr.Row():
289
+ generate_btn = gr.Button("Generate")
290
+ with gr.Row():
291
+ with gr.Column(scale=2):
292
+ output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False)
293
+ with gr.Column(scale=1):
294
+ without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False)
295
+ with gr.Row():
296
+ with gr.Column(scale=2):
297
+ output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False)
298
+ with gr.Column(scale=1):
299
+ with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False)
300
+
301
+ redecoded_input = gr.Textbox(visible=False)
302
+ truncation_warning = gr.Number(visible=False)
303
+ def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
304
+ if truncation_warning:
305
+ return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
306
+ else:
307
+ return orig_prompt, args
308
+
309
+ generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
310
+
311
+ # Show truncated version of prompt if truncation occurred
312
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
313
+
314
+ # Call detection when the outputs of the generate function are updated.
315
+ output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
316
+ output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
317
 
318
+ with gr.Tab("Detector Only"):
319
+ with gr.Row():
320
+ detection_input = gr.Textbox(label="Text to Analyze", interactive=True)
321
+ with gr.Row():
322
+ detect_btn = gr.Button("Detect")
323
+ with gr.Row():
324
+ detection_result = gr.Textbox(label="Detection Result", interactive=False)
325
+ detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
326
+
327
  # Parameter selection group
328
  with gr.Accordion("Advanced Settings",open=False):
329
  with gr.Row():
 
350
  ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
351
  with gr.Row():
352
  normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers)
353
+ gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help._")
354
+ with gr.Accordion("Actual submitted parameters:",open=False):
355
+ current_parameters = gr.Textbox(label="submitted parameters", value=args)
356
+ with gr.Accordion("Legacy Settings",open=False):
357
+ with gr.Row():
358
+ with gr.Column(scale=1):
359
+ seed_separately = gr.Checkbox(label="Seed both generations separately", value=args.seed_separately)
360
+ with gr.Column(scale=1):
361
+ select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition", value=args.select_green_tokens)
362
+
363
 
364
+ with gr.Accordion("A note on model capability",open=False):
365
+ gr.Markdown(
366
+ """
367
+ The models that can be used in this demo are limited to those that are open source as well as fit on a single commodity GPU. In particular, there are few models above 10B parameters and way fewer trained using both Instruction finetuning or RLHF that are open source that we can use.
368
+
369
+ Therefore, the model, in both it's un-watermarked (normal) and watermarked state, is not generally able to respond well to the kinds of prompts that a 100B+ Instruction and RLHF tuned model such as ChatGPT, Claude, or Bard is.
370
+
371
+ We suggest you try prompts that give the model a few sentences and then allow it to 'continue' the prompt, as these weaker models are more capable in this simpler language modeling setting.
372
+ """
373
+ )
374
 
375
+ # State manager logic
376
  def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
377
  def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
378
  def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
 
397
  def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
398
  def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
399
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
400
+ def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
401
+ def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
402
 
 
403
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
404
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
405
  decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
406
+
407
+ decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
408
  sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
409
  generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
410
  n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
411
  max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
 
412
  gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
413
  delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
414
  ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args])
415
  normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args])
416
+ seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args])
417
+ select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args])
418
 
419
+ generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
420
+ detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
+ # When the parameters change, also fire detection, since some detection params dont change the model output.
423
+ current_parameters.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
424
+ current_parameters.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
425
 
426
+ demo.queue(concurrency_count=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
  if args.demo_public:
429
  demo.launch(share=True) # exposes app to the internet via randomly generated link