File size: 4,780 Bytes
7cc2be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1f96ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cc2be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1f96ef
7cc2be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "3xnrF3UB6ev0"
      ],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Model Inference"
      ],
      "metadata": {
        "id": "33C47swS80_1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Install Dependencies\n",
        "!pip install transformers -q"
      ],
      "metadata": {
        "cellView": "form",
        "id": "noaoheUjvGbd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title [Run this if using Nvidia Ampere or newer. This will significantly speed up the process]\n",
        "import torch\n",
        "torch.backends.cuda.matmul.allow_tf32 = True\n",
        "torch.backends.cudnn.allow_tf32 = True"
      ],
      "metadata": {
        "cellView": "form",
        "id": "MkGgqW87eUsQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "NZLqjuWEtCDy"
      },
      "outputs": [],
      "source": [
        "#@title Imports\n",
        "import os\n",
        "from transformers import pipeline\n",
        "import shutil\n",
        "from PIL import Image\n",
        "import torch\n",
        "pipe = pipeline(\"image-classification\", model=\"shadowlilac/aesthetic-shadow\", device=0)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Inference\n",
        "\n",
        "# Input image file\n",
        "single_image_file = \"image_1.png\" #@param {type:\"string\"}\n",
        "\n",
        "result = pipe(images=[single_image_file])\n",
        "\n",
        "prediction_single = result[0]\n",
        "print(\"Prediction: \" + str(round([p for p in prediction_single if p['label'] == 'hq'][0]['score'], 2)) + \"% High Quality\")\n",
        "Image.open(single_image_file)"
      ],
      "metadata": {
        "cellView": "form",
        "id": "r1R-L2r-0uo2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Batch Mode"
      ],
      "metadata": {
        "id": "3xnrF3UB6ev0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Batch parameters\n",
        "# Define the paths for the input folder and output folders\n",
        "input_folder = \"input_folder\" #@param {type:\"string\"}\n",
        "output_folder_hq = \"output_hq_folder\" #@param {type:\"string\"}\n",
        "output_folder_lq = \"output_lq_folder\" #@param {type:\"string\"}\n",
        "# Threshhold\n",
        "batch_hq_threshold = 0.5 #@param {type:\"number\"}\n",
        "# Define the batch size\n",
        "batch_size = 8 #@param {type:\"number\"}"
      ],
      "metadata": {
        "cellView": "form",
        "id": "VlPgrJf4wpHo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Execute Batch Job\n",
        "\n",
        "# List all image files in the input folder\n",
        "image_files = [os.path.join(input_folder, f) for f in os.listdir(input_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
        "\n",
        "# Process images in batches\n",
        "for i in range(0, len(image_files), batch_size):\n",
        "    batch = image_files[i:i + batch_size]\n",
        "\n",
        "    # Perform classification for the batch\n",
        "    results = pipe(images=batch)\n",
        "\n",
        "    for idx, result in enumerate(results):\n",
        "        # Extract the prediction scores and labels\n",
        "        predictions = result\n",
        "        hq_score = [p for p in predictions if p['label'] == 'hq'][0]['score']\n",
        "\n",
        "        # Determine the destination folder based on the prediction and threshold\n",
        "        destination_folder = output_folder_hq if hq_score >= batch_hq_threshold else output_folder_lq\n",
        "\n",
        "        # Copy the image to the appropriate folder\n",
        "        shutil.copy(batch[idx], os.path.join(destination_folder, os.path.basename(batch[idx])))\n",
        "\n",
        "print(\"Classification and sorting complete.\")"
      ],
      "metadata": {
        "cellView": "form",
        "id": "RG01mcYf4DvK"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}