Nils Durner commited on
Commit
910dbfd
·
1 Parent(s): 8489337

Claude-3 and Mistral support

Browse files
Files changed (3) hide show
  1. app.py +21 -37
  2. llm.py +134 -0
  3. requirements.txt +3 -3
app.py CHANGED
@@ -5,10 +5,9 @@ import boto3
5
 
6
  from doc2json import process_docx
7
  from settings_mgr import generate_download_settings_js, generate_upload_settings_js
 
8
 
9
  dump_controls = False
10
- log_to_console = False
11
-
12
 
13
  def add_text(history, text):
14
  history = history + [(text, None)]
@@ -59,26 +58,10 @@ def process_values_js():
59
  }
60
  """
61
 
62
- def bot(message, history, aws_access, aws_secret, aws_token, temperature, max_tokens, model, region):
63
  try:
64
- prompt = "\n\n"
65
- for human, assi in history:
66
- if prompt is not None:
67
- prompt += f"Human: {human}\n\n"
68
- if assi is not None:
69
- prompt += f"Assistant: {assi}\n\n"
70
- if message:
71
- prompt += f"Human: {message}\n\n"
72
- prompt += f"Assistant:"
73
-
74
- if log_to_console:
75
- print(f"br_prompt: {str(prompt)}")
76
-
77
- body = json.dumps({
78
- "prompt": prompt,
79
- "max_tokens_to_sample": max_tokens,
80
- "temperature": temperature,
81
- })
82
 
83
  sess = boto3.Session(
84
  aws_access_key_id=aws_access,
@@ -87,21 +70,18 @@ def bot(message, history, aws_access, aws_secret, aws_token, temperature, max_to
87
  region_name=region)
88
  br = sess.client(service_name="bedrock-runtime")
89
 
90
- response = br.invoke_model(body=body, modelId=f"anthropic.{model}",
91
  accept="application/json", contentType="application/json")
92
  response_body = json.loads(response.get('body').read())
93
- br_result = response_body.get('completion')
94
 
95
  history[-1][1] = br_result
96
- if log_to_console:
97
- print(f"br_result: {str(history)}")
98
 
99
  except Exception as e:
100
  raise gr.Error(f"Error: {str(e)}")
101
 
102
  return "", history
103
 
104
-
105
  def import_history(history, file):
106
  with open(file.name, mode="rb") as f:
107
  content = f.read()
@@ -117,16 +97,18 @@ def import_history(history, file):
117
  return history
118
 
119
  with gr.Blocks() as demo:
120
- gr.Markdown("# Amazon™️ Bedrock™️ Chat™️ (Nils' Version™️) feat. Anthropic™️ Claude-2™️")
121
 
122
  with gr.Accordion("Settings"):
123
  aws_access = gr.Textbox(label="AWS Access Key", elem_id="aws_access")
124
  aws_secret = gr.Textbox(label="AWS Secret Key", elem_id="aws_secret")
125
  aws_token = gr.Textbox(label="AWS Session Token", elem_id="aws_token")
126
- model = gr.Dropdown(label="Model", value="claude-v2:1", allow_custom_value=True, elem_id="model",
127
- choices=["claude-v2:1", "claude-v2"])
128
- region = gr.Dropdown(label="Region", value="eu-central-1", allow_custom_value=True, elem_id="region",
129
- choices=["eu-central-1", "us-east-1", "us-west-1"])
 
 
130
  temp = gr.Slider(0, 1, label="Temperature", elem_id="temp", value=1)
131
  max_tokens = gr.Slider(1, 200000, label="Max. Tokens", elem_id="max_tokens", value=4000)
132
  save_button = gr.Button("Save Settings")
@@ -136,7 +118,7 @@ with gr.Blocks() as demo:
136
 
137
  load_button.click(load_settings, js="""
138
  () => {
139
- let elems = ['#aws_access textarea', '#aws_secret textarea', '#aws_token textarea', '#temp input', '#max_tokens input', '#model', '#region'];
140
  elems.forEach(elem => {
141
  let item = document.querySelector(elem);
142
  let event = new InputEvent('input', { bubbles: true });
@@ -146,11 +128,12 @@ with gr.Blocks() as demo:
146
  }
147
  """)
148
 
149
- save_button.click(save_settings, [aws_access, aws_secret, aws_token, temp, max_tokens, model, region], js="""
150
- (acc, sec, tok, prompt, temp, ntok, model, region) => {
151
  localStorage.setItem('aws_access', acc);
152
  localStorage.setItem('aws_secret', sec);
153
  localStorage.setItem('aws_token', tok);
 
154
  localStorage.setItem('temp', document.querySelector('#temp input').value);
155
  localStorage.setItem('max_tokens', document.querySelector('#max_tokens input').value);
156
  localStorage.setItem('model', model);
@@ -161,11 +144,12 @@ with gr.Blocks() as demo:
161
  control_ids = [('aws_access', '#aws_access textarea'),
162
  ('aws_secret', '#aws_secret textarea'),
163
  ('aws_token', '#aws_token textarea'),
 
164
  ('temp', '#temp input'),
165
  ('max_tokens', '#max_tokens input'),
166
  ('model', '#model'),
167
  ('region', '#region')]
168
- controls = [aws_access, aws_secret, aws_token, temp, max_tokens, model, region]
169
 
170
  dl_settings_button.click(None, controls, js=generate_download_settings_js("amz_chat_settings.bin", control_ids))
171
  ul_settings_button.click(None, None, None, js=generate_upload_settings_js(control_ids))
@@ -187,7 +171,7 @@ with gr.Blocks() as demo:
187
  )
188
  submit_btn = gr.Button("🚀 Send", scale=0)
189
  submit_click = submit_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
190
- bot, [txt, chatbot, aws_access, aws_secret, aws_token, temp, max_tokens, model, region], [txt, chatbot],
191
  )
192
  submit_click.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
193
 
@@ -256,7 +240,7 @@ with gr.Blocks() as demo:
256
  import_button.upload(import_history, inputs=[chatbot, import_button], outputs=[chatbot])
257
 
258
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
259
- bot, [txt, chatbot, aws_access, aws_secret, aws_token, temp, max_tokens, model, region], [txt, chatbot],
260
  )
261
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
262
  file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False, postprocess=False)
 
5
 
6
  from doc2json import process_docx
7
  from settings_mgr import generate_download_settings_js, generate_upload_settings_js
8
+ from llm import LLM
9
 
10
  dump_controls = False
 
 
11
 
12
  def add_text(history, text):
13
  history = history + [(text, None)]
 
58
  }
59
  """
60
 
61
+ def bot(message, history, aws_access, aws_secret, aws_token, system_prompt, temperature, max_tokens, model: str, region):
62
  try:
63
+ llm = LLM.create_llm(model)
64
+ body = llm.generate_body(message, history, system_prompt, temperature, max_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  sess = boto3.Session(
67
  aws_access_key_id=aws_access,
 
70
  region_name=region)
71
  br = sess.client(service_name="bedrock-runtime")
72
 
73
+ response = br.invoke_model(body=body, modelId=f"{model}",
74
  accept="application/json", contentType="application/json")
75
  response_body = json.loads(response.get('body').read())
76
+ br_result = llm.read_response(response_body)
77
 
78
  history[-1][1] = br_result
 
 
79
 
80
  except Exception as e:
81
  raise gr.Error(f"Error: {str(e)}")
82
 
83
  return "", history
84
 
 
85
  def import_history(history, file):
86
  with open(file.name, mode="rb") as f:
87
  content = f.read()
 
97
  return history
98
 
99
  with gr.Blocks() as demo:
100
+ gr.Markdown("# Amazon™️ Bedrock™️ Chat™️ (Nils' Version™️) feat. Mistral™️ AI & Anthropic™️ Claude™️")
101
 
102
  with gr.Accordion("Settings"):
103
  aws_access = gr.Textbox(label="AWS Access Key", elem_id="aws_access")
104
  aws_secret = gr.Textbox(label="AWS Secret Key", elem_id="aws_secret")
105
  aws_token = gr.Textbox(label="AWS Session Token", elem_id="aws_token")
106
+ model = gr.Dropdown(label="Model", value="anthropic.claude-3-sonnet-20240229-v1:0", allow_custom_value=True, elem_id="model",
107
+ choices=["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-v2:1", "anthropic.claude-v2",
108
+ "mistral.mistral-7b-instruct-v0:2", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0"])
109
+ system_prompt = gr.TextArea("You are a helpful AI.", label="System Prompt", lines=3, max_lines=250, elem_id="system_prompt")
110
+ region = gr.Dropdown(label="Region", value="eu-west-3", allow_custom_value=True, elem_id="region",
111
+ choices=["eu-central-1", "eu-west-3", "us-east-1", "us-west-1"])
112
  temp = gr.Slider(0, 1, label="Temperature", elem_id="temp", value=1)
113
  max_tokens = gr.Slider(1, 200000, label="Max. Tokens", elem_id="max_tokens", value=4000)
114
  save_button = gr.Button("Save Settings")
 
118
 
119
  load_button.click(load_settings, js="""
120
  () => {
121
+ let elems = ['#aws_access textarea', '#aws_secret textarea', '#aws_token textarea', '#system_prompt textarea', '#temp input', '#max_tokens input', '#model', '#region'];
122
  elems.forEach(elem => {
123
  let item = document.querySelector(elem);
124
  let event = new InputEvent('input', { bubbles: true });
 
128
  }
129
  """)
130
 
131
+ save_button.click(save_settings, [aws_access, aws_secret, aws_token, system_prompt, temp, max_tokens, model, region], js="""
132
+ (acc, sec, tok, system_prompt, temp, ntok, model, region) => {
133
  localStorage.setItem('aws_access', acc);
134
  localStorage.setItem('aws_secret', sec);
135
  localStorage.setItem('aws_token', tok);
136
+ localStorage.setItem('system_prompt', system_prompt);
137
  localStorage.setItem('temp', document.querySelector('#temp input').value);
138
  localStorage.setItem('max_tokens', document.querySelector('#max_tokens input').value);
139
  localStorage.setItem('model', model);
 
144
  control_ids = [('aws_access', '#aws_access textarea'),
145
  ('aws_secret', '#aws_secret textarea'),
146
  ('aws_token', '#aws_token textarea'),
147
+ ('system_prompt', '#system_prompt textarea'),
148
  ('temp', '#temp input'),
149
  ('max_tokens', '#max_tokens input'),
150
  ('model', '#model'),
151
  ('region', '#region')]
152
+ controls = [aws_access, aws_secret, aws_token, system_prompt, temp, max_tokens, model, region]
153
 
154
  dl_settings_button.click(None, controls, js=generate_download_settings_js("amz_chat_settings.bin", control_ids))
155
  ul_settings_button.click(None, None, None, js=generate_upload_settings_js(control_ids))
 
171
  )
172
  submit_btn = gr.Button("🚀 Send", scale=0)
173
  submit_click = submit_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
174
+ bot, [txt, chatbot, aws_access, aws_secret, aws_token, system_prompt, temp, max_tokens, model, region], [txt, chatbot],
175
  )
176
  submit_click.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
177
 
 
240
  import_button.upload(import_history, inputs=[chatbot, import_button], outputs=[chatbot])
241
 
242
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
243
+ bot, [txt, chatbot, aws_access, aws_secret, aws_token, system_prompt, temp, max_tokens, model, region], [txt, chatbot],
244
  )
245
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
246
  file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False, postprocess=False)
llm.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Type, TypeVar
3
+ import base64
4
+ import json
5
+
6
+ # constants
7
+ image_embed_prefix = "🖼️🆙 "
8
+ log_to_console = False
9
+
10
+ def encode_image(image_data):
11
+ """Generates a prefix for image base64 data in the required format for the
12
+ four known image formats: png, jpeg, gif, and webp.
13
+
14
+ Args:
15
+ image_data: The image data, encoded in base64.
16
+
17
+ Returns:
18
+ An object encoding the image
19
+ """
20
+
21
+ # Get the first few bytes of the image data.
22
+ magic_number = image_data[:4]
23
+
24
+ # Check the magic number to determine the image type.
25
+ if magic_number.startswith(b'\x89PNG'):
26
+ image_type = 'png'
27
+ elif magic_number.startswith(b'\xFF\xD8'):
28
+ image_type = 'jpeg'
29
+ elif magic_number.startswith(b'GIF89a'):
30
+ image_type = 'gif'
31
+ elif magic_number.startswith(b'RIFF'):
32
+ if image_data[8:12] == b'WEBP':
33
+ image_type = 'webp'
34
+ else:
35
+ # Unknown image type.
36
+ raise Exception("Unknown image type")
37
+ else:
38
+ # Unknown image type.
39
+ raise Exception("Unknown image type")
40
+
41
+ return {"type": "base64",
42
+ "media_type": "image/" + image_type,
43
+ "data": base64.b64encode(image_data).decode('utf-8')}
44
+
45
+ LLMClass = TypeVar('LLMClass', bound='LLM')
46
+ class LLM(ABC):
47
+ @abstractmethod
48
+ def generate_body(message, history, system_prompt, temperature, max_tokens):
49
+ pass
50
+
51
+ @abstractmethod
52
+ def read_response(message, history, system_prompt, temperature, max_tokens):
53
+ pass
54
+
55
+ @staticmethod
56
+ def create_llm(model: str) -> Type[LLMClass]:
57
+ if model.startswith("anthropic.claude"):
58
+ return Claude()
59
+ elif model.startswith("mistral."):
60
+ return Mistral()
61
+ else:
62
+ raise ValueError(f"Unsupported model: {model}")
63
+
64
+ class Claude(LLM):
65
+ @staticmethod
66
+ def generate_body(message, history, system_prompt, temperature, max_tokens):
67
+ history_claude_format = []
68
+ user_msg_parts = []
69
+ for human, assi in history:
70
+ if human is not None:
71
+ if human.startswith(image_embed_prefix):
72
+ with open(human.lstrip(image_embed_prefix), mode="rb") as f:
73
+ content = f.read()
74
+ user_msg_parts.append({"type": "image",
75
+ "source": encode_image(content)})
76
+ else:
77
+ user_msg_parts.append({"type": "text", "text": human})
78
+
79
+ if assi is not None:
80
+ if user_msg_parts:
81
+ history_claude_format.append({"role": "user", "content": user_msg_parts})
82
+ user_msg_parts = []
83
+
84
+ history_claude_format.append({"role": "assistant", "content": assi})
85
+
86
+ if message:
87
+ user_msg_parts.append({"type": "text", "text": human})
88
+
89
+ if user_msg_parts:
90
+ history_claude_format.append({"role": "user", "content": user_msg_parts})
91
+
92
+ if log_to_console:
93
+ print(f"br_prompt: {str(history_claude_format)}")
94
+
95
+ body = json.dumps({
96
+ "anthropic_version": "bedrock-2023-05-31",
97
+ "system": system_prompt,
98
+ "max_tokens": max_tokens,
99
+ "temperature": temperature,
100
+ "messages": history_claude_format
101
+ })
102
+
103
+ return body
104
+
105
+ @staticmethod
106
+ def read_response(response_body) -> Type[str]:
107
+ return response_body.get('content')[0].get('text')
108
+
109
+ class Mistral(LLM):
110
+ @staticmethod
111
+ def generate_body(message, history, system_prompt, temperature, max_tokens):
112
+ prompt = "<s>"
113
+ for human, assi in history:
114
+ if prompt is not None:
115
+ prompt += f"[INST] {human} [/INST]\n"
116
+ if assi is not None:
117
+ prompt += f"{assi}</s>\n"
118
+ if message:
119
+ prompt += f"[INST] {message} [/INST]"
120
+
121
+ if log_to_console:
122
+ print(f"br_prompt: {str(prompt)}")
123
+
124
+ body = json.dumps({
125
+ "prompt": prompt,
126
+ "max_tokens": max_tokens,
127
+ "temperature": temperature,
128
+ })
129
+
130
+ return body
131
+
132
+ @staticmethod
133
+ def read_response(response_body) -> Type[str]:
134
+ return response_body.get('outputs')[0].get('text')
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio
2
- langchain
3
- boto3
4
  lxml
 
1
+ gradio
2
+ langchain
3
+ boto3>1.34.54
4
  lxml