File size: 4,327 Bytes
7cc2be8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"3xnrF3UB6ev0"
],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Model Inference"
],
"metadata": {
"id": "33C47swS80_1"
}
},
{
"cell_type": "code",
"source": [
"#@title Install Dependencies\n",
"!pip install transformers -q"
],
"metadata": {
"cellView": "form",
"id": "noaoheUjvGbd"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "NZLqjuWEtCDy"
},
"outputs": [],
"source": [
"#@title Imports\n",
"import os\n",
"from transformers import pipeline\n",
"import shutil\n",
"from PIL import Image\n",
"import torch\n",
"pipe = pipeline(\"image-classification\", model=\"shadowlilac/aesthetic-shadow\", device=0)"
]
},
{
"cell_type": "code",
"source": [
"#@title Inference\n",
"\n",
"# Input image file\n",
"single_image_file = \"image_1.png\" #@param {type:\"string\"}\n",
"\n",
"result = pipe(images=[single_image_file])\n",
"\n",
"prediction_single = result[0]\n",
"print(\"Prediction: \" + str(round([p for p in prediction_single if p['label'] == 'hq'][0]['score'], 2)) + \"% High Quality\")\n",
"Image.open(single_image_file)"
],
"metadata": {
"cellView": "form",
"id": "r1R-L2r-0uo2"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Batch Mode"
],
"metadata": {
"id": "3xnrF3UB6ev0"
}
},
{
"cell_type": "code",
"source": [
"#@title Batch parameters\n",
"# Define the paths for the input folder and output folders\n",
"input_folder = \"input_folder\" #@param {type:\"string\"}\n",
"output_folder_hq = \"output_hq_folder\" #@param {type:\"string\"}\n",
"output_folder_lq = \"output_lq_folder\" #@param {type:\"string\"}\n",
"# Threshhold\n",
"batch_hq_threshold = 0.5 #@param {type:\"number\"}\n",
"# Define the batch size\n",
"batch_size = 8 #@param {type:\"number\"}"
],
"metadata": {
"cellView": "form",
"id": "VlPgrJf4wpHo"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title Execute Batch Job\n",
"\n",
"# List all image files in the input folder\n",
"image_files = [os.path.join(input_folder, f) for f in os.listdir(input_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
"\n",
"# Process images in batches\n",
"for i in range(0, len(image_files), batch_size):\n",
" batch = image_files[i:i + batch_size]\n",
"\n",
" # Perform classification for the batch\n",
" results = pipe(images=batch)\n",
"\n",
" for idx, result in enumerate(results):\n",
" # Extract the prediction scores and labels\n",
" predictions = result\n",
" hq_score = [p for p in predictions if p['label'] == 'hq'][0]['score']\n",
"\n",
" # Determine the destination folder based on the prediction and threshold\n",
" destination_folder = output_folder_hq if hq_score >= batch_hq_threshold else output_folder_lq\n",
"\n",
" # Copy the image to the appropriate folder\n",
" shutil.copy(batch[idx], os.path.join(destination_folder, os.path.basename(batch[idx])))\n",
"\n",
"print(\"Classification and sorting complete.\")"
],
"metadata": {
"cellView": "form",
"id": "RG01mcYf4DvK"
},
"execution_count": null,
"outputs": []
}
]
} |