Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
# In[1]: | |
import torch | |
import pandas as pd | |
import numpy as np | |
import os | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
import warnings | |
import streamlit as st | |
from PIL import Image | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DistilBertForSequenceClassification, AutoModelForSeq2SeqLM | |
from tqdm import tqdm | |
from torchvision import models | |
from torchvision.transforms import v2 | |
from torch.utils.data import Dataset, DataLoader | |
from keras.preprocessing import image | |
from torchmetrics.classification import MultilabelF1Score | |
from sklearn.metrics import average_precision_score, ndcg_score | |
# In[2]: | |
warnings.filterwarnings("ignore") | |
# In[3]: | |
genres = ["Crime", "Thriller", "Fantasy", "Horror", "Sci-Fi", "Comedy", "Documentary", "Adventure", "Film-Noir", "Animation", "Romance", "Drama", "Western", "Musical", "Action", "Mystery", "War", "Children\'s"] | |
mapping = {} | |
for i in range(len(genres)): | |
mapping[i] = genres[i] | |
mapping | |
# In[4]: | |
tokenizer_gen = AutoTokenizer.from_pretrained("MBZUAI/LaMini-Flan-T5-248M") | |
model_gen = AutoModelForSeq2SeqLM.from_pretrained("MBZUAI/LaMini-Flan-T5-248M") | |
tokenizer1 = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
model1 = DistilBertForSequenceClassification .from_pretrained("distilbert-base-uncased", problem_type="multi_label_classification", num_labels=18) | |
model1.config.id2label = mapping | |
tokenizer2 = AutoTokenizer.from_pretrained("dduy193/plot-classification") | |
model2 = AutoModelForSequenceClassification.from_pretrained("dduy193/plot-classification") | |
model2.config.id2label = mapping | |
model3 = models.resnet101(pretrained=False) | |
model3.fc = torch.nn.Linear(2048, len(genres)) | |
# In[5]: | |
class Multimodal(torch.nn.Module): | |
def __init__(self, model1, model2, model3): | |
super().__init__() | |
self.model1 = model1 | |
self.model2 = model2 | |
self.model3 = model3 | |
self.fc1 = torch.nn.Linear(18, 18) | |
self.fc2 = torch.nn.Linear(18, 18) | |
self.fc3 = torch.nn.Linear(18, 18) | |
def forward(self, | |
title_input_ids, title_attention_mask, | |
plot_input_ids, plot_attention_mask, | |
image_input): | |
title_output = self.model1(title_input_ids, title_attention_mask) | |
plot_output = self.model2(plot_input_ids, plot_attention_mask) | |
image_output = self.model3(image_input) | |
title_output = self.fc1(title_output.logits) | |
plot_output = self.fc2(plot_output.logits) | |
image_output = self.fc3(image_output) | |
output = torch.add(title_output, plot_output) | |
output = torch.add(output, image_output) | |
return output | |
# **_PLEASE INSTALL THE MODEL CHECKPOINT FROM THE LINK IN README.txt_** | |
# In[7]: | |
model = Multimodal(model1, model2, model3) | |
model.load_state_dict(torch.load('multimodel.pt', map_location=torch.device('cpu'))) | |
model.eval() | |
device = torch.device('cpu') | |
# In[8]: | |
def generate_plot(title: str, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, device) -> str: | |
quote = 'What is the story of the movie {}?' | |
model_gen.to(device) | |
model_gen.eval() | |
input_ids = tokenizer(quote.format(title), return_tensors='pt').input_ids.to(device) | |
output = model.generate(input_ids, max_length=256, do_sample=True, temperature=0.09) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
# In[9]: | |
def inference(title, image, | |
tokenizer1=tokenizer1, tokenizer2=tokenizer2, tokenizer_gen=tokenizer_gen, | |
model_gen=model_gen, model=model, | |
genres=genres, device=device): | |
title_input = tokenizer1(title, return_tensors='pt', padding=True, truncation=True) | |
title_input_ids = title_input['input_ids'].to(device) | |
title_attention_mask = title_input['attention_mask'].to(device) | |
plot = generate_plot(title, model_gen, tokenizer_gen, device) | |
plot_input = tokenizer2(plot, return_tensors='pt', padding=True, truncation=True) | |
plot_input_ids = plot_input['input_ids'].to(device) | |
plot_attention_mask = plot_input['attention_mask'].to(device) | |
# If image is not uploaded | |
if image is None: | |
image_input = torch.zeros((1, 3, 224, 224)).to(device) | |
else: | |
image_input = image.resize((224, 224)) | |
image_input = v2.ToTensor()(image_input) | |
image_input = image_input.unsqueeze(0) | |
image_input = image_input.to(device) | |
output = model(title_input_ids, title_attention_mask, plot_input_ids, plot_attention_mask, image_input) | |
output = torch.sigmoid(output) | |
output = output.cpu().detach().numpy() | |
output = np.where(output > 0.5, 1, 0) | |
output = output.squeeze() | |
output = np.where(output == 1)[0] | |
output = [genres[i] for i in output] | |
return output | |
# In[10]: | |
app = gr.Interface(fn=inference, inputs=["text", "pil"], outputs="text", title="Movie Genre Classification", | |
description="This model classifies the genre of a movie based on its title and poster.", | |
examples=[["The Matrix", "https://upload.wikimedia.org/wikipedia/en/c/c1/The_Matrix_Poster.jpg"], | |
["The Dark Knight", "https://upload.wikimedia.org/wikipedia/en/1/1c/The_Dark_Knight_%282008_film%29.jpg"], | |
["The Godfather", "https://upload.wikimedia.org/wikipedia/en/1/1c/Godfather_ver1.jpg"], | |
["The Shawshank Redemption", "https://upload.wikimedia.org/wikipedia/en/8/81/ShawshankRedemptionMoviePoster.jpg"], | |
["The Lord of the Rings: The Return of the King", "https://upload.wikimedia.org/wikipedia/en/2/23/The_Lord_of_the_Rings%2C_TROTK_%282003%29.jpg"], | |
["The Godfather: Part II", "https://upload.wikimedia.org/wikipedia/en/0/03/Godfather_part_ii.jpg"]]) | |
# In[11]: | |
app.launch(share=True) | |