Spaces:
Sleeping
Sleeping
new style chatbot, PDF support (taken from OAI chatbot)
Browse files- app.py +8 -77
- llm.py +100 -17
- requirements.txt +3 -2
app.py
CHANGED
@@ -3,49 +3,12 @@ import json
|
|
3 |
import os
|
4 |
import boto3
|
5 |
|
6 |
-
from doc2json import process_docx
|
7 |
from settings_mgr import generate_download_settings_js, generate_upload_settings_js
|
8 |
-
from llm import LLM, log_to_console
|
9 |
from botocore.config import Config
|
10 |
|
11 |
dump_controls = False
|
12 |
|
13 |
-
def add_text(history, text):
|
14 |
-
if text:
|
15 |
-
history = history + [(text, None)]
|
16 |
-
return history, gr.Textbox(value="", interactive=False)
|
17 |
-
|
18 |
-
|
19 |
-
def add_file(history, file):
|
20 |
-
if file.name.endswith(".docx"):
|
21 |
-
content = process_docx(file.name)
|
22 |
-
else:
|
23 |
-
with open(file.name, mode="rb") as f:
|
24 |
-
content = f.read()
|
25 |
-
|
26 |
-
if isinstance(content, bytes):
|
27 |
-
content = content.decode('utf-8', 'replace')
|
28 |
-
else:
|
29 |
-
content = str(content)
|
30 |
-
|
31 |
-
fn = os.path.basename(file.name)
|
32 |
-
history = history + [(f'```{fn}\n{content}\n```', None)]
|
33 |
-
|
34 |
-
return history
|
35 |
-
|
36 |
-
def add_img(history, files):
|
37 |
-
for file in files:
|
38 |
-
if log_to_console:
|
39 |
-
print(f"add_img {file.name}")
|
40 |
-
history = history + [(image_embed_prefix + file.name, None)]
|
41 |
-
|
42 |
-
gr.Info(f"Image added as {file.name}")
|
43 |
-
|
44 |
-
return history
|
45 |
-
|
46 |
-
def submit_text(txt_value):
|
47 |
-
return add_text([chatbot, txt_value], [chatbot, txt_value])
|
48 |
-
|
49 |
def undo(history):
|
50 |
history.pop()
|
51 |
return history
|
@@ -92,14 +55,12 @@ def bot(message, history, aws_access, aws_secret, aws_token, system_prompt, temp
|
|
92 |
response = br.invoke_model(body=body, modelId=f"{model}",
|
93 |
accept="application/json", contentType="application/json")
|
94 |
response_body = json.loads(response.get('body').read())
|
95 |
-
|
96 |
-
|
97 |
-
history[-1][1] = br_result
|
98 |
|
99 |
except Exception as e:
|
100 |
raise gr.Error(f"Error: {str(e)}")
|
101 |
|
102 |
-
return
|
103 |
|
104 |
def import_history(history, file):
|
105 |
with open(file.name, mode="rb") as f:
|
@@ -186,34 +147,11 @@ with gr.Blocks() as demo:
|
|
186 |
dl_settings_button.click(None, controls, js=generate_download_settings_js("amz_chat_settings.bin", control_ids))
|
187 |
ul_settings_button.click(None, None, None, js=generate_upload_settings_js(control_ids))
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
)
|
195 |
-
|
196 |
-
with gr.Row():
|
197 |
-
txt = gr.TextArea(
|
198 |
-
scale=4,
|
199 |
-
show_label=False,
|
200 |
-
placeholder="Enter text and press enter, or upload a file",
|
201 |
-
container=False,
|
202 |
-
lines=3,
|
203 |
-
)
|
204 |
-
submit_btn = gr.Button("🚀 Send", scale=0)
|
205 |
-
submit_click = submit_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
206 |
-
bot, [txt, chatbot, aws_access, aws_secret, aws_token, system_prompt, temp, max_tokens, model, region], [txt, chatbot],
|
207 |
-
)
|
208 |
-
submit_click.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
209 |
-
|
210 |
-
with gr.Row():
|
211 |
-
btn = gr.UploadButton("📁 Upload", size="sm")
|
212 |
-
img_btn = gr.UploadButton("🖼️ Upload", size="sm", file_count="multiple", file_types=["image"])
|
213 |
-
undo_btn = gr.Button("↩️ Undo")
|
214 |
-
undo_btn.click(undo, inputs=[chatbot], outputs=[chatbot])
|
215 |
-
|
216 |
-
clear = gr.ClearButton(chatbot, value="🗑️ Clear")
|
217 |
|
218 |
if dump_controls:
|
219 |
with gr.Row():
|
@@ -273,11 +211,4 @@ with gr.Blocks() as demo:
|
|
273 |
""")
|
274 |
import_button.upload(import_history, inputs=[chatbot, import_button], outputs=[chatbot, system_prompt])
|
275 |
|
276 |
-
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
277 |
-
bot, [txt, chatbot, aws_access, aws_secret, aws_token, system_prompt, temp, max_tokens, model, region], [txt, chatbot],
|
278 |
-
)
|
279 |
-
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
280 |
-
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False, postprocess=False)
|
281 |
-
img_msg = img_btn.upload(add_img, [chatbot, img_btn], [chatbot], queue=False, postprocess=False)
|
282 |
-
|
283 |
demo.queue().launch()
|
|
|
3 |
import os
|
4 |
import boto3
|
5 |
|
|
|
6 |
from settings_mgr import generate_download_settings_js, generate_upload_settings_js
|
7 |
+
from llm import LLM, log_to_console
|
8 |
from botocore.config import Config
|
9 |
|
10 |
dump_controls = False
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def undo(history):
|
13 |
history.pop()
|
14 |
return history
|
|
|
55 |
response = br.invoke_model(body=body, modelId=f"{model}",
|
56 |
accept="application/json", contentType="application/json")
|
57 |
response_body = json.loads(response.get('body').read())
|
58 |
+
result = llm.read_response(response_body)
|
|
|
|
|
59 |
|
60 |
except Exception as e:
|
61 |
raise gr.Error(f"Error: {str(e)}")
|
62 |
|
63 |
+
return result
|
64 |
|
65 |
def import_history(history, file):
|
66 |
with open(file.name, mode="rb") as f:
|
|
|
147 |
dl_settings_button.click(None, controls, js=generate_download_settings_js("amz_chat_settings.bin", control_ids))
|
148 |
ul_settings_button.click(None, None, None, js=generate_upload_settings_js(control_ids))
|
149 |
|
150 |
+
chat = gr.ChatInterface(fn=bot, multimodal=True, additional_inputs=controls, retry_btn = None, autofocus = False)
|
151 |
+
chat.textbox.file_count = "multiple"
|
152 |
+
chatbot = chat.chatbot
|
153 |
+
chatbot.show_copy_button = True
|
154 |
+
chatbot.height = 350
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
if dump_controls:
|
157 |
with gr.Row():
|
|
|
211 |
""")
|
212 |
import_button.upload(import_history, inputs=[chatbot, import_button], outputs=[chatbot, system_prompt])
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
demo.queue().launch()
|
llm.py
CHANGED
@@ -1,12 +1,51 @@
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
from typing import Type, TypeVar
|
3 |
import base64
|
|
|
4 |
import json
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# constants
|
7 |
-
image_embed_prefix = "🖼️🆙 "
|
8 |
log_to_console = False
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def encode_image(image_data):
|
11 |
"""Generates a prefix for image base64 data in the required format for the
|
12 |
four known image formats: png, jpeg, gif, and webp.
|
@@ -42,6 +81,38 @@ def encode_image(image_data):
|
|
42 |
"media_type": "image/" + image_type,
|
43 |
"data": base64.b64encode(image_data).decode('utf-8')}
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
LLMClass = TypeVar('LLMClass', bound='LLM')
|
46 |
class LLM(ABC):
|
47 |
@abstractmethod
|
@@ -68,26 +139,25 @@ class Claude(LLM):
|
|
68 |
user_msg_parts = []
|
69 |
for human, assi in history:
|
70 |
if human:
|
71 |
-
if human
|
72 |
-
|
73 |
-
content = f.read()
|
74 |
-
user_msg_parts.append({"type": "image",
|
75 |
-
"source": encode_image(content)})
|
76 |
else:
|
77 |
user_msg_parts.append({"type": "text", "text": human})
|
78 |
|
79 |
-
if assi:
|
80 |
if user_msg_parts:
|
81 |
history_claude_format.append({"role": "user", "content": user_msg_parts})
|
82 |
user_msg_parts = []
|
83 |
|
84 |
history_claude_format.append({"role": "assistant", "content": assi})
|
85 |
|
86 |
-
if message:
|
87 |
-
user_msg_parts.append({"type": "text", "text":
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
91 |
|
92 |
if log_to_console:
|
93 |
print(f"br_prompt: {str(history_claude_format)}")
|
@@ -111,12 +181,25 @@ class Mistral(LLM):
|
|
111 |
def generate_body(message, history, system_prompt, temperature, max_tokens):
|
112 |
prompt = "<s>"
|
113 |
for human, assi in history:
|
114 |
-
if
|
115 |
-
|
|
|
|
|
|
|
|
|
116 |
if assi is not None:
|
117 |
-
prompt += f"{assi}</s
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
if log_to_console:
|
122 |
print(f"br_prompt: {str(prompt)}")
|
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
from typing import Type, TypeVar
|
3 |
import base64
|
4 |
+
import os
|
5 |
import json
|
6 |
+
from doc2json import process_docx
|
7 |
+
import fitz
|
8 |
+
from PIL import Image
|
9 |
+
import io
|
10 |
|
11 |
# constants
|
|
|
12 |
log_to_console = False
|
13 |
|
14 |
+
def process_pdf_img(pdf_fn: str):
|
15 |
+
pdf = fitz.open(pdf_fn)
|
16 |
+
message_parts = []
|
17 |
+
|
18 |
+
for page in pdf.pages():
|
19 |
+
# Create a transformation matrix for rendering at the calculated scale
|
20 |
+
mat = fitz.Matrix(0.6, 0.6)
|
21 |
+
|
22 |
+
# Render the page to a pixmap
|
23 |
+
pix = page.get_pixmap(matrix=mat, alpha=False)
|
24 |
+
|
25 |
+
# Convert pixmap to PIL Image
|
26 |
+
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
27 |
+
|
28 |
+
# Convert PIL Image to bytes
|
29 |
+
img_byte_arr = io.BytesIO()
|
30 |
+
img.save(img_byte_arr, format='PNG')
|
31 |
+
img_byte_arr = img_byte_arr.getvalue()
|
32 |
+
|
33 |
+
# Encode image to base64
|
34 |
+
base64_encoded = base64.b64encode(img_byte_arr).decode('utf-8')
|
35 |
+
|
36 |
+
# Append the message part
|
37 |
+
message_parts.append({
|
38 |
+
"type": "text",
|
39 |
+
"text": f"Page {page.number} of file '{pdf_fn}'"
|
40 |
+
})
|
41 |
+
message_parts.append({"type": "image", "source": {"type": "base64",
|
42 |
+
"media_type": "image/png",
|
43 |
+
"data": base64_encoded}})
|
44 |
+
|
45 |
+
pdf.close()
|
46 |
+
|
47 |
+
return message_parts
|
48 |
+
|
49 |
def encode_image(image_data):
|
50 |
"""Generates a prefix for image base64 data in the required format for the
|
51 |
four known image formats: png, jpeg, gif, and webp.
|
|
|
81 |
"media_type": "image/" + image_type,
|
82 |
"data": base64.b64encode(image_data).decode('utf-8')}
|
83 |
|
84 |
+
def encode_file(fn: str) -> list:
|
85 |
+
user_msg_parts = []
|
86 |
+
|
87 |
+
if fn.endswith(".docx"):
|
88 |
+
user_msg_parts.append({"type": "text", "text": process_docx(fn)})
|
89 |
+
elif fn.endswith(".pdf"):
|
90 |
+
user_msg_parts.extend(process_pdf_img(fn))
|
91 |
+
else:
|
92 |
+
with open(fn, mode="rb") as f:
|
93 |
+
content = f.read()
|
94 |
+
|
95 |
+
isImage = False
|
96 |
+
if isinstance(content, bytes):
|
97 |
+
try:
|
98 |
+
# try to add as image
|
99 |
+
content = encode_image(content)
|
100 |
+
isImage = True
|
101 |
+
except:
|
102 |
+
# not an image, try text
|
103 |
+
content = content.decode('utf-8', 'replace')
|
104 |
+
else:
|
105 |
+
content = str(content)
|
106 |
+
|
107 |
+
if isImage:
|
108 |
+
user_msg_parts.append({"type": "image",
|
109 |
+
"source": content})
|
110 |
+
else:
|
111 |
+
fname = os.path.basename(fn)
|
112 |
+
user_msg_parts.append({"type": "text", "text": f"``` {fname}\n{content}\n```"})
|
113 |
+
|
114 |
+
return user_msg_parts
|
115 |
+
|
116 |
LLMClass = TypeVar('LLMClass', bound='LLM')
|
117 |
class LLM(ABC):
|
118 |
@abstractmethod
|
|
|
139 |
user_msg_parts = []
|
140 |
for human, assi in history:
|
141 |
if human:
|
142 |
+
if type(human) is tuple:
|
143 |
+
user_msg_parts.extend(encode_file(human[0]))
|
|
|
|
|
|
|
144 |
else:
|
145 |
user_msg_parts.append({"type": "text", "text": human})
|
146 |
|
147 |
+
if assi is not None:
|
148 |
if user_msg_parts:
|
149 |
history_claude_format.append({"role": "user", "content": user_msg_parts})
|
150 |
user_msg_parts = []
|
151 |
|
152 |
history_claude_format.append({"role": "assistant", "content": assi})
|
153 |
|
154 |
+
if message['text']:
|
155 |
+
user_msg_parts.append({"type": "text", "text": message['text']})
|
156 |
+
if message['files']:
|
157 |
+
for file in message['files']:
|
158 |
+
user_msg_parts.extend(encode_file(file['path']))
|
159 |
+
history_claude_format.append({"role": "user", "content": user_msg_parts})
|
160 |
+
user_msg_parts = []
|
161 |
|
162 |
if log_to_console:
|
163 |
print(f"br_prompt: {str(history_claude_format)}")
|
|
|
181 |
def generate_body(message, history, system_prompt, temperature, max_tokens):
|
182 |
prompt = "<s>"
|
183 |
for human, assi in history:
|
184 |
+
if human:
|
185 |
+
if type(human) is tuple:
|
186 |
+
prompt += f"[INST] {encode_file(human[0])} [/INST]"
|
187 |
+
else:
|
188 |
+
prompt += f"[INST] {human} [/INST]"
|
189 |
+
|
190 |
if assi is not None:
|
191 |
+
prompt += f"{assi}</s>"
|
192 |
+
|
193 |
+
if message['text'] or message['files']:
|
194 |
+
prompt += "[INST] "
|
195 |
+
|
196 |
+
if message['text']:
|
197 |
+
prompt += message['text']
|
198 |
+
if message['files']:
|
199 |
+
for file in message['files']:
|
200 |
+
prompt += f"{encode_file(file['path'])}\n"
|
201 |
+
|
202 |
+
prompt += " [/INST]"
|
203 |
|
204 |
if log_to_console:
|
205 |
print(f"br_prompt: {str(prompt)}")
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
gradio>=4.1
|
2 |
langchain
|
3 |
boto3>1.34.54
|
4 |
-
lxml
|
|
|
|
1 |
+
gradio>=4.38.1
|
2 |
langchain
|
3 |
boto3>1.34.54
|
4 |
+
lxml
|
5 |
+
PyMuPDF
|