concept-guidance / main.py
dvruette's picture
Update main.py
7657346 verified
import argparse
import logging
from threading import Thread
import time
import torch
import gradio as gr
import spaces
from concept_guidance.chat_template import DEFAULT_CHAT_TEMPLATE
from concept_guidance.patching import patch_model, load_weights
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer, Conversation
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device = torch.device("cuda")
# comment in/out the models you want to use
# RAM requirements: ~16GB x #models (+ ~4GB overhead)
# VRAM requirements: ~16GB
# if using int8: ~8GB VRAM x #models, low RAM requirements
MODEL_CONFIGS = {
"Llama-2-7b-chat-hf": {
"identifier": "meta-llama/Llama-2-7b-chat-hf",
"dtype": torch.float16 if device.type == "cuda" else torch.float32,
"load_in_8bit": False,
"guidance_interval": [-16.0, 16.0],
"default_guidance_scale": 8.0,
"min_guidance_layer": 16,
"max_guidance_layer": 32,
"default_concept": "humor",
"concepts": ["humor", "creativity", "quality", "truthfulness", "compliance"],
},
# "Mistral-7B-Instruct-v0.1": {
# "identifier": "mistralai/Mistral-7B-Instruct-v0.1",
# "dtype": torch.bfloat16 if device.type == "cuda" else torch.float32,
# "load_in_8bit": False,
# "guidance_interval": [-128.0, 128.0],
# "default_guidance_scale": 48.0,
# "min_guidance_layer": 8,
# "max_guidance_layer": 32,
# "default_concept": "humor",
# "concepts": ["humor", "creativity", "quality", "truthfulness", "compliance"],
# },
}
def load_concept_vectors(model, concepts):
return {concept: load_weights(f"trained_concepts/{model}/{concept}.safetensors") for concept in concepts}
def load_model(model_name):
config = MODEL_CONFIGS[model_name]
model = AutoModelForCausalLM.from_pretrained(config["identifier"], torch_dtype=config["dtype"], load_in_8bit=config["load_in_8bit"])
tokenizer = AutoTokenizer.from_pretrained(config["identifier"])
if tokenizer.chat_template is None:
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
return model, tokenizer
CONCEPTS = ["humor", "creativity", "quality", "truthfulness", "compliance"]
CONCEPT_VECTORS = {model_name: load_concept_vectors(model_name, CONCEPTS) for model_name in MODEL_CONFIGS}
MODELS = {model_name: load_model(model_name) for model_name in MODEL_CONFIGS}
def history_to_conversation(history):
conversation = Conversation()
for prompt, completion in history:
conversation.add_message({"role": "user", "content": prompt})
if completion is not None:
conversation.add_message({"role": "assistant", "content": completion})
return conversation
def set_defaults(model_name):
config = MODEL_CONFIGS[model_name]
return (
model_name,
gr.update(choices=config["concepts"], value=config["concepts"][0]),
gr.update(minimum=config["guidance_interval"][0], maximum=config["guidance_interval"][1], value=config["default_guidance_scale"]),
gr.update(value=config["min_guidance_layer"]),
gr.update(value=config["max_guidance_layer"]),
)
def add_user_prompt(user_message, history):
if history is None:
history = []
history.append([user_message, None])
return history
@spaces.GPU
@torch.no_grad()
def generate_completion(
history,
model_name,
concept,
guidance_scale=4.0,
min_guidance_layer=16,
max_guidance_layer=32,
temperature=0.0,
repetition_penalty=1.2,
length_penalty=1.2,
):
start_time = time.time()
logger.info(f" --- Starting completion ({model_name}, {concept=}, {guidance_scale=}, {min_guidance_layer=}, {temperature=})")
logger.info(" User: " + repr(history[-1][0]))
# move all other models to CPU
for name, (model, _) in MODELS.items():
if name != model_name:
config = MODEL_CONFIGS[name]
if not config["load_in_8bit"]:
model.to("cpu")
torch.cuda.empty_cache()
# load the model
config = MODEL_CONFIGS[model_name]
model, tokenizer = MODELS[model_name]
if not config["load_in_8bit"]:
model.to(device, non_blocking=True)
concept_vector = CONCEPT_VECTORS[model_name][concept]
guidance_layers = list(range(int(min_guidance_layer) - 1, int(max_guidance_layer)))
patch_model(model, concept_vector, guidance_scale=guidance_scale, guidance_layers=guidance_layers)
pipe = pipeline("conversational", model=model, tokenizer=tokenizer, device=(device if not config["load_in_8bit"] else None))
conversation = history_to_conversation(history)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
max_new_tokens=512,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
streamer=streamer,
temperature=temperature,
do_sample=(temperature > 0)
)
thread = Thread(target=pipe, args=(conversation,), kwargs=generation_kwargs, daemon=True)
thread.start()
history[-1][1] = ""
for token in streamer:
history[-1][1] += token
yield history
logger.info(" Assistant: " + repr(history[-1][1]))
time_taken = time.time() - start_time
logger.info(f" --- Completed (took {time_taken:.1f}s)")
return history
class ConceptGuidanceUI:
def __init__(self):
model_names = list(MODEL_CONFIGS.keys())
default_model = model_names[0]
default_config = MODEL_CONFIGS[default_model]
default_concepts = default_config["concepts"]
default_concept = default_config["default_concept"]
saved_input = gr.State("")
with gr.Row(elem_id="concept-guidance-container"):
with gr.Column(scale=1, min_width=256):
model_dropdown = gr.Dropdown(model_names, value=default_model, label="Model")
concept_dropdown = gr.Dropdown(default_concepts, value=default_concept, label="Concept")
guidance_scale = gr.Slider(*default_config["guidance_interval"], value=default_config["default_guidance_scale"], label="Guidance Scale")
min_guidance_layer = gr.Slider(1.0, 32.0, value=16.0, step=1.0, label="First Guidance Layer")
max_guidance_layer = gr.Slider(1.0, 32.0, value=32.0, step=1.0, label="Last Guidance Layer")
temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Temperature")
repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, step=0.01, label="Repetition Penalty")
length_penalty = gr.Slider(0.0, 2.0, value=1.2, step=0.01, label="Length Penalty")
with gr.Column(scale=3, min_width=512):
chatbot = gr.Chatbot(scale=1, height=200)
with gr.Row():
self.retry_btn = gr.Button("🔄 Retry", size="sm")
self.undo_btn = gr.Button("↩️ Undo", size="sm")
self.clear_btn = gr.Button("🗑️ Clear", size="sm")
with gr.Group():
with gr.Row():
prompt_field = gr.Textbox(placeholder="Type a message...", show_label=False, label="Message", scale=7, container=False)
self.submit_btn = gr.Button("Submit", variant="primary", scale=1, min_width=150)
self.stop_btn = gr.Button("Stop", variant="secondary", scale=1, min_width=150, visible=False)
generation_args = [
model_dropdown,
concept_dropdown,
guidance_scale,
min_guidance_layer,
max_guidance_layer,
temperature,
repetition_penalty,
length_penalty,
]
model_dropdown.change(set_defaults, [model_dropdown], [model_dropdown, concept_dropdown, guidance_scale, min_guidance_layer, max_guidance_layer], queue=False)
submit_triggers = [prompt_field.submit, self.submit_btn.click]
submit_event = gr.on(
submit_triggers, self.clear_and_save_input, [prompt_field], [prompt_field, saved_input], queue=False
).then(
add_user_prompt, [saved_input, chatbot], [chatbot], queue=False
).then(
generate_completion,
[chatbot] + generation_args,
[chatbot],
concurrency_limit=1,
)
self.setup_stop_events(submit_triggers, submit_event)
retry_triggers = [self.retry_btn.click]
retry_event = gr.on(
retry_triggers, self.delete_prev_message, [chatbot], [chatbot, saved_input], queue=False
).then(
add_user_prompt, [saved_input, chatbot], [chatbot], queue=False
).then(
generate_completion,
[chatbot] + generation_args,
[chatbot],
concurrency_limit=1,
)
self.setup_stop_events(retry_triggers, retry_event)
self.undo_btn.click(
self.delete_prev_message, [chatbot], [chatbot, saved_input], queue=False
).then(
lambda x: x, [saved_input], [prompt_field]
)
self.clear_btn.click(lambda: [None, None], None, [chatbot, saved_input], queue=False)
def clear_and_save_input(self, message):
return "", message
def delete_prev_message(self, history):
message, _ = history.pop()
return history, message or ""
def setup_stop_events(self, event_triggers, event_to_cancel):
if self.submit_btn:
for event_trigger in event_triggers:
event_trigger(
lambda: (
gr.Button(visible=False),
gr.Button(visible=True),
),
None,
[self.submit_btn, self.stop_btn],
show_api=False,
queue=False,
)
event_to_cancel.then(
lambda: (gr.Button(visible=True), gr.Button(visible=False)),
None,
[self.submit_btn, self.stop_btn],
show_api=False,
queue=False,
)
self.stop_btn.click(
None,
None,
None,
cancels=event_to_cancel,
show_api=False,
)
css = """
#concept-guidance-container {
flex-grow: 1;
}
""".strip()
with gr.Blocks(title="Concept Guidance", fill_height=True, css=css) as demo:
ConceptGuidanceUI()
demo.queue()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
demo.launch(share=args.share)