ndurner commited on
Commit
d2d66c1
·
1 Parent(s): a30711d

new style chatbot, PDF support (taken from OAI chatbot)

Browse files
Files changed (3) hide show
  1. app.py +8 -77
  2. llm.py +100 -17
  3. requirements.txt +3 -2
app.py CHANGED
@@ -3,49 +3,12 @@ import json
3
  import os
4
  import boto3
5
 
6
- from doc2json import process_docx
7
  from settings_mgr import generate_download_settings_js, generate_upload_settings_js
8
- from llm import LLM, log_to_console, image_embed_prefix
9
  from botocore.config import Config
10
 
11
  dump_controls = False
12
 
13
- def add_text(history, text):
14
- if text:
15
- history = history + [(text, None)]
16
- return history, gr.Textbox(value="", interactive=False)
17
-
18
-
19
- def add_file(history, file):
20
- if file.name.endswith(".docx"):
21
- content = process_docx(file.name)
22
- else:
23
- with open(file.name, mode="rb") as f:
24
- content = f.read()
25
-
26
- if isinstance(content, bytes):
27
- content = content.decode('utf-8', 'replace')
28
- else:
29
- content = str(content)
30
-
31
- fn = os.path.basename(file.name)
32
- history = history + [(f'```{fn}\n{content}\n```', None)]
33
-
34
- return history
35
-
36
- def add_img(history, files):
37
- for file in files:
38
- if log_to_console:
39
- print(f"add_img {file.name}")
40
- history = history + [(image_embed_prefix + file.name, None)]
41
-
42
- gr.Info(f"Image added as {file.name}")
43
-
44
- return history
45
-
46
- def submit_text(txt_value):
47
- return add_text([chatbot, txt_value], [chatbot, txt_value])
48
-
49
  def undo(history):
50
  history.pop()
51
  return history
@@ -92,14 +55,12 @@ def bot(message, history, aws_access, aws_secret, aws_token, system_prompt, temp
92
  response = br.invoke_model(body=body, modelId=f"{model}",
93
  accept="application/json", contentType="application/json")
94
  response_body = json.loads(response.get('body').read())
95
- br_result = llm.read_response(response_body)
96
-
97
- history[-1][1] = br_result
98
 
99
  except Exception as e:
100
  raise gr.Error(f"Error: {str(e)}")
101
 
102
- return "", history
103
 
104
  def import_history(history, file):
105
  with open(file.name, mode="rb") as f:
@@ -186,34 +147,11 @@ with gr.Blocks() as demo:
186
  dl_settings_button.click(None, controls, js=generate_download_settings_js("amz_chat_settings.bin", control_ids))
187
  ul_settings_button.click(None, None, None, js=generate_upload_settings_js(control_ids))
188
 
189
- chatbot = gr.Chatbot(
190
- [],
191
- elem_id="chatbot",
192
- show_copy_button=True,
193
- height=350
194
- )
195
-
196
- with gr.Row():
197
- txt = gr.TextArea(
198
- scale=4,
199
- show_label=False,
200
- placeholder="Enter text and press enter, or upload a file",
201
- container=False,
202
- lines=3,
203
- )
204
- submit_btn = gr.Button("🚀 Send", scale=0)
205
- submit_click = submit_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
206
- bot, [txt, chatbot, aws_access, aws_secret, aws_token, system_prompt, temp, max_tokens, model, region], [txt, chatbot],
207
- )
208
- submit_click.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
209
-
210
- with gr.Row():
211
- btn = gr.UploadButton("📁 Upload", size="sm")
212
- img_btn = gr.UploadButton("🖼️ Upload", size="sm", file_count="multiple", file_types=["image"])
213
- undo_btn = gr.Button("↩️ Undo")
214
- undo_btn.click(undo, inputs=[chatbot], outputs=[chatbot])
215
-
216
- clear = gr.ClearButton(chatbot, value="🗑️ Clear")
217
 
218
  if dump_controls:
219
  with gr.Row():
@@ -273,11 +211,4 @@ with gr.Blocks() as demo:
273
  """)
274
  import_button.upload(import_history, inputs=[chatbot, import_button], outputs=[chatbot, system_prompt])
275
 
276
- txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
277
- bot, [txt, chatbot, aws_access, aws_secret, aws_token, system_prompt, temp, max_tokens, model, region], [txt, chatbot],
278
- )
279
- txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
280
- file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False, postprocess=False)
281
- img_msg = img_btn.upload(add_img, [chatbot, img_btn], [chatbot], queue=False, postprocess=False)
282
-
283
  demo.queue().launch()
 
3
  import os
4
  import boto3
5
 
 
6
  from settings_mgr import generate_download_settings_js, generate_upload_settings_js
7
+ from llm import LLM, log_to_console
8
  from botocore.config import Config
9
 
10
  dump_controls = False
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def undo(history):
13
  history.pop()
14
  return history
 
55
  response = br.invoke_model(body=body, modelId=f"{model}",
56
  accept="application/json", contentType="application/json")
57
  response_body = json.loads(response.get('body').read())
58
+ result = llm.read_response(response_body)
 
 
59
 
60
  except Exception as e:
61
  raise gr.Error(f"Error: {str(e)}")
62
 
63
+ return result
64
 
65
  def import_history(history, file):
66
  with open(file.name, mode="rb") as f:
 
147
  dl_settings_button.click(None, controls, js=generate_download_settings_js("amz_chat_settings.bin", control_ids))
148
  ul_settings_button.click(None, None, None, js=generate_upload_settings_js(control_ids))
149
 
150
+ chat = gr.ChatInterface(fn=bot, multimodal=True, additional_inputs=controls, retry_btn = None, autofocus = False)
151
+ chat.textbox.file_count = "multiple"
152
+ chatbot = chat.chatbot
153
+ chatbot.show_copy_button = True
154
+ chatbot.height = 350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  if dump_controls:
157
  with gr.Row():
 
211
  """)
212
  import_button.upload(import_history, inputs=[chatbot, import_button], outputs=[chatbot, system_prompt])
213
 
 
 
 
 
 
 
 
214
  demo.queue().launch()
llm.py CHANGED
@@ -1,12 +1,51 @@
1
  from abc import ABC, abstractmethod
2
  from typing import Type, TypeVar
3
  import base64
 
4
  import json
 
 
 
 
5
 
6
  # constants
7
- image_embed_prefix = "🖼️🆙 "
8
  log_to_console = False
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def encode_image(image_data):
11
  """Generates a prefix for image base64 data in the required format for the
12
  four known image formats: png, jpeg, gif, and webp.
@@ -42,6 +81,38 @@ def encode_image(image_data):
42
  "media_type": "image/" + image_type,
43
  "data": base64.b64encode(image_data).decode('utf-8')}
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  LLMClass = TypeVar('LLMClass', bound='LLM')
46
  class LLM(ABC):
47
  @abstractmethod
@@ -68,26 +139,25 @@ class Claude(LLM):
68
  user_msg_parts = []
69
  for human, assi in history:
70
  if human:
71
- if human.startswith(image_embed_prefix):
72
- with open(human.lstrip(image_embed_prefix), mode="rb") as f:
73
- content = f.read()
74
- user_msg_parts.append({"type": "image",
75
- "source": encode_image(content)})
76
  else:
77
  user_msg_parts.append({"type": "text", "text": human})
78
 
79
- if assi:
80
  if user_msg_parts:
81
  history_claude_format.append({"role": "user", "content": user_msg_parts})
82
  user_msg_parts = []
83
 
84
  history_claude_format.append({"role": "assistant", "content": assi})
85
 
86
- if message:
87
- user_msg_parts.append({"type": "text", "text": human})
88
-
89
- if user_msg_parts:
90
- history_claude_format.append({"role": "user", "content": user_msg_parts})
 
 
91
 
92
  if log_to_console:
93
  print(f"br_prompt: {str(history_claude_format)}")
@@ -111,12 +181,25 @@ class Mistral(LLM):
111
  def generate_body(message, history, system_prompt, temperature, max_tokens):
112
  prompt = "<s>"
113
  for human, assi in history:
114
- if prompt is not None:
115
- prompt += f"[INST] {human} [/INST]\n"
 
 
 
 
116
  if assi is not None:
117
- prompt += f"{assi}</s>\n"
118
- if message:
119
- prompt += f"[INST] {message} [/INST]"
 
 
 
 
 
 
 
 
 
120
 
121
  if log_to_console:
122
  print(f"br_prompt: {str(prompt)}")
 
1
  from abc import ABC, abstractmethod
2
  from typing import Type, TypeVar
3
  import base64
4
+ import os
5
  import json
6
+ from doc2json import process_docx
7
+ import fitz
8
+ from PIL import Image
9
+ import io
10
 
11
  # constants
 
12
  log_to_console = False
13
 
14
+ def process_pdf_img(pdf_fn: str):
15
+ pdf = fitz.open(pdf_fn)
16
+ message_parts = []
17
+
18
+ for page in pdf.pages():
19
+ # Create a transformation matrix for rendering at the calculated scale
20
+ mat = fitz.Matrix(0.6, 0.6)
21
+
22
+ # Render the page to a pixmap
23
+ pix = page.get_pixmap(matrix=mat, alpha=False)
24
+
25
+ # Convert pixmap to PIL Image
26
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
27
+
28
+ # Convert PIL Image to bytes
29
+ img_byte_arr = io.BytesIO()
30
+ img.save(img_byte_arr, format='PNG')
31
+ img_byte_arr = img_byte_arr.getvalue()
32
+
33
+ # Encode image to base64
34
+ base64_encoded = base64.b64encode(img_byte_arr).decode('utf-8')
35
+
36
+ # Append the message part
37
+ message_parts.append({
38
+ "type": "text",
39
+ "text": f"Page {page.number} of file '{pdf_fn}'"
40
+ })
41
+ message_parts.append({"type": "image", "source": {"type": "base64",
42
+ "media_type": "image/png",
43
+ "data": base64_encoded}})
44
+
45
+ pdf.close()
46
+
47
+ return message_parts
48
+
49
  def encode_image(image_data):
50
  """Generates a prefix for image base64 data in the required format for the
51
  four known image formats: png, jpeg, gif, and webp.
 
81
  "media_type": "image/" + image_type,
82
  "data": base64.b64encode(image_data).decode('utf-8')}
83
 
84
+ def encode_file(fn: str) -> list:
85
+ user_msg_parts = []
86
+
87
+ if fn.endswith(".docx"):
88
+ user_msg_parts.append({"type": "text", "text": process_docx(fn)})
89
+ elif fn.endswith(".pdf"):
90
+ user_msg_parts.extend(process_pdf_img(fn))
91
+ else:
92
+ with open(fn, mode="rb") as f:
93
+ content = f.read()
94
+
95
+ isImage = False
96
+ if isinstance(content, bytes):
97
+ try:
98
+ # try to add as image
99
+ content = encode_image(content)
100
+ isImage = True
101
+ except:
102
+ # not an image, try text
103
+ content = content.decode('utf-8', 'replace')
104
+ else:
105
+ content = str(content)
106
+
107
+ if isImage:
108
+ user_msg_parts.append({"type": "image",
109
+ "source": content})
110
+ else:
111
+ fname = os.path.basename(fn)
112
+ user_msg_parts.append({"type": "text", "text": f"``` {fname}\n{content}\n```"})
113
+
114
+ return user_msg_parts
115
+
116
  LLMClass = TypeVar('LLMClass', bound='LLM')
117
  class LLM(ABC):
118
  @abstractmethod
 
139
  user_msg_parts = []
140
  for human, assi in history:
141
  if human:
142
+ if type(human) is tuple:
143
+ user_msg_parts.extend(encode_file(human[0]))
 
 
 
144
  else:
145
  user_msg_parts.append({"type": "text", "text": human})
146
 
147
+ if assi is not None:
148
  if user_msg_parts:
149
  history_claude_format.append({"role": "user", "content": user_msg_parts})
150
  user_msg_parts = []
151
 
152
  history_claude_format.append({"role": "assistant", "content": assi})
153
 
154
+ if message['text']:
155
+ user_msg_parts.append({"type": "text", "text": message['text']})
156
+ if message['files']:
157
+ for file in message['files']:
158
+ user_msg_parts.extend(encode_file(file['path']))
159
+ history_claude_format.append({"role": "user", "content": user_msg_parts})
160
+ user_msg_parts = []
161
 
162
  if log_to_console:
163
  print(f"br_prompt: {str(history_claude_format)}")
 
181
  def generate_body(message, history, system_prompt, temperature, max_tokens):
182
  prompt = "<s>"
183
  for human, assi in history:
184
+ if human:
185
+ if type(human) is tuple:
186
+ prompt += f"[INST] {encode_file(human[0])} [/INST]"
187
+ else:
188
+ prompt += f"[INST] {human} [/INST]"
189
+
190
  if assi is not None:
191
+ prompt += f"{assi}</s>"
192
+
193
+ if message['text'] or message['files']:
194
+ prompt += "[INST] "
195
+
196
+ if message['text']:
197
+ prompt += message['text']
198
+ if message['files']:
199
+ for file in message['files']:
200
+ prompt += f"{encode_file(file['path'])}\n"
201
+
202
+ prompt += " [/INST]"
203
 
204
  if log_to_console:
205
  print(f"br_prompt: {str(prompt)}")
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- gradio>=4.1
2
  langchain
3
  boto3>1.34.54
4
- lxml
 
 
1
+ gradio>=4.38.1
2
  langchain
3
  boto3>1.34.54
4
+ lxml
5
+ PyMuPDF