import torch from PIL import Image from torchvision import transforms, models from transformers import AutoModelForCausalLM, AutoTokenizer import pandas as pd from sentence_transformers import SentenceTransformer import random import urllib.parse import torch.nn as nn from sklearn.metrics import classification_report from torch.optim.lr_scheduler import ReduceLROnPlateau import gradio as gr from io import BytesIO # Device setup device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {device}") # Data transformation data_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load datasets for enriched prompts dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description']) dataset_desc.columns = dataset_desc.columns.str.lower() style_desc = pd.read_csv("style_desc.csv", delimiter=';') style_desc.columns = style_desc.columns.str.lower() # Function to enrich prompts with custom data def enrich_prompt(artist, style): artist_info = dataset_desc.loc[dataset_desc['artists'] == artist, 'description'].values style_info = style_desc.loc[style_desc['style'] == style, 'description'].values artist_details = artist_info[0] if len(artist_info) > 0 else "Details about the artist are not available." style_details = style_info[0] if len(style_info) > 0 else "Details about the style are not available." return f"{artist_details} This work exemplifies {style_details}." # Custom dataset for ResNet18 class ArtDataset: def __init__(self, csv_file): self.annotations = pd.read_csv(csv_file) self.train_data = self.annotations[self.annotations['subset'] == 'train'] self.test_data = self.annotations[self.annotations['subset'] == 'test'] self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())} self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())} def get_style_and_artist_mappings(self): return self.label_map_style, self.label_map_artist def get_train_test_split(self): return self.train_data, self.test_data # DualOutputResNet model with Dropout class DualOutputResNet(nn.Module): def __init__(self, num_styles, num_artists, dropout_rate=0.5): super(DualOutputResNet, self).__init__() self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) num_features = self.backbone.fc.in_features self.backbone.fc = nn.Identity() self.dropout = nn.Dropout(dropout_rate) self.fc_style = nn.Linear(num_features, num_styles) self.fc_artist = nn.Linear(num_features, num_artists) def forward(self, x): features = self.backbone(x) features = self.dropout(features) style_output = self.fc_style(features) artist_output = self.fc_artist(features) return style_output, artist_output # Load dataset csv_file = "cleaned_classes.csv" dataset = ArtDataset(csv_file) label_map_style, label_map_artist = dataset.get_style_and_artist_mappings() train_data, test_data = dataset.get_train_test_split() num_styles = len(label_map_style) num_artists = len(label_map_artist) # Model setup model_resnet = DualOutputResNet(num_styles, num_artists).to(device) optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.001, weight_decay=1e-5) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True) # Load SentenceTransformer model clip_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1').to(device) # Load GPT-Neo and set padding token model_name = "EleutherAI/gpt-neo-1.3B" tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Set pad_token to eos_token model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device) def generate_description(image): image_resnet = data_transforms(image).unsqueeze(0).to(device) model_resnet.eval() with torch.no_grad(): outputs_style, outputs_artist = model_resnet(image_resnet) _, predicted_style_idx = torch.max(outputs_style, 1) _, predicted_artist_idx = torch.max(outputs_artist, 1) idx_to_style = {v: k for k, v in label_map_style.items()} idx_to_artist = {v: k for k, v in label_map_artist.items()} predicted_style = idx_to_style[predicted_style_idx.item()] predicted_artist = idx_to_artist[predicted_artist_idx.item()] enriched_prompt = enrich_prompt(predicted_artist, predicted_style) full_prompt = ( f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} " "Describe its distinctive features, considering both the artist's techniques and the artistic style." ) input_ids = tokenizer.encode(full_prompt, return_tensors="pt", padding=True).to(device) attention_mask = input_ids != tokenizer.pad_token_id output = model_gptneo.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=250, temperature=0.7, top_p=0.9, repetition_penalty=1.5, do_sample=True, pad_token_id=tokenizer.pad_token_id ) description_text = tokenizer.decode(output[0], skip_special_tokens=True) return predicted_style, predicted_artist, description_text # Gradio interface def gradio_interface(image): if image is None: return "No image provided. Please upload an image." if isinstance(image, BytesIO): image = Image.open(image).convert("RGB") else: image = Image.open(image).convert("RGB") predicted_style, predicted_artist, description = generate_description(image) return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}" iface = gr.Interface( fn=gradio_interface, inputs=gr.Image(type="filepath"), outputs="text", title="AI Artwork Analysis", description="Upload an image to predict its artistic style and creator, and generate a detailed description." ) if __name__ == "__main__": iface.launch()