Aesthetic-Anime-Art / README.md
starsnatched's picture
Update README.md
2e5d567 verified
|
raw
history blame
2.05 kB
---
license: apache-2.0
pipeline_tag: image-classification
tags:
- aesthetic
---
# THE INPUT IMAGE MUST HAVE `RGB` CHANNELS. IT WILL NOT WORK WITH `RGBA` CHANNELS!
## Usage
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CNN(nn.Module):
def __init__(self, hidden_size=512):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 192 * 192, hidden_size)
self.fc2 = nn.Linear(hidden_size, 2)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.max_pool2d(x, kernel_size=2, stride=2)
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, kernel_size=2, stride=2)
x = x.view(-1, 32 * 192 * 192)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = CNN().to(device).half()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=2.5e-5)
transform = transforms.Compose([
transforms.Resize((768, 768)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def infer(model, image_path):
model.eval()
image = Image.open(image_path)
image = transform(image).unsqueeze(0).to(device).half()
with torch.no_grad():
output = model(image)
predicted_class = torch.argmax(output).item()
return predicted_class
checkpoint = torch.load('half_precision_model_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
image_path = 'good.jpg'
predicted_class = infer(model, image_path)
if int(predicted_class) == 0:
print('Predicted class: Bad Image')
elif int(predicted_class) == 1:
print('Predicted class: Good Image')
```