import torch import numpy from PIL import Image from torchvision.transforms import ToTensor from transformers import ViTModel, ViTFeatureExtractor from transformers.modeling_outputs import SequenceClassifierOutput import torch.nn as nn feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch32-384') class ViTForImageClassification(nn.Module): def __init__(self, num_labels=2): super(ViTForImageClassification, self).__init__() self.vit = ViTModel.from_pretrained('google/vit-large-patch32-384') self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels) self.num_labels = num_labels def forward(self, pixel_values, labels=None): outputs = self.vit(pixel_values=pixel_values) output = self.dropout(outputs.last_hidden_state[:,0]) logits = self.classifier(output) if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return SequenceClassifierOutput( loss=loss, logits=logits, ) else: return logits def preprocess_image(image, desired_size=384): im = image # Resize and pad the image old_size = im.size ratio = float(desired_size) / max(old_size) new_size = tuple([int(x*ratio) for x in old_size]) im = im.resize(new_size) # Create a new image and paste the resized on it new_im = Image.new("RGB", (desired_size, desired_size), "white") new_im.paste(im, ((desired_size-new_size[0])//2, (desired_size-new_size[1])//2)) return new_im def predict_image(image, model, feature_extractor): # Ensure model is in eval mode model.eval() # Convert image to tensor transform = ToTensor() input_tensor = transform(image) input_tensor = torch.tensor(numpy.array(feature_extractor(input_tensor)['pixel_values'])) # Move tensors to the right device input_tensor = input_tensor.cuda() # Forward pass of the image through the model output = model(input_tensor) # Convert model output to probabilities using softmax probabilities = torch.nn.functional.softmax(output, dim=1) return probabilities.cpu().detach().numpy() model = ViTForImageClassification(num_labels=2) model.load_state_dict(torch.load("./AID96k_E15_384.pth")) model.cuda() model.eval() img = Image.open("test.png") img = preprocess_image(img) probs = predict_image(img, model, feature_extractor) print(f"AI: {probs[0][0]}") print(f"Human: {probs[0][1]}")