Spaces:
Sleeping
Sleeping
import torch | |
from PIL import Image | |
from torchvision import transforms, models | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import pandas as pd | |
import open_clip | |
import random | |
import urllib.parse | |
import torch.nn as nn | |
from sklearn.metrics import classification_report | |
from torch.optim.lr_scheduler import ReduceLROnPlateau | |
import gradio as gr | |
# Device setup | |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Data transformation | |
data_transforms = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Load datasets for enriched prompts | |
dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description']) | |
dataset_desc.columns = dataset_desc.columns.str.lower() | |
style_desc = pd.read_csv("style_desc.csv", delimiter=';') # CSV containing style-specific descriptions | |
style_desc.columns = style_desc.columns.str.lower() | |
# Function to enrich prompts with custom data | |
def enrich_prompt(artist, style): | |
artist_info = dataset_desc.loc[dataset_desc['artists'] == artist, 'description'].values | |
style_info = style_desc.loc[style_desc['style'] == style, 'description'].values | |
artist_details = artist_info[0] if len(artist_info) > 0 else "Details about the artist are not available." | |
style_details = style_info[0] if len(style_info) > 0 else "Details about the style are not available." | |
return f"{artist_details} This work exemplifies {style_details}." | |
# Custom dataset for ResNet18 | |
class ArtDataset: | |
def __init__(self, csv_file): | |
self.annotations = pd.read_csv(csv_file) | |
self.train_data = self.annotations[self.annotations['subset'] == 'train'] | |
self.test_data = self.annotations[self.annotations['subset'] == 'test'] | |
self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())} | |
self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())} | |
def get_style_and_artist_mappings(self): | |
return self.label_map_style, self.label_map_artist | |
def get_train_test_split(self): | |
return self.train_data, self.test_data | |
# DualOutputResNet model with Dropout | |
class DualOutputResNet(nn.Module): | |
def __init__(self, num_styles, num_artists, dropout_rate=0.5): | |
super(DualOutputResNet, self).__init__() | |
self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) | |
num_features = self.backbone.fc.in_features | |
self.backbone.fc = nn.Identity() | |
self.dropout = nn.Dropout(dropout_rate) | |
self.fc_style = nn.Linear(num_features, num_styles) | |
self.fc_artist = nn.Linear(num_features, num_artists) | |
def forward(self, x): | |
features = self.backbone(x) | |
features = self.dropout(features) | |
style_output = self.fc_style(features) | |
artist_output = self.fc_artist(features) | |
return style_output, artist_output | |
# Load dataset | |
csv_file = "cleaned_classes.csv" | |
dataset = ArtDataset(csv_file) | |
label_map_style, label_map_artist = dataset.get_style_and_artist_mappings() | |
train_data, test_data = dataset.get_train_test_split() | |
num_styles = len(label_map_style) | |
num_artists = len(label_map_artist) | |
# Model setup | |
model_resnet = DualOutputResNet(num_styles, num_artists).to(device) | |
optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.001, weight_decay=1e-5) | |
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True) | |
# Load GPT-Neo and CLIP | |
model_clip, preprocess_clip = open_clip.create_model_and_transforms('ViT-B/32', device=device) | |
tokenizer_clip = open_clip.get_tokenizer('ViT-B/32') | |
model_clip.eval() | |
model_name = "EleutherAI/gpt-neo-1.3B" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
# Generate prediction using ResNet and CLIP | |
def predict(image_path): | |
image = Image.open(image_path).convert("RGB") | |
image_tensor = data_transforms(image).unsqueeze(0).to(device) | |
# Predict with ResNet | |
style_logits, artist_logits = model_resnet(image_tensor) | |
style_idx = torch.argmax(style_logits, dim=1).item() | |
artist_idx = torch.argmax(artist_logits, dim=1).item() | |
predicted_style = list(label_map_style.keys())[list(label_map_style.values()).index(style_idx)] | |
predicted_artist = list(label_map_artist.keys())[list(label_map_artist.values()).index(artist_idx)] | |
# Enrich prompt with additional information | |
prompt = enrich_prompt(predicted_artist, predicted_style) | |
# Generate text description using GPT-Neo | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
output = model_gptneo.generate(input_ids, max_length=350, num_return_sequences=1) | |
description = tokenizer.decode(output[0], skip_special_tokens=True) | |
return predicted_style, predicted_artist, description | |
# Gradio interface | |
def gradio_interface(image): | |
predicted_style, predicted_artist, description = predict(image) | |
return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}" | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Image(type="filepath"), | |
outputs="text", | |
title="AI Artwork Analysis", | |
description="Upload an image to predict its artistic style and creator, and generate a detailed description." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |