Dissertation / app.py
MusIre's picture
Update app.py
e6b5607 verified
raw
history blame
5.56 kB
import torch
from PIL import Image
from torchvision import transforms, models
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import open_clip
import random
import urllib.parse
import torch.nn as nn
from sklearn.metrics import classification_report
from torch.optim.lr_scheduler import ReduceLROnPlateau
import gradio as gr
# Device setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
# Data transformation
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load datasets for enriched prompts
dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
dataset_desc.columns = dataset_desc.columns.str.lower()
style_desc = pd.read_csv("style_desc.csv", delimiter=';') # CSV containing style-specific descriptions
style_desc.columns = style_desc.columns.str.lower()
# Function to enrich prompts with custom data
def enrich_prompt(artist, style):
artist_info = dataset_desc.loc[dataset_desc['artists'] == artist, 'description'].values
style_info = style_desc.loc[style_desc['style'] == style, 'description'].values
artist_details = artist_info[0] if len(artist_info) > 0 else "Details about the artist are not available."
style_details = style_info[0] if len(style_info) > 0 else "Details about the style are not available."
return f"{artist_details} This work exemplifies {style_details}."
# Custom dataset for ResNet18
class ArtDataset:
def __init__(self, csv_file):
self.annotations = pd.read_csv(csv_file)
self.train_data = self.annotations[self.annotations['subset'] == 'train']
self.test_data = self.annotations[self.annotations['subset'] == 'test']
self.label_map_style = {style: idx for idx, style in enumerate(self.annotations['genre'].unique())}
self.label_map_artist = {artist: idx for idx, artist in enumerate(self.annotations['artist'].unique())}
def get_style_and_artist_mappings(self):
return self.label_map_style, self.label_map_artist
def get_train_test_split(self):
return self.train_data, self.test_data
# DualOutputResNet model with Dropout
class DualOutputResNet(nn.Module):
def __init__(self, num_styles, num_artists, dropout_rate=0.5):
super(DualOutputResNet, self).__init__()
self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
self.dropout = nn.Dropout(dropout_rate)
self.fc_style = nn.Linear(num_features, num_styles)
self.fc_artist = nn.Linear(num_features, num_artists)
def forward(self, x):
features = self.backbone(x)
features = self.dropout(features)
style_output = self.fc_style(features)
artist_output = self.fc_artist(features)
return style_output, artist_output
# Load dataset
csv_file = "cleaned_classes.csv"
dataset = ArtDataset(csv_file)
label_map_style, label_map_artist = dataset.get_style_and_artist_mappings()
train_data, test_data = dataset.get_train_test_split()
num_styles = len(label_map_style)
num_artists = len(label_map_artist)
# Model setup
model_resnet = DualOutputResNet(num_styles, num_artists).to(device)
optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
# Load GPT-Neo and CLIP
model_clip, preprocess_clip = open_clip.create_model_and_transforms('ViT-B/32', device=device)
tokenizer_clip = open_clip.get_tokenizer('ViT-B/32')
model_clip.eval()
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# Generate prediction using ResNet and CLIP
def predict(image_path):
image = Image.open(image_path).convert("RGB")
image_tensor = data_transforms(image).unsqueeze(0).to(device)
# Predict with ResNet
style_logits, artist_logits = model_resnet(image_tensor)
style_idx = torch.argmax(style_logits, dim=1).item()
artist_idx = torch.argmax(artist_logits, dim=1).item()
predicted_style = list(label_map_style.keys())[list(label_map_style.values()).index(style_idx)]
predicted_artist = list(label_map_artist.keys())[list(label_map_artist.values()).index(artist_idx)]
# Enrich prompt with additional information
prompt = enrich_prompt(predicted_artist, predicted_style)
# Generate text description using GPT-Neo
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
output = model_gptneo.generate(input_ids, max_length=350, num_return_sequences=1)
description = tokenizer.decode(output[0], skip_special_tokens=True)
return predicted_style, predicted_artist, description
# Gradio interface
def gradio_interface(image):
predicted_style, predicted_artist, description = predict(image)
return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}"
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="filepath"),
outputs="text",
title="AI Artwork Analysis",
description="Upload an image to predict its artistic style and creator, and generate a detailed description."
)
if __name__ == "__main__":
iface.launch()