ndurner commited on
Commit
4e362cd
·
1 Parent(s): 0211c96

resize images if too large

Browse files
Files changed (1) hide show
  1. llm.py +63 -20
llm.py CHANGED
@@ -10,6 +10,9 @@ import io
10
  import boto3
11
  from botocore.config import Config
12
  import re
 
 
 
13
 
14
  # constants
15
  log_to_console = False
@@ -140,31 +143,71 @@ class LLM:
140
  return message_parts
141
 
142
  def _encode_image(self, image_data):
143
- # Get the first few bytes of the image data.
144
- magic_number = image_data[:4]
145
-
146
- # Check the magic number to determine the image type.
147
- if magic_number.startswith(b'\x89PNG'):
148
- image_type = 'png'
149
- elif magic_number.startswith(b'\xFF\xD8'):
150
- image_type = 'jpeg'
151
- elif magic_number.startswith(b'GIF89a'):
152
- image_type = 'gif'
153
- elif magic_number.startswith(b'RIFF'):
154
- if image_data[8:12] == b'WEBP':
155
- image_type = 'webp'
156
- else:
157
- # Unknown image type.
158
- raise Exception("Unknown image type")
159
- else:
160
- # Unknown image type.
161
  raise Exception("Unknown image type")
 
 
 
 
162
 
163
- return {
164
- "format": image_type,
 
 
 
 
 
165
  "source": {"bytes": image_data}
166
  }
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def read_response(self, response_stream):
169
  for event in response_stream:
170
  if 'contentBlockDelta' in event:
 
10
  import boto3
11
  from botocore.config import Config
12
  import re
13
+ from PIL import Image
14
+ import io
15
+ import math
16
 
17
  # constants
18
  log_to_console = False
 
143
  return message_parts
144
 
145
  def _encode_image(self, image_data):
146
+ try:
147
+ # Open the image using Pillow
148
+ img = Image.open(io.BytesIO(image_data))
149
+ original_format = img.format.lower()
150
+ except IOError:
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  raise Exception("Unknown image type")
152
+
153
+ # check if within the limits for Claude as per https://docs.anthropic.com/en/docs/build-with-claude/vision
154
+ def calculate_tokens(width, height):
155
+ return (width * height) / 750
156
 
157
+ tokens = calculate_tokens(img.width, img.height)
158
+ long_edge = max(img.width, img.height)
159
+
160
+ # Check if the image already meets all requirements
161
+ if long_edge <= 1568 and tokens <= 1600 and len(image_data) <= 5 * 1024 * 1024:
162
+ return {
163
+ "format": original_format,
164
  "source": {"bytes": image_data}
165
  }
166
 
167
+ # If we need to modify the image, proceed with resizing and/or compression
168
+ while long_edge > 1568 or tokens > 1600:
169
+ if long_edge > 1568:
170
+ scale_factor = max(1568 / long_edge, 0.9)
171
+ else:
172
+ scale_factor = max(math.sqrt(1600 / tokens), 0.9)
173
+
174
+ new_width = int(img.width * scale_factor)
175
+ new_height = int(img.height * scale_factor)
176
+
177
+ img = img.resize((new_width, new_height), Image.LANCZOS)
178
+
179
+ long_edge = max(new_width, new_height)
180
+ tokens = calculate_tokens(new_width, new_height)
181
+
182
+ # Try to save in original format first
183
+ buffer = io.BytesIO()
184
+ img.save(buffer, format=original_format, quality=95)
185
+ image_data = buffer.getvalue()
186
+
187
+ # If the image is still too large, switch to WebP and compress
188
+ if len(image_data) > 5 * 1024 * 1024:
189
+ format_to_use = "webp"
190
+ quality = 95
191
+ while len(image_data) > 5 * 1024 * 1024:
192
+ quality = max(int(quality * 0.9), 20)
193
+ buffer = io.BytesIO()
194
+ img.save(buffer, format=format_to_use, quality=quality)
195
+ image_data = buffer.getvalue()
196
+ if quality == 20:
197
+ # If we've reached quality 20 and it's still too large, resize
198
+ scale_factor = 0.9
199
+ new_width = int(img.width * scale_factor)
200
+ new_height = int(img.height * scale_factor)
201
+ img = img.resize((new_width, new_height), Image.LANCZOS)
202
+ quality = 95 # Reset quality for the resized image
203
+ else:
204
+ format_to_use = original_format
205
+
206
+ return {
207
+ "format": format_to_use,
208
+ "source": {"bytes": image_data}
209
+ }
210
+
211
  def read_response(self, response_stream):
212
  for event in response_stream:
213
  if 'contentBlockDelta' in event: