dvruette commited on
Commit
2a046f2
1 Parent(s): 7fb272d

initial commit

Browse files
README.md CHANGED
@@ -1,13 +1,19 @@
1
  ---
2
  title: Concept Guidance
3
- emoji: 🐠
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.18.0
8
- app_file: app.py
9
  pinned: false
10
  license: mit
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: Concept Guidance
3
+ emoji: 💆
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.18.0
8
+ app_file: main.py
9
  pinned: false
10
  license: mit
11
+ models: ["meta-llama/Llama-2-7b-chat-hf", "mistralai/Mistral-7B-Instruct-v0.1"]
12
+ datasets: ["OpenAssistant/oasst1", "dvruette/toxic-completions", "truthfulqa"]
13
  ---
14
 
15
+ # A Language Model's Guide Through Latent Space
16
+
17
+ An interactive demo accompanying the paper "A Language Model's Guide Through Latent Space".
18
+
19
+ Arxiv: [COMING SOON]
main.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from threading import Thread
4
+
5
+ import time
6
+ import torch
7
+ import gradio as gr
8
+ from concept_guidance.chat_template import DEFAULT_CHAT_TEMPLATE
9
+ from concept_guidance.patching import patch_model, load_weights
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer, Conversation
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ # device = "cpu"
17
+
18
+ MODEL_CONFIGS = {
19
+ "Llama-2-7b-chat-hf": {
20
+ "identifier": "meta-llama/Llama-2-7b-chat-hf",
21
+ "dtype": torch.float16 if device.type == "cuda" else torch.float32,
22
+ "guidance_interval": [-16.0, 16.0],
23
+ "default_guidance_scale": 8.0,
24
+ "min_guidance_layer": 16,
25
+ "max_guidance_layer": 32,
26
+ "default_concept": "humor",
27
+ "concepts": ["humor", "creativity", "quality", "truthfulness", "compliance"],
28
+ },
29
+ "Mistral-7B-Instruct-v0.1": {
30
+ "identifier": "mistralai/Mistral-7B-Instruct-v0.1",
31
+ "dtype": torch.bfloat16 if device.type == "cuda" else torch.float32,
32
+ "guidance_interval": [-128.0, 128.0],
33
+ "default_guidance_scale": 48.0,
34
+ "min_guidance_layer": 8,
35
+ "max_guidance_layer": 32,
36
+ "default_concept": "humor",
37
+ "concepts": ["humor", "creativity", "quality", "truthfulness", "compliance"],
38
+ },
39
+ }
40
+
41
+ def load_concept_vectors(model, concepts):
42
+ return {concept: load_weights(f"trained_concepts/{model}/{concept}.safetensors") for concept in concepts}
43
+
44
+ def load_model(model_name):
45
+ config = MODEL_CONFIGS[model_name]
46
+ model = AutoModelForCausalLM.from_pretrained(config["identifier"], torch_dtype=config["dtype"])
47
+ tokenizer = AutoTokenizer.from_pretrained(config["identifier"])
48
+ if tokenizer.chat_template is None:
49
+ tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
50
+ return model, tokenizer
51
+
52
+ CONCEPTS = ["humor", "creativity", "quality", "truthfulness", "compliance"]
53
+ CONCEPT_VECTORS = {model_name: load_concept_vectors(model_name, CONCEPTS) for model_name in MODEL_CONFIGS}
54
+ MODELS = {model_name: load_model(model_name) for model_name in MODEL_CONFIGS}
55
+
56
+
57
+ def history_to_conversation(history):
58
+ conversation = Conversation()
59
+ for prompt, completion in history:
60
+ conversation.add_message({"role": "user", "content": prompt})
61
+ if completion is not None:
62
+ conversation.add_message({"role": "assistant", "content": completion})
63
+ return conversation
64
+
65
+
66
+
67
+ def set_defaults(model_name):
68
+ config = MODEL_CONFIGS[model_name]
69
+ return (
70
+ model_name,
71
+ gr.update(choices=config["concepts"], value=config["concepts"][0]),
72
+ gr.update(minimum=config["guidance_interval"][0], maximum=config["guidance_interval"][1], value=config["default_guidance_scale"]),
73
+ gr.update(value=config["min_guidance_layer"]),
74
+ gr.update(value=config["max_guidance_layer"]),
75
+ )
76
+
77
+ def add_user_prompt(user_message, history):
78
+ if history is None:
79
+ history = []
80
+ history.append([user_message, None])
81
+ return history
82
+
83
+ @torch.no_grad()
84
+ def generate_completion(
85
+ history,
86
+ model_name,
87
+ concept,
88
+ guidance_scale=4.0,
89
+ min_guidance_layer=16,
90
+ max_guidance_layer=32,
91
+ temperature=0.0,
92
+ repetition_penalty=1.2,
93
+ length_penalty=1.2,
94
+ ):
95
+ start_time = time.time()
96
+ logger.info(f" --- Starting completion ({model_name}, {concept=}, {guidance_scale=}, {min_guidance_layer=}, {temperature=})")
97
+ logger.info(" User: " + repr(history[-1][0]))
98
+
99
+ # move all other models to CPU
100
+ for name, (model, _) in MODELS.items():
101
+ if name != model_name:
102
+ model.to("cpu")
103
+ torch.cuda.empty_cache()
104
+ # load the model
105
+ model, tokenizer = MODELS[model_name]
106
+ model = model.to(device, non_blocking=True)
107
+
108
+ concept_vector = CONCEPT_VECTORS[model_name][concept]
109
+ guidance_layers = list(range(int(min_guidance_layer) - 1, int(max_guidance_layer)))
110
+ patch_model(model, concept_vector, guidance_scale=guidance_scale, guidance_layers=guidance_layers)
111
+ pipe = pipeline("conversational", model=model, tokenizer=tokenizer, device=device)
112
+
113
+ conversation = history_to_conversation(history)
114
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
115
+
116
+ generation_kwargs = dict(
117
+ max_new_tokens=512,
118
+ repetition_penalty=repetition_penalty,
119
+ length_penalty=length_penalty,
120
+ streamer=streamer,
121
+ temperature=temperature,
122
+ do_sample=(temperature > 0)
123
+ )
124
+ thread = Thread(target=pipe, args=(conversation,), kwargs=generation_kwargs, daemon=True)
125
+ thread.start()
126
+
127
+ history[-1][1] = ""
128
+ for token in streamer:
129
+ history[-1][1] += token
130
+ yield history
131
+ logger.info(" Assistant: " + repr(history[-1][1]))
132
+
133
+ time_taken = time.time() - start_time
134
+ logger.info(f" --- Completed (took {time_taken:.1f}s)")
135
+ return history
136
+
137
+
138
+ class ConceptGuidanceUI:
139
+ def __init__(self):
140
+ model_names = list(MODEL_CONFIGS.keys())
141
+ default_model = model_names[0]
142
+ default_config = MODEL_CONFIGS[default_model]
143
+ default_concepts = default_config["concepts"]
144
+
145
+ saved_input = gr.State("")
146
+
147
+ with gr.Row(elem_id="concept-guidance-container"):
148
+ with gr.Column(scale=1, min_width=256):
149
+ model_dropdown = gr.Dropdown(model_names, value=default_model, label="Model")
150
+ concept_dropdown = gr.Dropdown(default_concepts, value=default_concepts[0], label="Concept")
151
+ guidance_scale = gr.Slider(*default_config["guidance_interval"], value=default_config["default_guidance_scale"], label="Guidance Scale")
152
+ min_guidance_layer = gr.Slider(1.0, 32.0, value=16.0, step=1.0, label="First Guidance Layer")
153
+ max_guidance_layer = gr.Slider(1.0, 32.0, value=32.0, step=1.0, label="Last Guidance Layer")
154
+ temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Temperature")
155
+ repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, step=0.01, label="Repetition Penalty")
156
+ length_penalty = gr.Slider(0.0, 2.0, value=1.2, step=0.01, label="Length Penalty")
157
+
158
+ with gr.Column(scale=3, min_width=512):
159
+ chatbot = gr.Chatbot(scale=1, height=200)
160
+
161
+ with gr.Row():
162
+ self.retry_btn = gr.Button("🔄 Retry", size="sm")
163
+ self.undo_btn = gr.Button("↩️ Undo", size="sm")
164
+ self.clear_btn = gr.Button("🗑️ Clear", size="sm")
165
+
166
+ with gr.Group():
167
+ with gr.Row():
168
+ prompt_field = gr.Textbox(placeholder="Type a message...", show_label=False, label="Message", scale=7, container=False)
169
+ self.submit_btn = gr.Button("Submit", variant="primary", scale=1, min_width=150)
170
+ self.stop_btn = gr.Button("Stop", variant="secondary", scale=1, min_width=150, visible=False)
171
+
172
+ generation_args = [
173
+ model_dropdown,
174
+ concept_dropdown,
175
+ guidance_scale,
176
+ min_guidance_layer,
177
+ max_guidance_layer,
178
+ temperature,
179
+ repetition_penalty,
180
+ length_penalty,
181
+ ]
182
+
183
+ model_dropdown.change(set_defaults, [model_dropdown], [model_dropdown, concept_dropdown, guidance_scale, min_guidance_layer, max_guidance_layer], queue=False)
184
+
185
+ submit_triggers = [prompt_field.submit, self.submit_btn.click]
186
+ submit_event = gr.on(
187
+ submit_triggers, self.clear_and_save_input, [prompt_field], [prompt_field, saved_input], queue=False
188
+ ).then(
189
+ add_user_prompt, [saved_input, chatbot], [chatbot], queue=False
190
+ ).then(
191
+ generate_completion,
192
+ [chatbot] + generation_args,
193
+ [chatbot],
194
+ concurrency_limit=1,
195
+ )
196
+ self.setup_stop_events(submit_triggers, submit_event)
197
+
198
+ retry_triggers = [self.retry_btn.click]
199
+ retry_event = gr.on(
200
+ retry_triggers, self.delete_prev_message, [chatbot], [chatbot, saved_input], queue=False
201
+ ).then(
202
+ add_user_prompt, [saved_input, chatbot], [chatbot], queue=False
203
+ ).then(
204
+ generate_completion,
205
+ [chatbot] + generation_args,
206
+ [chatbot],
207
+ concurrency_limit=1,
208
+ )
209
+ self.setup_stop_events(retry_triggers, retry_event)
210
+
211
+ self.undo_btn.click(
212
+ self.delete_prev_message, [chatbot], [chatbot, saved_input], queue=False
213
+ ).then(
214
+ lambda x: x, [saved_input], [prompt_field]
215
+ )
216
+ self.clear_btn.click(lambda: [None, None], None, [chatbot, saved_input], queue=False)
217
+
218
+ def clear_and_save_input(self, message):
219
+ return "", message
220
+
221
+ def delete_prev_message(self, history):
222
+ message, _ = history.pop()
223
+ return history, message or ""
224
+
225
+ def setup_stop_events(self, event_triggers, event_to_cancel):
226
+ if self.submit_btn:
227
+ for event_trigger in event_triggers:
228
+ event_trigger(
229
+ lambda: (
230
+ gr.Button(visible=False),
231
+ gr.Button(visible=True),
232
+ ),
233
+ None,
234
+ [self.submit_btn, self.stop_btn],
235
+ show_api=False,
236
+ queue=False,
237
+ )
238
+ event_to_cancel.then(
239
+ lambda: (gr.Button(visible=True), gr.Button(visible=False)),
240
+ None,
241
+ [self.submit_btn, self.stop_btn],
242
+ show_api=False,
243
+ queue=False,
244
+ )
245
+
246
+ self.stop_btn.click(
247
+ None,
248
+ None,
249
+ None,
250
+ cancels=event_to_cancel,
251
+ show_api=False,
252
+ )
253
+
254
+ css = """
255
+ #concept-guidance-container {
256
+ flex-grow: 1;
257
+ }
258
+ """.strip()
259
+
260
+ with gr.Blocks(title="Concept Guidance", fill_height=True, css=css) as demo:
261
+ ConceptGuidanceUI()
262
+
263
+ demo.queue()
264
+ if __name__ == "__main__":
265
+ parser = argparse.ArgumentParser()
266
+ parser.add_argument("--share", action="store_true")
267
+ args = parser.parse_args()
268
+ demo.launch(share=args.share)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==2.1.2
2
+ transformers==4.37.2
3
+ datasets==2.16.1
4
+ accelerate==0.25.0
5
+ safetensors==0.4.2
6
+ concept-guidance @ git+https://github.com/dvruette/concept-guidance.git
trained_concepts/Llama-2-7b-chat-hf/compliance.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Llama-2-7b-chat-hf/creativity.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Llama-2-7b-chat-hf/humor.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Llama-2-7b-chat-hf/quality.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Llama-2-7b-chat-hf/truthfulness.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Llama-2-7b-hf/compliance.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Llama-2-7b-hf/creativity.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Llama-2-7b-hf/humor.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Llama-2-7b-hf/quality.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Llama-2-7b-hf/truthfulness.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Mistral-7B-Instruct-v0.1/compliance.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Mistral-7B-Instruct-v0.1/creativity.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Mistral-7B-Instruct-v0.1/humor.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Mistral-7B-Instruct-v0.1/quality.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Mistral-7B-Instruct-v0.1/truthfulness.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Mistral-7B-v0.1/compliance.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Mistral-7B-v0.1/creativity.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Mistral-7B-v0.1/humor.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Mistral-7B-v0.1/quality.safetensors ADDED
Binary file (525 kB). View file
 
trained_concepts/Mistral-7B-v0.1/truthfulness.safetensors ADDED
Binary file (525 kB). View file