Spaces:
Sleeping
Sleeping
File size: 5,555 Bytes
9d16cc3 9212c70 9d16cc3 9212c70 9d16cc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
from PIL import Image
from torchvision import transforms, models
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import open_clip
import random
import urllib.parse
import torch.nn as nn
from sklearn.metrics import classification_report
from torch.optim.lr_scheduler import ReduceLROnPlateau
import gradio as gr
# Device setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
# Data transformation
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load datasets for enriched prompts
dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
dataset_desc.columns = dataset_desc.columns.str.lower()
style_desc = pd.read_csv("style_desc.csv", delimiter=';') # CSV containing style-specific descriptions
style_desc.columns = style_desc.columns.str.lower()
# Function to enrich prompts with custom data
def enrich_prompt(artist, style):
artist_info = dataset_desc.loc[dataset_desc['artists'] == artist, 'description'].values
style_info = style_desc.loc[style_desc['style'] == style, 'description'].values
artist_details = artist_info[0] if len(artist_info) > 0 else "Details about the artist are not available."
style_details = style_info[0] if len(style_info) > 0 else "Details about the style are not available."
return f"{artist_details} This work exemplifies {style_details}."
# Custom dataset for ResNet18
class ArtDataset:
def __init__(self, csv_file):
self.annotations = pd.read_csv(csv_file)
self.train_data = self.annotations[self.annotations['subset'] == 'train']
self.test_data = self.annotations[self.annotations['subset'] == 'test']
self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())}
self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())}
def get_style_and_artist_mappings(self):
return self.label_map_style, self.label_map_artist
def get_train_test_split(self):
return self.train_data, self.test_data
# DualOutputResNet model with Dropout
class DualOutputResNet(nn.Module):
def __init__(self, num_styles, num_artists, dropout_rate=0.5):
super(DualOutputResNet, self).__init__()
self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
self.dropout = nn.Dropout(dropout_rate)
self.fc_style = nn.Linear(num_features, num_styles)
self.fc_artist = nn.Linear(num_features, num_artists)
def forward(self, x):
features = self.backbone(x)
features = self.dropout(features)
style_output = self.fc_style(features)
artist_output = self.fc_artist(features)
return style_output, artist_output
# Load dataset
csv_file = "cleaned_classes.csv"
dataset = ArtDataset(csv_file)
label_map_style, label_map_artist = dataset.get_style_and_artist_mappings()
train_data, test_data = dataset.get_train_test_split()
num_styles = len(label_map_style)
num_artists = len(label_map_artist)
# Model setup
model_resnet = DualOutputResNet(num_styles, num_artists).to(device)
optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
# Load GPT-Neo and CLIP
model_clip, preprocess_clip = open_clip.create_model_and_transforms('ViT-B/32', device=device)
tokenizer_clip = open_clip.get_tokenizer('ViT-B/32')
model_clip.eval()
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# Generate prediction using ResNet and CLIP
def predict(image_path):
image = Image.open(image_path).convert("RGB")
image_tensor = data_transforms(image).unsqueeze(0).to(device)
# Predict with ResNet
style_logits, artist_logits = model_resnet(image_tensor)
style_idx = torch.argmax(style_logits, dim=1).item()
artist_idx = torch.argmax(artist_logits, dim=1).item()
predicted_style = list(label_map_style.keys())[list(label_map_style.values()).index(style_idx)]
predicted_artist = list(label_map_artist.keys())[list(label_map_artist.values()).index(artist_idx)]
# Enrich prompt with additional information
prompt = enrich_prompt(predicted_artist, predicted_style)
# Generate text description using GPT-Neo
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
output = model_gptneo.generate(input_ids, max_length=350, num_return_sequences=1)
description = tokenizer.decode(output[0], skip_special_tokens=True)
return predicted_style, predicted_artist, description
# Gradio interface
def gradio_interface(image):
predicted_style, predicted_artist, description = predict(image)
return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}"
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="filepath"),
outputs="text",
title="AI Artwork Analysis",
description="Upload an image to predict its artistic style and creator, and generate a detailed description."
)
if __name__ == "__main__":
iface.launch()
|