Update app.py
Browse files
app.py
CHANGED
@@ -9,9 +9,9 @@ import logging
|
|
9 |
logging.basicConfig(level=logging.DEBUG)
|
10 |
|
11 |
# Load the pre-trained model and feature extractor
|
12 |
-
model_name = "JoshuaKelleyDs/quickdraw-
|
13 |
logging.info("Loading image processor and model...")
|
14 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name
|
15 |
model = AutoModelForImageClassification.from_pretrained(model_name)
|
16 |
|
17 |
# Define the prediction function
|
@@ -27,13 +27,8 @@ def predict(image):
|
|
27 |
logging.debug("Converting to NumPy array...")
|
28 |
image = np.array(image).astype('uint8')
|
29 |
logging.debug("Converting NumPy array to PIL image...")
|
30 |
-
image = Image.fromarray(image, 'RGBA').convert('
|
31 |
logging.debug("Image converted successfully.")
|
32 |
-
|
33 |
-
# Ensure the image has the correct number of dimensions
|
34 |
-
if len(image.size) == 2:
|
35 |
-
image = np.expand_dims(image, axis=-1)
|
36 |
-
logging.debug("Added dimension to image to match model input requirements.")
|
37 |
|
38 |
logging.info("Processing image...")
|
39 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
|
|
9 |
logging.basicConfig(level=logging.DEBUG)
|
10 |
|
11 |
# Load the pre-trained model and feature extractor
|
12 |
+
model_name = "JoshuaKelleyDs/quickdraw-MobileVITV2-2.0-Finetune"
|
13 |
logging.info("Loading image processor and model...")
|
14 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
15 |
model = AutoModelForImageClassification.from_pretrained(model_name)
|
16 |
|
17 |
# Define the prediction function
|
|
|
27 |
logging.debug("Converting to NumPy array...")
|
28 |
image = np.array(image).astype('uint8')
|
29 |
logging.debug("Converting NumPy array to PIL image...")
|
30 |
+
image = Image.fromarray(image, 'RGBA').convert('RGB')
|
31 |
logging.debug("Image converted successfully.")
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
logging.info("Processing image...")
|
34 |
inputs = feature_extractor(images=image, return_tensors="pt")
|