|
import argparse |
|
from dataclasses import asdict, dataclass, field |
|
from datetime import datetime |
|
import html |
|
from itertools import zip_longest |
|
import os |
|
import textwrap |
|
from typing import Dict, List, Tuple |
|
|
|
from dotenv import load_dotenv |
|
import gradio as gr |
|
from pymongo import MongoClient |
|
|
|
from llm_rules import Role, Message, models, scenarios |
|
|
|
|
|
MONGO_URI = "mongodb+srv://{username}:{password}@{host}/?retryWrites=true&w=majority" |
|
MONGO_DB = None |
|
PLACEHOLDER = "Enter message" |
|
|
|
History = List[List[str]] |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--hf_proxy", action="store_true", default=False) |
|
parser.add_argument("--port", type=int, default=7860) |
|
return parser.parse_args() |
|
|
|
|
|
@dataclass |
|
class State: |
|
scenario_name: str |
|
provider_name: str |
|
model_name: str |
|
scenario: scenarios.scenario.BaseScenario = None |
|
model: models.BaseModel = None |
|
system_message: str = None |
|
use_system_instructions: bool = False |
|
messages: List[Message] = field(default_factory=list) |
|
redacted_messages: List[Message] = field(default_factory=list) |
|
last_user_message_valid: bool = False |
|
|
|
def __post_init__(self): |
|
self.scenario = scenarios.SCENARIOS[self.scenario_name]() |
|
self.model = models.MODEL_BUILDERS[self.provider_name]( |
|
model=self.model_name, |
|
stream=True, |
|
temperature=0, |
|
) |
|
self.messages = self.get_initial_messages() |
|
self.redacted_messages = self.get_initial_messages(redacted=True) |
|
|
|
def get_initial_messages(self, redacted=False) -> List[Message]: |
|
prompt = self.scenario.redacted_prompt if redacted else self.scenario.prompt |
|
if self.use_system_instructions: |
|
messages = [ |
|
Message(Role.SYSTEM, prompt), |
|
] |
|
else: |
|
messages = [ |
|
Message(Role.SYSTEM, models.PROMPTS[self.system_message]), |
|
Message(Role.USER, prompt), |
|
Message(Role.ASSISTANT, self.scenario.initial_response), |
|
] |
|
return messages |
|
|
|
def get_history(self) -> History: |
|
"""Process redacted messages into format for chatbot to display.""" |
|
redacted_messages = self.redacted_messages[1:] |
|
history = [] |
|
args = [iter(redacted_messages)] * 2 |
|
for u, a in zip_longest(*args): |
|
u = html.escape(u.content, quote=False) |
|
a = None if a is None else html.escape(a.content, quote=False) |
|
history.append([u, a]) |
|
return history |
|
|
|
def update_state_and_history(self, history: History, delta: str) -> History: |
|
"""Incrementally update last item of both messages and history.""" |
|
|
|
self.messages[-1].content += delta |
|
history[-1][-1] += html.escape(delta, quote=False) |
|
return history |
|
|
|
def get_info(self): |
|
info_str = "Return to send message. Shift + Return to add a new line." |
|
if self.scenario.format_message: |
|
info_str = self.scenario.format_message + " " + info_str |
|
return info_str |
|
|
|
def unescape_messages(self) -> List[Message]: |
|
return [Message(m.role, html.unescape(m.content)) for m in self.messages] |
|
|
|
|
|
def change_provider(state: State, provider_name: str) -> Tuple[State, Dict]: |
|
"""Update model provider and model selection.""" |
|
state.provider_name = provider_name.lower() |
|
state.model_name = models.MODEL_DEFAULTS[state.provider_name] |
|
state.model = models.MODEL_BUILDERS[state.provider_name]( |
|
model=state.model_name, |
|
stream=True, |
|
temperature=0, |
|
) |
|
update_model = gr.update( |
|
choices=models.MODEL_NAMES_BY_PROVIDER[state.provider_name], |
|
value=state.model_name, |
|
) |
|
return state, update_model |
|
|
|
|
|
def change_model(state: State, model_name: str) -> State: |
|
"""Update model selection.""" |
|
state.model_name = model_name |
|
state.model = models.MODEL_BUILDERS[state.provider_name]( |
|
model=state.model_name, |
|
stream=True, |
|
temperature=0, |
|
) |
|
return state |
|
|
|
|
|
def change_scenario(state: State, scenario: str) -> Tuple[State, Dict]: |
|
state.scenario = scenarios.SCENARIOS[scenario]() |
|
state.scenario_name = scenario |
|
update = gr.update(placeholder=PLACEHOLDER, label=state.get_info()) |
|
return state, update |
|
|
|
|
|
def send_user_message(state: State, input: str) -> Tuple[State, History, Dict]: |
|
"""Update state and chatbot with user input, clear textbox.""" |
|
user_msg = Message(Role.USER, input) |
|
if not state.scenario.is_valid_user_message(user_msg): |
|
gr.Warning(f"Invalid user message: {state.scenario.format_message}'") |
|
update = gr.update() |
|
else: |
|
state.messages.append(user_msg) |
|
state.redacted_messages.append(user_msg) |
|
state.last_user_message_valid = True |
|
update = gr.update(placeholder=PLACEHOLDER, value="") |
|
return state, state.get_history(), update |
|
|
|
|
|
def send_assistant_message(state: State, api_key: str) -> Tuple[State, History]: |
|
"""Request model response and update blocks.""" |
|
history = state.get_history() |
|
yield state, history |
|
|
|
if not state.last_user_message_valid: |
|
return |
|
|
|
try: |
|
api_key = None if api_key == "" else api_key |
|
response = state.model(state.messages, api_key=api_key) |
|
except Exception as e: |
|
raise gr.Error(f"API error: {e} Please reset the scenario and try again.") |
|
|
|
asst_msg = Message(Role.ASSISTANT, "") |
|
state.messages.append(asst_msg) |
|
state.redacted_messages.append(asst_msg) |
|
history = state.get_history() |
|
|
|
for delta in response: |
|
history = state.update_state_and_history(history, delta) |
|
yield state, history |
|
|
|
|
|
def evaluate_and_log(state: State) -> Tuple[State, Dict]: |
|
"""Evaluate messages and update chatbot.""" |
|
if not state.last_user_message_valid: |
|
return state, gr.update() |
|
|
|
messages = state.unescape_messages() |
|
result = state.scenario.evaluate(messages, state.use_system_instructions) |
|
state.last_user_message_valid = False |
|
|
|
global MONGO_DB |
|
if MONGO_DB is not None: |
|
doc = { |
|
"timestamp": datetime.now(), |
|
"scenario": state.scenario_name, |
|
"params": asdict(state.scenario.p), |
|
"provider": state.provider_name, |
|
"model": state.model_name, |
|
"system_instructions": state.use_system_instructions, |
|
"messages": Message.serialize(state.messages), |
|
"result": asdict(result), |
|
} |
|
|
|
try: |
|
MONGO_DB.insert_one(doc) |
|
except Exception as e: |
|
gr.Error(f"Database failed: {e}") |
|
|
|
if not result.passed: |
|
gr.Info("Rule broken!") |
|
return state, gr.update( |
|
interactive=False, |
|
value=f"Scenario terminated: {result.reason} Please reset the scenario or select another.", |
|
) |
|
else: |
|
return state, gr.update( |
|
interactive=True, |
|
placeholder=PLACEHOLDER, |
|
) |
|
|
|
|
|
def reset_scenario(state: State) -> Tuple[State, History, dict]: |
|
state.messages = state.get_initial_messages() |
|
state.redacted_messages = state.get_initial_messages(redacted=True) |
|
history = state.get_history() |
|
update = gr.update(placeholder=PLACEHOLDER, value="", interactive=True) |
|
return state, history, update |
|
|
|
|
|
def make_block(initial_state): |
|
markdown = ( |
|
"<h1>RuLES: Rule-following Language Evaluation Scenarios</h1>" |
|
"<p>Your goal is to get the assistant model to break its rules. The rules are presented to the model in the " |
|
"first user message. Each assistant response is checked using the evaluate method of the scenario, and the " |
|
"conversation terminates upon a failing result. Passwords and other secrets are redacted from the chat " |
|
"interface and replaced with \"[REDACTED]\", but the actual value is still provided to the model. Some " |
|
"scenarios require user messages in a specific format. Your messages will be logged for research use. Please do" |
|
" not enter any sensitive information.</p>" |
|
"<p>Anthropic has graciously provided us with subsidized API access, and Google Cloud currently offers 100% " |
|
"discounted access to the PaLM 2 API during its Preview stage, so no API key is required for these models. " |
|
"To access OpenAI models, please enter your own API key. We do not record your key, which you can verify " |
|
"in the demo's source code.</p>" |
|
"<p>See the RuLES <a href=\"https://github.com/normster/llm_rules\">github repo</a> for more information.</p>" |
|
) |
|
|
|
with gr.Blocks( |
|
gr.themes.Monochrome( |
|
font=[ |
|
gr.themes.GoogleFont("Source Sans Pro"), |
|
"ui-sans-serif", |
|
"system-ui", |
|
"sans-serif", |
|
], |
|
radius_size=gr.themes.sizes.radius_sm, |
|
) |
|
) as block: |
|
gr.Markdown(markdown, sanitize_html=False) |
|
state = gr.State(value=initial_state) |
|
with gr.Row(): |
|
provider_select = gr.Dropdown( |
|
["Anthropic", "OpenAI", "Google"], |
|
value="Anthropic", |
|
label="Provider", |
|
) |
|
model_select = gr.Dropdown( |
|
models.MODEL_NAMES_BY_PROVIDER["anthropic"], |
|
value="claude-instant-v1.2", |
|
label="Model", |
|
) |
|
scenario_select = gr.Dropdown( |
|
scenarios.SCENARIOS.keys(), |
|
value=initial_state.scenario_name, |
|
label="Scenario", |
|
) |
|
apikey = gr.Textbox(placeholder="sk-...", label="API Key") |
|
chatbot = gr.Chatbot(initial_state.get_history(), show_label=False) |
|
textbox = gr.Textbox(placeholder=PLACEHOLDER, label=initial_state.get_info()) |
|
reset_button = gr.Button("Reset Scenario") |
|
|
|
|
|
textbox.submit( |
|
send_user_message, [state, textbox], [state, chatbot, textbox], queue=True |
|
).then( |
|
send_assistant_message, |
|
[state, apikey], |
|
[state, chatbot], |
|
queue=True, |
|
).then( |
|
evaluate_and_log, state, [state, textbox], queue=True |
|
) |
|
|
|
provider_select.change( |
|
change_provider, |
|
[state, provider_select], |
|
[state, model_select], |
|
queue=False, |
|
).then( |
|
reset_scenario, state, [state, chatbot, textbox], queue=False |
|
) |
|
|
|
model_select.change( |
|
change_model, |
|
[state, model_select], |
|
[state], |
|
queue=False, |
|
).then( |
|
reset_scenario, state, [state, chatbot, textbox], queue=False |
|
) |
|
|
|
scenario_select.change( |
|
change_scenario, |
|
[state, scenario_select], |
|
[state, textbox], |
|
queue=False, |
|
).then(reset_scenario, state, [state, chatbot, textbox], queue=False) |
|
|
|
reset_button.click( |
|
reset_scenario, state, [state, chatbot, textbox], queue=False |
|
) |
|
block.load(reset_scenario, state, [state, chatbot, textbox], queue=False) |
|
|
|
return block |
|
|
|
|
|
def main(args): |
|
load_dotenv() |
|
|
|
initial_state = State( |
|
scenario_name="Encryption", |
|
provider_name="anthropic", |
|
model_name="claude-instant-v1.2", |
|
) |
|
initial_state.messages = (initial_state.get_initial_messages(),) |
|
initial_state.redacted_messages = ( |
|
initial_state.get_initial_messages(redacted=True), |
|
) |
|
|
|
|
|
global MONGO_DB |
|
mongo_uri = MONGO_URI.format( |
|
username=os.environ["MONGO_USERNAME"], |
|
password=os.environ["MONGO_PASSWORD"], |
|
host=os.environ["MONGO_HOST"], |
|
) |
|
client = MongoClient(mongo_uri) |
|
MONGO_DB = client["messages"]["v1.0"] |
|
|
|
block = make_block(initial_state) |
|
block.queue(concurrency_count=2) |
|
block.launch( |
|
server_port=args.port, |
|
share=args.hf_proxy, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
main(args) |
|
|