Spaces:
Running
Running
from abc import ABC, abstractmethod | |
from typing import Type, TypeVar | |
import base64 | |
import os | |
import json | |
from doc2json import process_docx | |
import fitz | |
from PIL import Image | |
import io | |
import boto3 | |
from botocore.config import Config | |
import re | |
from PIL import Image | |
import io | |
import math | |
import gradio | |
# constants | |
log_to_console = False | |
use_document_message_type = False # AWS document message type usage | |
LLMClass = TypeVar('LLMClass', bound='LLM') | |
class LLM: | |
def create_llm(model: str) -> Type[LLMClass]: | |
return LLM() | |
def generate_body(self, message, history): | |
messages = [] | |
# AWS API requires strict user, assi, user, ... sequence | |
lastTypeHuman = False | |
for human, assi in history: | |
if human: | |
if lastTypeHuman: | |
last_msg = messages.pop() | |
user_msg_parts = last_msg["content"] | |
else: | |
user_msg_parts = [] | |
if isinstance(human, tuple): | |
user_msg_parts.extend(self._process_file(human[0])) | |
elif isinstance(human, gradio.Image): | |
user_msg_parts.extend(self._process_file(human.value["path"])) | |
else: | |
user_msg_parts.extend([{"text": human}]) | |
messages.append({"role": "user", "content": user_msg_parts}) | |
lastTypeHuman = True | |
if assi: | |
messages.append({"role": "assistant", "content": [{"text": assi}]}) | |
lastTypeHuman = False | |
user_msg_parts = [] | |
if message.text: | |
user_msg_parts.append({"text": message.text}) | |
if message.files: | |
for file in message.files: | |
user_msg_parts.extend(self._process_file(file.path)) | |
if user_msg_parts: | |
messages.append({"role": "user", "content": user_msg_parts}) | |
return messages | |
def _process_file(self, file_path): | |
if use_document_message_type and self._is_supported_document_type(file_path): | |
return [self._create_document_message(file_path)] | |
else: | |
return self._encode_file(file_path) | |
def _is_supported_document_type(self, file_path): | |
supported_extensions = ['.pdf', '.csv', '.doc', '.docx', '.xls', '.xlsx', '.html', '.txt', '.md'] | |
return os.path.splitext(file_path)[1].lower() in supported_extensions | |
def _create_document_message(self, file_path): | |
with open(file_path, 'rb') as file: | |
file_content = file.read() | |
file_name = re.sub(r'[^a-zA-Z0-9\s\-\(\)\[\]]', '', os.path.basename(file_path))[:200].strip() or "unnamed_file" | |
file_extension = os.path.splitext(file_path)[1][1:] # Remove the dot | |
return { | |
"document": { | |
"name": file_name, | |
"format": file_extension, | |
"source": { | |
"bytes": file_content | |
} | |
} | |
} | |
def _encode_file(self, fn: str) -> list: | |
if fn.endswith(".docx"): | |
return [{"text": process_docx(fn)}] | |
elif fn.endswith(".pdf"): | |
return self._process_pdf_img(fn) | |
else: | |
with open(fn, mode="rb") as f: | |
content = f.read() | |
if isinstance(content, bytes): | |
try: | |
# try to add as image | |
image_data = self._encode_image(content) | |
return [{"image": image_data}] | |
except: | |
# not an image, try text | |
content = content.decode('utf-8', 'replace') | |
else: | |
content = str(content) | |
fname = os.path.basename(fn) | |
return [{"text": f"``` {fname}\n{content}\n```"}] | |
def _process_pdf_img(self, pdf_fn: str): | |
pdf = fitz.open(pdf_fn) | |
message_parts = [] | |
for page in pdf.pages(): | |
# Create a transformation matrix for rendering at the calculated scale | |
mat = fitz.Matrix(0.6, 0.6) | |
# Render the page to a pixmap | |
pix = page.get_pixmap(matrix=mat, alpha=False) | |
# Convert pixmap to PIL Image | |
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
# Convert PIL Image to bytes | |
img_byte_arr = io.BytesIO() | |
img.save(img_byte_arr, format='PNG') | |
img_byte_arr = img_byte_arr.getvalue() | |
# Append the message parts | |
message_parts.append({"text": f"Page {page.number} of file '{pdf_fn}'"}) | |
message_parts.append({"image": { | |
"format": "png", | |
"source": {"bytes": img_byte_arr} | |
}}) | |
pdf.close() | |
return message_parts | |
def _encode_image(self, image_data): | |
try: | |
# Open the image using Pillow | |
img = Image.open(io.BytesIO(image_data)) | |
original_format = img.format.lower() | |
except IOError: | |
raise Exception("Unknown image type") | |
# check if within the limits for Claude as per https://docs.anthropic.com/en/docs/build-with-claude/vision | |
def calculate_tokens(width, height): | |
return (width * height) / 750 | |
tokens = calculate_tokens(img.width, img.height) | |
long_edge = max(img.width, img.height) | |
format_ok = original_format in ["jpg", "jpeg", "png", "webp"] | |
# Check if the image already meets all requirements | |
if format_ok and (long_edge <= 1568 and tokens <= 1600 and len(image_data) <= 5 * 1024 * 1024): | |
return { | |
"format": original_format, | |
"source": {"bytes": image_data} | |
} | |
# If we need to modify the image, proceed with resizing and/or compression | |
while long_edge > 1568 or tokens > 1600: | |
if long_edge > 1568: | |
scale_factor = max(1568 / long_edge, 0.9) | |
else: | |
scale_factor = max(math.sqrt(1600 / tokens), 0.9) | |
new_width = int(img.width * scale_factor) | |
new_height = int(img.height * scale_factor) | |
img = img.resize((new_width, new_height), Image.LANCZOS) | |
long_edge = max(new_width, new_height) | |
tokens = calculate_tokens(new_width, new_height) | |
# Try to save in original format first | |
buffer = io.BytesIO() | |
img.save(buffer, format="webp", quality=95) | |
image_data = buffer.getvalue() | |
# If the image is still too large, switch to WebP and compress | |
if len(image_data) > 5 * 1024 * 1024: | |
quality = 95 | |
while len(image_data) > 5 * 1024 * 1024: | |
quality = max(int(quality * 0.9), 20) | |
buffer = io.BytesIO() | |
img.save(buffer, format="webp", quality=quality) | |
image_data = buffer.getvalue() | |
if quality == 20: | |
# If we've reached quality 20 and it's still too large, resize | |
scale_factor = 0.9 | |
new_width = int(img.width * scale_factor) | |
new_height = int(img.height * scale_factor) | |
img = img.resize((new_width, new_height), Image.LANCZOS) | |
quality = 95 # Reset quality for the resized image | |
return { | |
"format": "webp", | |
"source": {"bytes": image_data} | |
} | |
def read_response(self, response_stream): | |
for event in response_stream: | |
if 'contentBlockDelta' in event: | |
yield event['contentBlockDelta']['delta']['text'] | |
if 'messageStop' in event: | |
if log_to_console: | |
print(f"\nStop reason: {event['messageStop']['stopReason']}") | |
if 'metadata' in event: | |
metadata = event['metadata'] | |
if 'usage' in metadata and log_to_console: | |
print("\nToken usage:") | |
print(f"Input tokens: {metadata['usage']['inputTokens']}") | |
print(f"Output tokens: {metadata['usage']['outputTokens']}") | |
print(f"Total tokens: {metadata['usage']['totalTokens']}") |