|
--- |
|
license: apache-2.0 |
|
pipeline_tag: image-classification |
|
tags: |
|
- aesthetic |
|
--- |
|
|
|
# THE INPUT IMAGE MUST HAVE `RGB` CHANNELS. IT WILL NOT WORK WITH `RGBA` CHANNELS! |
|
|
|
## Usage |
|
```python |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class CNN(nn.Module): |
|
def __init__(self, hidden_size=512): |
|
super(CNN, self).__init__() |
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) |
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) |
|
self.fc1 = nn.Linear(32 * 192 * 192, hidden_size) |
|
self.fc2 = nn.Linear(hidden_size, 2) |
|
|
|
def forward(self, x): |
|
x = torch.relu(self.conv1(x)) |
|
x = torch.max_pool2d(x, kernel_size=2, stride=2) |
|
x = torch.relu(self.conv2(x)) |
|
x = torch.max_pool2d(x, kernel_size=2, stride=2) |
|
x = x.view(-1, 32 * 192 * 192) |
|
x = torch.relu(self.fc1(x)) |
|
x = self.fc2(x) |
|
return x |
|
|
|
model = CNN().to(device).half() |
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=2.5e-5) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((768, 768)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
def infer(model, image_path): |
|
model.eval() |
|
image = Image.open(image_path) |
|
image = transform(image).unsqueeze(0).to(device).half() |
|
with torch.no_grad(): |
|
output = model(image) |
|
predicted_class = torch.argmax(output).item() |
|
return predicted_class |
|
|
|
checkpoint = torch.load('half_precision_model_checkpoint.pth') |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
epoch = checkpoint['epoch'] |
|
loss = checkpoint['loss'] |
|
|
|
image_path = 'good.jpg' |
|
predicted_class = infer(model, image_path) |
|
if int(predicted_class) == 0: |
|
print('Predicted class: Bad Image') |
|
elif int(predicted_class) == 1: |
|
print('Predicted class: Good Image') |
|
``` |