VictorSanh commited on
Commit
8a29e64
·
1 Parent(s): 1552af8

update the generation args

Browse files
Files changed (1) hide show
  1. README.md +8 -2
README.md CHANGED
@@ -92,7 +92,10 @@ inputs = processor(prompts, return_tensors="pt").to(device)
92
  # --single sample mode
93
  # inputs = processor(prompts[0], return_tensors="pt").to(device)
94
 
95
- generated_ids = model.generate(**inputs, max_length=100)
 
 
 
96
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
97
  for i, t in enumerate(generated_text):
98
  print(f"{i}:\n{t}\n")
@@ -132,9 +135,12 @@ prompts = [
132
  inputs = processor(prompts, add_end_of_utterance_token=False, return_tensors="pt").to(device)
133
  # --single sample mode
134
  # inputs = processor(prompts[0], return_tensors="pt").to(device)
 
 
135
  exit_condition = processor.tokenizer("<end_of_utterance>", add_special_tokens=False).input_ids
 
136
 
137
- generated_ids = model.generate(**inputs, eos_token_id=exit_condition, max_length=100)
138
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
139
  for i, t in enumerate(generated_text):
140
  print(f"{i}:\n{t}\n")
 
92
  # --single sample mode
93
  # inputs = processor(prompts[0], return_tensors="pt").to(device)
94
 
95
+ # Generation args
96
+ bad_words_ids = tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
97
+
98
+ generated_ids = model.generate(**inputs, bad_words_ids=bad_words_ids, max_length=100)
99
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
100
  for i, t in enumerate(generated_text):
101
  print(f"{i}:\n{t}\n")
 
135
  inputs = processor(prompts, add_end_of_utterance_token=False, return_tensors="pt").to(device)
136
  # --single sample mode
137
  # inputs = processor(prompts[0], return_tensors="pt").to(device)
138
+
139
+ # Generation args
140
  exit_condition = processor.tokenizer("<end_of_utterance>", add_special_tokens=False).input_ids
141
+ bad_words_ids = tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
142
 
143
+ generated_ids = model.generate(**inputs, eos_token_id=exit_condition, bad_words_ids=bad_words_ids, max_length=100)
144
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
145
  for i, t in enumerate(generated_text):
146
  print(f"{i}:\n{t}\n")