starsnatched commited on
Commit
2e5d567
1 Parent(s): 322007d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +66 -0
README.md CHANGED
@@ -1,3 +1,69 @@
1
  ---
2
  license: apache-2.0
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ pipeline_tag: image-classification
4
+ tags:
5
+ - aesthetic
6
  ---
7
+
8
+ # THE INPUT IMAGE MUST HAVE `RGB` CHANNELS. IT WILL NOT WORK WITH `RGBA` CHANNELS!
9
+
10
+ ## Usage
11
+ ```python
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import torchvision.transforms as transforms
16
+ from PIL import Image
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ class CNN(nn.Module):
21
+ def __init__(self, hidden_size=512):
22
+ super(CNN, self).__init__()
23
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
24
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
25
+ self.fc1 = nn.Linear(32 * 192 * 192, hidden_size)
26
+ self.fc2 = nn.Linear(hidden_size, 2)
27
+
28
+ def forward(self, x):
29
+ x = torch.relu(self.conv1(x))
30
+ x = torch.max_pool2d(x, kernel_size=2, stride=2)
31
+ x = torch.relu(self.conv2(x))
32
+ x = torch.max_pool2d(x, kernel_size=2, stride=2)
33
+ x = x.view(-1, 32 * 192 * 192)
34
+ x = torch.relu(self.fc1(x))
35
+ x = self.fc2(x)
36
+ return x
37
+
38
+ model = CNN().to(device).half()
39
+ criterion = nn.CrossEntropyLoss()
40
+ optimizer = optim.Adam(model.parameters(), lr=2.5e-5)
41
+
42
+ transform = transforms.Compose([
43
+ transforms.Resize((768, 768)),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
46
+ ])
47
+
48
+ def infer(model, image_path):
49
+ model.eval()
50
+ image = Image.open(image_path)
51
+ image = transform(image).unsqueeze(0).to(device).half()
52
+ with torch.no_grad():
53
+ output = model(image)
54
+ predicted_class = torch.argmax(output).item()
55
+ return predicted_class
56
+
57
+ checkpoint = torch.load('half_precision_model_checkpoint.pth')
58
+ model.load_state_dict(checkpoint['model_state_dict'])
59
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
60
+ epoch = checkpoint['epoch']
61
+ loss = checkpoint['loss']
62
+
63
+ image_path = 'good.jpg'
64
+ predicted_class = infer(model, image_path)
65
+ if int(predicted_class) == 0:
66
+ print('Predicted class: Bad Image')
67
+ elif int(predicted_class) == 1:
68
+ print('Predicted class: Good Image')
69
+ ```